`_
+Regular development tasks
+=========================
-Here is the part of Breeze video which is relevant (note that it refers to the old ``./breeze-legacy``
-command and it is not yet available in the new ``breeze`` command):
-
-.. raw:: html
+The regular Breeze development tasks are available as top-level commands. Those tasks are most often
+used during the development, that's why they are available without any sub-command. More advanced
+commands are separated to sub-commands.
-
+Entering Breeze shell
+---------------------
-Choosing different Breeze environment configuration
-===================================================
+This is the most often used feature of breeze. It simply allows to enter the shell inside the Breeze
+development environment (inside the Breeze container).
You can use additional ``breeze`` flags to choose your environment. You can specify a Python
version to use, and backend (the meta-data database). Thanks to that, with Breeze, you can recreate the same
@@ -398,20 +368,6 @@ default settings.
You can see which value of the parameters that can be stored persistently in cache marked with >VALUE<
in the help of the commands.
-Another part of configuration is enabling/disabling cheatsheet, asciiart. The cheatsheet and asciiart can
-be disabled - they are "nice looking" and cheatsheet
-contains useful information for first time users but eventually you might want to disable both if you
-find it repetitive and annoying.
-
-With the config setting colour-blind-friendly communication for Breeze messages. By default we communicate
-with the users about information/errors/warnings/successes via colour-coded messages, but we can switch
-it off by passing ``--no-colour`` to config in which case the messages to the user printed by Breeze
-will be printed using different schemes (italic/bold/underline) to indicate different kind of messages
-rather than colours.
-
-Here is the part of Breeze video which is relevant (note that it refers to the old ``./breeze-legacy``
-command but it is very similar to current ``breeze`` command):
-
.. raw:: html
@@ -421,199 +377,149 @@ command but it is very similar to current ``breeze`` command):
-Those are all available flags of ``config`` command:
-
-.. image:: ./images/breeze/output-config.svg
- :width: 100%
- :alt: Breeze config
+Building the documentation
+--------------------------
+To build documentation in Breeze, use the ``build-docs`` command:
-You can also dump hash of the configuration options used - this is mostly use to generate the dump
-of help of the commands only when they change.
+.. code-block:: bash
-.. image:: ./images/breeze/output-command-hash-export.svg
- :width: 100%
- :alt: Breeze command-hash-export
+ breeze build-docs
+Results of the build can be found in the ``docs/_build`` folder.
-Starting complete Airflow installation
-======================================
+The documentation build consists of three steps:
-For testing Airflow oyou often want to start multiple components (in multiple terminals). Breeze has
-built-in ``start-airflow`` command that start breeze container, launches multiple terminals using tmux
-and launches all Airflow necessary components in those terminals.
+* verifying consistency of indexes
+* building documentation
+* spell checking
-You can also use it to start any released version of Airflow from ``PyPI`` with the
-``--use-airflow-version`` flag.
+You can choose only one stage of the two by providing ``--spellcheck-only`` or ``--docs-only`` after
+extra ``--`` flag.
.. code-block:: bash
- breeze --python 3.7 --backend mysql --use-airflow-version 2.2.5 start-airflow
+ breeze build-docs --spellcheck-only
-Those are all available flags of ``start-airflow`` command:
+This process can take some time, so in order to make it shorter you can filter by package, using the flag
+``--package-filter ``. The package name has to be one of the providers or ``apache-airflow``. For
+instance, for using it with Amazon, the command would be:
-.. image:: ./images/breeze/output-start-airflow.svg
- :width: 100%
- :alt: Breeze start-airflow
+.. code-block:: bash
+ breeze build-docs --package-filter apache-airflow-providers-amazon
-Troubleshooting
-===============
+Often errors during documentation generation come from the docstrings of auto-api generated classes.
+During the docs building auto-api generated files are stored in the ``docs/_api`` folder. This helps you
+easily identify the location the problems with documentation originated from.
-If you are having problems with the Breeze environment, try the steps below. After each step you
-can check whether your problem is fixed.
+Those are all available flags of ``build-docs`` command:
-1. If you are on macOS, check if you have enough disk space for Docker (Breeze will warn you if not).
-2. Stop Breeze with ``breeze stop``.
-3. Delete the ``.build`` directory and run ``breeze build-image``.
-4. Clean up Docker images via ``breeze cleanup`` command.
-5. Restart your Docker Engine and try again.
-6. Restart your machine and try again.
-7. Re-install Docker Desktop and try again.
+.. image:: ./images/breeze/output_build-docs.svg
+ :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_build-docs.svg
+ :width: 100%
+ :alt: Breeze build documentation
-In case the problems are not solved, you can set the VERBOSE_COMMANDS variable to "true":
-.. code-block::
+.. raw:: html
- export VERBOSE_COMMANDS="true"
+
+Running static checks
+---------------------
-Then run the failed command, copy-and-paste the output from your terminal to the
-`Airflow Slack `_ #airflow-breeze channel and
-describe your problem.
+You can run static checks via Breeze. You can also run them via pre-commit command but with auto-completion
+Breeze makes it easier to run selective static checks. If you press after the static-check and if
+you have auto-complete setup you should see auto-completable list of all checks available.
-Uses of the Airflow Breeze environment
-======================================
+.. code-block:: bash
-Airflow Breeze is a bash script serving as a "swiss-army-knife" of Airflow testing. Under the
-hood it uses other scripts that you can also run manually if you have problem with running the Breeze
-environment. Breeze script allows performing the following tasks:
+ breeze static-checks -t run-mypy
+
+The above will run mypy check for currently staged files.
+
+You can also pass specific pre-commit flags for example ``--all-files`` :
-Development tasks
------------------
+.. code-block:: bash
-Those are commands mostly used by contributors:
+ breeze static-checks -t run-mypy --all-files
-* Execute arbitrary command in the test environment with ``breeze shell`` command
-* Enter interactive shell in CI container when ``shell`` (or no command) is specified
-* Start containerised, development-friendly airflow installation with ``breeze start-airflow`` command
-* Build documentation with ``breeze build-docs`` command
-* Initialize local virtualenv with ``./scripts/tools/initialize_virtualenv.py`` command
-* Run static checks with autocomplete support ``breeze static-checks`` command
-* Run test specified with ``breeze tests`` command
-* Build CI docker image with ``breeze build-image`` command
-* Cleanup breeze with ``breeze cleanup`` command
+The above will run mypy check for all files.
-Additional management tasks:
+There is a convenience ``--last-commit`` flag that you can use to run static check on last commit only:
-* Join running interactive shell with ``breeze exec`` command
-* Stop running interactive environment with ``breeze stop`` command
-* Execute arbitrary docker-compose command with ``./breeze-legacy docker-compose`` command
+.. code-block:: bash
-Tests
------
+ breeze static-checks -t run-mypy --last-commit
-* Run docker-compose tests with ``breeze docker-compose-tests`` command.
-* Run test specified with ``breeze tests`` command.
+The above will run mypy check for all files in the last commit.
-.. image:: ./images/breeze/output-tests.svg
- :width: 100%
- :alt: Breeze tests
+There is another convenience ``--commit-ref`` flag that you can use to run static check on specific commit:
-Kubernetes tests
-----------------
+.. code-block:: bash
-* Manage KinD Kubernetes cluster and deploy Airflow to KinD cluster ``./breeze-legacy kind-cluster`` commands
-* Run Kubernetes tests specified with ``./breeze-legacy kind-cluster tests`` command
-* Enter the interactive kubernetes test environment with ``./breeze-legacy kind-cluster shell`` command
+ breeze static-checks -t run-mypy --commit-ref 639483d998ecac64d0fef7c5aa4634414065f690
-CI Image tasks
---------------
+The above will run mypy check for all files in the 639483d998ecac64d0fef7c5aa4634414065f690 commit.
+Any ``commit-ish`` reference from Git will work here (branch, tag, short/long hash etc.)
-The image building is usually run for users automatically when needed,
-but sometimes Breeze users might want to manually build, pull or verify the CI images.
+If you ever need to get a list of the files that will be checked (for troubleshooting) use these commands:
-* Build CI docker image with ``breeze build-image`` command
-* Pull CI images in parallel ``breeze pull-image`` command
-* Verify CI image ``breeze verify-image`` command
+.. code-block:: bash
-PROD Image tasks
-----------------
+ breeze static-checks -t identity --verbose # currently staged files
+ breeze static-checks -t identity --verbose --from-ref $(git merge-base main HEAD) --to-ref HEAD # branch updates
-Users can also build Production images when they are developing them. However when you want to
-use the PROD image, the regular docker build commands are recommended. See
-`building the image `_
+Those are all available flags of ``static-checks`` command:
-* Build PROD image with ``breeze build-prod-image`` command
-* Pull PROD image in parallel ``breeze pull-prod-image`` command
-* Verify CI image ``breeze verify-prod-image`` command
+.. image:: ./images/breeze/output_static-checks.svg
+ :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_static-checks.svg
+ :width: 100%
+ :alt: Breeze static checks
-Configuration and maintenance
------------------------------
-* Cleanup breeze with ``breeze cleanup`` command
-* Self-upgrade breeze with ``breeze self-upgrade`` command
-* Setup autocomplete for Breeze with ``breeze setup-autocomplete`` command
-* Checking available resources for docker with ``breeze resource-check`` command
-* Freeing space needed to run CI tests with ``breeze free-space`` command
-* Fixing ownership of files in your repository with ``breeze fix-ownership`` command
-* Print Breeze version with ``breeze version`` command
-* Outputs hash of commands defined by ``breeze`` with ``command-hash-export`` (useful to avoid needless
- regeneration of Breeze images)
-
-Release tasks
--------------
+.. note::
-Maintainers also can use Breeze for other purposes (those are commands that regular contributors likely
-do not need or have no access to run). Those are usually connected with releasing Airflow:
+ When you run static checks, some of the artifacts (mypy_cache) is stored in docker-compose volume
+ so that it can speed up static checks execution significantly. However, sometimes, the cache might
+ get broken, in which case you should run ``breeze stop`` to clean up the cache.
-* Prepare cache for CI: ``breeze build-image --prepare-build-cache`` and
- ``breeze build-prod image --prepare-build-cache``(needs buildx plugin and write access to registry ghcr.io)
-* Generate constraints with ``breeze generate-constraints`` (needed when conflicting changes are merged)
-* Prepare airflow packages: ``breeze prepare-airflow-package`` (when releasing Airflow)
-* Verify providers: ``breeze verify-provider-packages`` (when releasing provider packages) - including importing
- the providers in an earlier airflow version.
-* Prepare provider documentation ``breeze prepare-provider-documentation`` and prepare provider packages
- ``breeze prepare-provider-packages`` (when releasing provider packages)
-* Finding the updated dependencies since the last successful build when we have conflict with
- ``breeze find-newer-dependencies`` command
-* Release production images to DockerHub with ``breeze release-prod-images`` command
+Starting Airflow
+----------------
-Details of Breeze usage
-=======================
+For testing Airflow you often want to start multiple components (in multiple terminals). Breeze has
+built-in ``start-airflow`` command that start breeze container, launches multiple terminals using tmux
+and launches all Airflow necessary components in those terminals.
-Database volumes in Breeze
---------------------------
+When you are starting airflow from local sources, www asset compilation is automatically executed before.
-Breeze keeps data for all it's integration in named docker volumes. Each backend and integration
-keeps data in their own volume. Those volumes are persisted until ``breeze stop`` command.
-You can also preserve the volumes by adding flag ``--preserve-volumes`` when you run the command.
-Then, next time when you start Breeze, it will have the data pre-populated.
+.. code-block:: bash
-Those are all available flags of ``stop`` command:
+ breeze --python 3.7 --backend mysql start-airflow
-.. image:: ./images/breeze/output-stop.svg
- :width: 100%
- :alt: Breeze stop
-Image cleanup
---------------
+You can also use it to start any released version of Airflow from ``PyPI`` with the
+``--use-airflow-version`` flag.
-Breeze uses docker images heavily and those images are rebuild periodically. This might cause extra
-disk usage by the images. If you need to clean-up the images periodically you can run
-``breeze cleanup`` command (by default it will skip removing your images before cleaning up but you
-can also remove the images to clean-up everything by adding ``--all``).
+.. code-block:: bash
-Those are all available flags of ``cleanup`` command:
+ breeze start-airflow --python 3.7 --backend mysql --use-airflow-version 2.2.5
+Those are all available flags of ``start-airflow`` command:
-.. image:: ./images/breeze/output-cleanup.svg
+.. image:: ./images/breeze/output_start-airflow.svg
+ :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_start-airflow.svg
:width: 100%
- :alt: Breeze cleanup
+ :alt: Breeze start-airflow
-Launching multiple terminals
-----------------------------
+Launching multiple terminals in the same environment
+----------------------------------------------------
Often if you want to run full airflow in the Breeze environment you need to launch multiple terminals and
run ``airflow webserver``, ``airflow scheduler``, ``airflow worker`` in separate terminals.
@@ -624,299 +530,696 @@ capability of creating multiple virtual terminals and multiplex between them. Mo
found at `tmux GitHub wiki page `_ . Tmux has several useful shortcuts
that allow you to split the terminals, open new tabs etc - it's pretty useful to learn it.
-Here is the part of Breeze video which is relevant:
-
-.. raw:: html
-
-
-
-
Another way is to exec into Breeze terminal from the host's terminal. Often you can
have multiple terminals in the host (Linux/MacOS/WSL2 on Windows) and you can simply use those terminals
to enter the running container. It's as easy as launching ``breeze exec`` while you already started the
Breeze environment. You will be dropped into bash and environment variables will be read in the same
way as when you enter the environment. You can do it multiple times and open as many terminals as you need.
-Here is the part of Breeze video which is relevant:
-
-.. raw:: html
-
-
-
-
Those are all available flags of ``exec`` command:
-.. image:: ./images/breeze/output-exec.svg
+.. image:: ./images/breeze/output_exec.svg
+ :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_exec.svg
:width: 100%
:alt: Breeze exec
-Additional tools
-----------------
-To shrink the Docker image, not all tools are pre-installed in the Docker image. But we have made sure that there
-is an easy process to install additional tools.
+Compiling www assets
+--------------------
-Additional tools are installed in ``/files/bin``. This path is added to ``$PATH``, so your shell will
-automatically autocomplete files that are in that directory. You can also keep the binaries for your tools
-in this directory if you need to.
+Airflow webserver needs to prepare www assets - compiled with node and yarn. The ``compile-www-assets``
+command takes care about it. This is needed when you want to run webserver inside of the breeze.
-**Installation scripts**
+.. image:: ./images/breeze/output_compile-www-assets.svg
+ :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_compile-www-assets.svg
+ :width: 100%
+ :alt: Breeze compile-www-assets
-For the development convenience, we have also provided installation scripts for commonly used tools. They are
-installed to ``/files/opt/``, so they are preserved after restarting the Breeze environment. Each script
-is also available in ``$PATH``, so just type ``install_`` to get a list of tools.
+Breeze cleanup
+--------------
-Currently available scripts:
+Breeze uses docker images heavily and those images are rebuild periodically. This might cause extra
+disk usage by the images. If you need to clean-up the images periodically you can run
+``breeze setup cleanup`` command (by default it will skip removing your images before cleaning up but you
+can also remove the images to clean-up everything by adding ``--all``).
-* ``install_aws.sh`` - installs `the AWS CLI `__ including
-* ``install_az.sh`` - installs `the Azure CLI `__ including
-* ``install_gcloud.sh`` - installs `the Google Cloud SDK `__ including
- ``gcloud``, ``gsutil``.
-* ``install_imgcat.sh`` - installs `imgcat - Inline Images Protocol `__
- for iTerm2 (Mac OS only)
-* ``install_java.sh`` - installs `the OpenJDK 8u41 `__
-* ``install_kubectl.sh`` - installs `the Kubernetes command-line tool, kubectl `__
-* ``install_snowsql.sh`` - installs `SnowSQL `__
-* ``install_terraform.sh`` - installs `Terraform `__
+Those are all available flags of ``cleanup`` command:
-Launching Breeze integrations
------------------------------
-When Breeze starts, it can start additional integrations. Those are additional docker containers
-that are started in the same docker-compose command. Those are required by some of the tests
-as described in ``_.
+.. image:: ./images/breeze/output_cleanup.svg
+ :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_cleanup.svg
+ :width: 100%
+ :alt: Breeze setup cleanup
-By default Breeze starts only airflow container without any integration enabled. If you selected
-``postgres`` or ``mysql`` backend, the container for the selected backend is also started (but only the one
-that is selected). You can start the additional integrations by passing ``--integration`` flag
-with appropriate integration name when starting Breeze. You can specify several ``--integration`` flags
-to start more than one integration at a time.
-Finally you can specify ``--integration all`` to start all integrations.
+Running arbitrary commands in container
+---------------------------------------
-Once integration is started, it will continue to run until the environment is stopped with
-``breeze stop`` command. or restarted via ``breeze restart`` command
+More sophisticated usages of the breeze shell is using the ``breeze shell`` command - it has more parameters
+and you can also use it to execute arbitrary commands inside the container.
-Note that running integrations uses significant resources - CPU and memory.
+.. code-block:: bash
-Here is the part of Breeze video which is relevant (note that it refers to the old ``./breeze-legacy``
-command but it is very similar to current ``breeze`` command):
+ breeze shell "ls -la"
-.. raw:: html
+Those are all available flags of ``shell`` command:
-
+.. image:: ./images/breeze/output_shell.svg
+ :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_shell.svg
+ :width: 100%
+ :alt: Breeze shell
-Managing CI images
-------------------
-With Breeze you can build images that are used by Airflow CI and production ones.
+Stopping the environment
+------------------------
-For all development tasks, unit tests, integration tests, and static code checks, we use the
-**CI image** maintained in GitHub Container Registry.
+After starting up, the environment runs in the background and takes quite some memory which you might
+want to free for other things you are running on your host.
-The CI image is built automatically as needed, however it can be rebuilt manually with
-``build-image`` command. The production
-image should be built manually - but also a variant of this image is built automatically when
-kubernetes tests are executed see `Running Kubernetes tests <#running-kubernetes-tests>`_
+You can always stop it via:
-Here is the part of Breeze video which is relevant (note that it refers to the old ``./breeze-legacy``
-command but it is very similar to current ``breeze`` command):
+.. code-block:: bash
-.. raw:: html
+ breeze stop
-
+Those are all available flags of ``stop`` command:
-Building the image first time pulls a pre-built version of images from the Docker Hub, which may take some
-time. But for subsequent source code changes, no wait time is expected.
-However, changes to sensitive files like ``setup.py`` or ``Dockerfile.ci`` will trigger a rebuild
-that may take more time though it is highly optimized to only rebuild what is needed.
+.. image:: ./images/breeze/output_stop.svg
+ :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_stop.svg
+ :width: 100%
+ :alt: Breeze stop
-Breeze has built in mechanism to check if your local image has not diverged too much from the
-latest image build on CI. This might happen when for example latest patches have been released as new
-Python images or when significant changes are made in the Dockerfile. In such cases, Breeze will
-download the latest images before rebuilding because this is usually faster than rebuilding the image.
+Troubleshooting
+===============
-Those are all available flags of ``build-image`` command:
+If you are having problems with the Breeze environment, try the steps below. After each step you
+can check whether your problem is fixed.
-.. image:: ./images/breeze/output-build-image.svg
- :width: 100%
- :alt: Breeze build-image
+1. If you are on macOS, check if you have enough disk space for Docker (Breeze will warn you if not).
+2. Stop Breeze with ``breeze stop``.
+3. Delete the ``.build`` directory and run ``breeze ci-image build``.
+4. Clean up Docker images via ``breeze cleanup`` command.
+5. Restart your Docker Engine and try again.
+6. Restart your machine and try again.
+7. Re-install Docker Desktop and try again.
-You can also pull the CI images locally in parallel with optional verification.
+In case the problems are not solved, you can set the VERBOSE_COMMANDS variable to "true":
-Those are all available flags of ``pull-image`` command:
+.. code-block::
-.. image:: ./images/breeze/output-pull-image.svg
- :width: 100%
- :alt: Breeze pull-image
+ export VERBOSE_COMMANDS="true"
-Finally, you can verify CI image by running tests - either with the pulled/built images or
-with an arbitrary image.
-Those are all available flags of ``verify-image`` command:
+Then run the failed command, copy-and-paste the output from your terminal to the
+`Airflow Slack `_ #airflow-breeze channel and
+describe your problem.
+
+Advanced commands
+=================
+
+Airflow Breeze is a bash script serving as a "swiss-army-knife" of Airflow testing. Under the
+hood it uses other scripts that you can also run manually if you have problem with running the Breeze
+environment. Breeze script allows performing the following tasks:
+
+Running tests
+-------------
+
+You can run tests with ``breeze``. There are various tests type and breeze allows to run different test
+types easily. You can run unit tests in different ways, either interactively run tests with the default
+``shell`` command or via the ``testing`` commands. The latter allows to run more kinds of tests easily.
-.. image:: ./images/breeze/output-verify-image.svg
+Here is the detailed set of options for the ``breeze testing`` command.
+
+.. image:: ./images/breeze/output_testing.svg
+ :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_testing.svg
:width: 100%
- :alt: Breeze verify-image
+ :alt: Breeze testing
-Verifying providers
--------------------
-Breeze can also be used to verify if provider classes are importable and if they are following the
-right naming conventions. This happens automatically on CI but you can also run it manually.
+Iterate on tests interactively via ``shell`` command
+....................................................
-.. code-block:: bash
+You can simply enter the ``breeze`` container and run ``pytest`` command there. You can enter the
+container via just ``breeze`` command or ``breeze shell`` command (the latter has more options
+useful when you run integration or system tests). This is the best way if you want to interactively
+run selected tests and iterate with the tests. Once you enter ``breeze`` environment it is ready
+out-of-the-box to run your tests by running the right ``pytest`` command (autocomplete should help
+you with autocompleting test name if you start typing ``pytest tests``).
- breeze verify-provider-packages
+Here are few examples:
-You can also run the verification with an earlier airflow version to check for compatibility.
+Running single test:
.. code-block:: bash
- breeze verify-provider-packages --use-airflow-version 2.1.0
+ pytest tests/core/test_core.py::TestCore::test_check_operators
-All the command parameters are here:
+To run the whole test class:
-.. image:: ./images/breeze/output-verify-provider-packages.svg
- :width: 100%
- :alt: Breeze verify-provider-packages
+.. code-block:: bash
-Preparing packages
-------------------
+ pytest tests/core/test_core.py::TestCore
-Breeze can also be used to prepare airflow packages - both "apache-airflow" main package and
-provider packages.
+You can re-run the tests interactively, add extra parameters to pytest and modify the files before
+re-running the test to iterate over the tests. You can also add more flags when starting the
+``breeze shell`` command when you run integration tests or system tests. Read more details about it
+in the ``TESTING.rst `` where all the test types of our are explained and more information
+on how to run them.
-You can read more about testing provider packages in
-`TESTING.rst `_
+This applies to all kind of tests - all our tests can be run using pytest.
-There are several commands that you can run in Breeze to manage and build packages:
+Running unit tests
+..................
-* preparing Provider documentation files
-* preparing Airflow packages
-* preparing Provider packages
+Another option you have is that you can also run tests via built-in ``breeze testing tests`` command.
+The iterative ``pytest`` command allows to run test individually, or by class or in any other way
+pytest allows to test them and run them interactively, but ``breeze testing tests`` command allows to
+run the tests in the same test "types" that are used to run the tests in CI: for example Core, Always
+API, Providers. This how our CI runs them - running each group in parallel to other groups and you can
+replicate this behaviour.
-Preparing provider documentation files is part of the release procedure by the release managers
-and it is described in detail in `dev `_ .
+Another interesting use of the ``breeze testing tests`` command is that you can easily specify sub-set of the
+tests for Providers.
-The below example perform documentation preparation for provider packages.
+For example this will only run provider tests for airbyte and http providers:
.. code-block:: bash
- breeze prepare-provider-documentation
+ breeze testing tests --test-type "Providers[airbyte,http]"
-By default, the documentation preparation runs package verification to check if all packages are
-importable, but you can add ``--skip-package-verification`` to skip it.
+You can also run parallel tests with ``--run-in-parallel`` flag - by default it will run all tests types
+in parallel, but you can specify the test type that you want to run with space separated list of test
+types passed to ``--test-types`` flag.
+
+For example this will run API and WWW tests in parallel:
.. code-block:: bash
- breeze prepare-provider-documentation --skip-package-verification
+ breeze testing tests --test-types "API WWW" --run-in-parallel
-You can also add ``--answer yes`` to perform non-interactive build.
-.. image:: ./images/breeze/output-prepare-provider-documentation.svg
+Here is the detailed set of options for the ``breeze testing tests`` command.
+
+.. image:: ./images/breeze/output_testing_tests.svg
+ :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_testing_tests.svg
:width: 100%
- :alt: Breeze prepare-provider-documentation
+ :alt: Breeze testing tests
-The packages are prepared in ``dist`` folder. Note, that this command cleans up the ``dist`` folder
-before running, so you should run it before generating airflow package below as it will be removed.
+Running integration tests
+.........................
-The below example builds provider packages in the wheel format.
+You can also run integration tests via built-in ``breeze testing integration-tests`` command. Some of our
+tests require additional integrations to be started in docker-compose. The integration tests command will
+run the expected integration and tests that need that integration.
+
+For example this will only run kerberos tests:
.. code-block:: bash
- breeze prepare-provider-packages
+ breeze testing integration-tests --integration Kerberos
-If you run this command without packages, you will prepare all packages, you can however specify
-providers that you would like to build. By default ``both`` types of packages are prepared (
-``wheel`` and ``sdist``, but you can change it providing optional --package-format flag.
-.. code-block:: bash
+Here is the detailed set of options for the ``breeze testing integration-tests`` command.
- breeze prepare-provider-packages google amazon
+.. image:: ./images/breeze/output_testing_integration-tests.svg
+ :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_testing_integration_tests.svg
+ :width: 100%
+ :alt: Breeze testing integration-tests
-You can see all providers available by running this command:
-.. code-block:: bash
+Running Helm tests
+..................
- breeze prepare-provider-packages --help
+You can use Breeze to run all Helm tests. Those tests are run inside the breeze image as there are all
+necessary tools installed there.
-.. image:: ./images/breeze/output-prepare-provider-packages.svg
+.. image:: ./images/breeze/output_testing_helm-tests.svg
+ :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_testing_helm-tests.svg
:width: 100%
- :alt: Breeze prepare-provider-packages
+ :alt: Breeze testing helm-tests
-You can prepare airflow packages using breeze:
+You can also iterate over those tests with pytest commands, similarly as in case of regular unit tests.
+The helm tests can be found in ``tests/chart`` folder in the main repo.
-.. code-block:: bash
+Running docker-compose tests
+............................
- breeze prepare-airflow-package
+You can use Breeze to run all docker-compose tests. Those tests are run using Production image
+and they are running test with the Quick-start docker compose we have.
-This prepares airflow .whl package in the dist folder.
+.. image:: ./images/breeze/output_testing_docker-compose-tests.svg
+ :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_testing_docker-compose-tests.svg
+ :width: 100%
+ :alt: Breeze testing docker-compose-tests
-Again, you can specify optional ``--package-format`` flag to build selected formats of airflow packages,
-default is to build ``both`` type of packages ``sdist`` and ``wheel``.
+You can also iterate over those tests with pytest command, but - unlike regular unit tests and
+Helm tests, they need to be run in local virtual environment. They also require to have
+``DOCKER_IMAGE`` environment variable set, pointing to the image to test if you do not run them
+through ``breeze testing docker-compose-tests`` command.
-.. code-block:: bash
+The docker-compose tests are in ``docker-tests/`` folder in the main repo.
- breeze prepare-airflow-package --package-format=wheel
+Running Kubernetes tests
+------------------------
-.. image:: ./images/breeze/output-prepare-airflow-package.svg
+Breeze helps with running Kubernetes tests in the same environment/way as CI tests are run.
+Breeze helps to setup KinD cluster for testing, setting up virtualenv and downloads the right tools
+automatically to run the tests.
+
+You can:
+
+* Setup environment for k8s tests with ``breeze k8s setup-env``
+* Build airflow k8S images with ``breeze k8s build-k8s-image``
+* Manage KinD Kubernetes cluster and upload image and deploy Airflow to KinD cluster via
+ ``breeze k8s create-cluster``, ``breeze k8s configure-cluster``, ``breeze k8s deploy-airflow``, ``breeze k8s status``,
+ ``breeze k8s upload-k8s-image``, ``breeze k8s delete-cluster`` commands
+* Run Kubernetes tests specified with ``breeze k8s tests`` command
+* Run complete test run with ``breeze k8s run-complete-tests`` - performing the full cycle of creating
+ cluster, uploading the image, deploying airflow, running tests and deleting the cluster
+* Enter the interactive kubernetes test environment with ``breeze k8s shell`` and ``breeze k8s k9s`` command
+* Run multi-cluster-operations ``breeze k8s list-all-clusters`` and
+ ``breeze k8s delete-all-clusters`` commands as well as running complete tests in parallel
+ via ``breeze k8s dump-logs`` command
+
+This is described in detail in `Testing Kubernetes `_.
+
+You can read more about KinD that we use in `The documentation `_
+
+Here is the detailed set of options for the ``breeze k8s`` command.
+
+.. image:: ./images/breeze/output_k8s.svg
+ :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_k8s.svg
:width: 100%
- :alt: Breeze prepare-airflow-package
+ :alt: Breeze k8s
-Managing Production images
---------------------------
+
+Setting up K8S environment
+..........................
+
+Kubernetes environment can be set with the ``breeze k8s setup-env`` command.
+It will create appropriate virtualenv to run tests and download the right set of tools to run
+the tests: ``kind``, ``kubectl`` and ``helm`` in the right versions. You can re-run the command
+when you want to make sure the expected versions of the tools are installed properly in the
+virtualenv. The Virtualenv is available in ``.build/.k8s-env/bin`` subdirectory of your Airflow
+installation.
+
+.. image:: ./images/breeze/output_k8s_setup-env.svg
+ :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_k8s_setup-env.svg
+ :width: 100%
+ :alt: Breeze k8s setup-env
+
+Creating K8S cluster
+....................
+
+You can create kubernetes cluster (separate cluster for each python/kubernetes version) via
+``breeze k8s create-cluster`` command. With ``--force`` flag the cluster will be
+deleted if exists. You can also use it to create multiple clusters in parallel with
+``--run-in-parallel`` flag - this is what happens in our CI.
+
+All parameters of the command are here:
+
+.. image:: ./images/breeze/output_k8s_create-cluster.svg
+ :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_k8s_create-cluster.svg
+ :width: 100%
+ :alt: Breeze k8s create-cluster
+
+Deleting K8S cluster
+....................
+
+You can delete current kubernetes cluster via ``breeze k8s delete-cluster`` command. You can also add
+``--run-in-parallel`` flag to delete all clusters.
+
+All parameters of the command are here:
+
+.. image:: ./images/breeze/output_k8s_delete-cluster.svg
+ :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_k8s_delete-cluster.svg
+ :width: 100%
+ :alt: Breeze k8s delete-cluster
+
+Building Airflow K8s images
+...........................
+
+Before deploying Airflow Helm Chart, you need to make sure the appropriate Airflow image is build (it has
+embedded test dags, pod templates and webserver is configured to refresh immediately. This can
+be done via ``breeze k8s build-k8s-image`` command. It can also be done in parallel for all images via
+``--run-in-parallel`` flag.
+
+All parameters of the command are here:
+
+.. image:: ./images/breeze/output_k8s_build-k8s-image.svg
+ :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_k8s_build-k8s-image.svg
+ :width: 100%
+ :alt: Breeze k8s build-k8s-image
+
+Uploading Airflow K8s images
+............................
+
+The K8S airflow images need to be uploaded to the KinD cluster. This can be done via
+``breeze k8s upload-k8s-image`` command. It can also be done in parallel for all images via
+``--run-in-parallel`` flag.
+
+All parameters of the command are here:
+
+.. image:: ./images/breeze/output_k8s_upload-k8s-image.svg
+ :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_k8s_upload-k8s-image.svg
+ :width: 100%
+ :alt: Breeze k8s upload-k8s-image
+
+Configuring K8S cluster
+.......................
+
+In order to deploy Airflow, the cluster needs to be configured. Airflow namespace needs to be created
+and test resources should be deployed. By passing ``--run-in-parallel`` the configuration can be run
+for all clusters in parallel.
+
+All parameters of the command are here:
+
+.. image:: ./images/breeze/output_k8s_configure-cluster.svg
+ :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_k8s_configure-cluster.svg
+ :width: 100%
+ :alt: Breeze k8s configure-cluster
+
+Deploying Airflow to the Cluster
+................................
+
+Airflow can be deployed to the Cluster with ``breeze k8s deploy-airflow``. This step will automatically
+(unless disabled by switches) will rebuild the image to be deployed. It also uses the latest version
+of the Airflow Helm Chart to deploy it. You can also choose to upgrade existing airflow deployment
+and pass extra arguments to ``helm install`` or ``helm upgrade`` commands that are used to
+deploy airflow. By passing ``--run-in-parallel`` the deployment can be run
+for all clusters in parallel.
+
+All parameters of the command are here:
+
+.. image:: ./images/breeze/output_k8s_deploy-airflow.svg
+ :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_k8s_deploy-airflow.svg
+ :width: 100%
+ :alt: Breeze k8s deploy-airflow
+
+Checking status of the K8S cluster
+..................................
+
+You can delete kubernetes cluster and airflow deployed in the current cluster
+via ``breeze k8s status`` command. It can be also checked fora all clusters created so far by passing
+``--all`` flag.
+
+All parameters of the command are here:
+
+.. image:: ./images/breeze/output_k8s_status.svg
+ :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_k8s_status.svg
+ :width: 100%
+ :alt: Breeze k8s status
+
+Running k8s tests
+.................
+
+You can run ``breeze k8s tests`` command to run ``pytest`` tests with your cluster. Those testa are placed
+in ``kubernetes_tests/`` and you can either specify the tests to run as parameter of the tests command or
+you can leave them empty to run all tests. By passing ``--run-in-parallel`` the tests can be run
+for all clusters in parallel.
+
+Run all tests:
+
+.. code-block::bash
+
+ breeze k8s tests
+
+Run selected tests:
+
+.. code-block::bash
+
+ breeze k8s tests kubernetes_tests/test_kubernetes_executor.py
+
+All parameters of the command are here:
+
+.. image:: ./images/breeze/output_k8s_tests.svg
+ :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_k8s_tests.svg
+ :width: 100%
+ :alt: Breeze k8s tests
+
+You can also specify any pytest flags as extra parameters - they will be passed to the
+shell command directly. In case the shell parameters are the same as the parameters of the command, you
+can pass them after ``--``. For example this is the way how you can see all available parameters of the shell
+you have:
+
+.. code-block::bash
+
+ breeze k8s tests -- --help
+
+The options that are not overlapping with the ``tests`` command options can be passed directly and mixed
+with the specifications of tests you want to run. For example the command below will only run
+``test_kubernetes_executor.py`` and will suppress capturing output from Pytest so that you can see the
+output during test execution.
+
+.. code-block::bash
+
+ breeze k8s tests -- kubernetes_tests/test_kubernetes_executor.py -s
+
+Running k8s complete tests
+..........................
+
+You can run ``breeze k8s run-complete-tests`` command to combine all previous steps in one command. That
+command will create cluster, deploy airflow and run tests and finally delete cluster. It is used in CI
+to run the whole chains in parallel.
+
+Run all tests:
+
+.. code-block::bash
+
+ breeze k8s run-complete-tests
+
+Run selected tests:
+
+.. code-block::bash
+
+ breeze k8s run-complete-tests kubernetes_tests/test_kubernetes_executor.py
+
+All parameters of the command are here:
+
+.. image:: ./images/breeze/output_k8s_run-complete-tests.svg
+ :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_k8s_run-complete-tests.svg
+ :width: 100%
+ :alt: Breeze k8s tests
+
+You can also specify any pytest flags as extra parameters - they will be passed to the
+shell command directly. In case the shell parameters are the same as the parameters of the command, you
+can pass them after ``--``. For example this is the way how you can see all available parameters of the shell
+you have:
+
+.. code-block::bash
+
+ breeze k8s run-complete-tests -- --help
+
+The options that are not overlapping with the ``tests`` command options can be passed directly and mixed
+with the specifications of tests you want to run. For example the command below will only run
+``test_kubernetes_executor.py`` and will suppress capturing output from Pytest so that you can see the
+output during test execution.
+
+.. code-block::bash
+
+ breeze k8s run-complete-tests -- kubernetes_tests/test_kubernetes_executor.py -s
+
+
+Entering k8s shell
+..................
+
+You can have multiple clusters created - with different versions of Kubernetes and Python at the same time.
+Breeze enables you to interact with the chosen cluster by entering dedicated shell session that has the
+cluster pre-configured. This is done via ``breeze k8s shell`` command.
+
+Once you are in the shell, the prompt will indicate which cluster you are interacting with as well
+as executor you use, similar to:
+
+.. code-block::bash
+
+ (kind-airflow-python-3.9-v1.24.0:KubernetesExecutor)>
+
+
+The shell automatically activates the virtual environment that has all appropriate dependencies
+installed and you can interactively run all k8s tests with pytest command (of course the cluster need to
+be created and airflow deployed to it before running the tests):
+
+.. code-block::bash
+
+ (kind-airflow-python-3.9-v1.24.0:KubernetesExecutor)> pytest kubernetes_tests/test_kubernetes_executor.py
+ ================================================= test session starts =================================================
+ platform linux -- Python 3.10.6, pytest-6.2.5, py-1.11.0, pluggy-1.0.0 -- /home/jarek/code/airflow/.build/.k8s-env/bin/python
+ cachedir: .pytest_cache
+ rootdir: /home/jarek/code/airflow, configfile: pytest.ini
+ plugins: anyio-3.6.1
+ collected 2 items
+
+ kubernetes_tests/test_kubernetes_executor.py::TestKubernetesExecutor::test_integration_run_dag PASSED [ 50%]
+ kubernetes_tests/test_kubernetes_executor.py::TestKubernetesExecutor::test_integration_run_dag_with_scheduler_failure PASSED [100%]
+
+ ================================================== warnings summary ===================================================
+ .build/.k8s-env/lib/python3.10/site-packages/_pytest/config/__init__.py:1233
+ /home/jarek/code/airflow/.build/.k8s-env/lib/python3.10/site-packages/_pytest/config/__init__.py:1233: PytestConfigWarning: Unknown config option: asyncio_mode
+
+ self._warn_or_fail_if_strict(f"Unknown config option: {key}\n")
+
+ -- Docs: https://docs.pytest.org/en/stable/warnings.html
+ ============================================ 2 passed, 1 warning in 38.62s ============================================
+ (kind-airflow-python-3.9-v1.24.0:KubernetesExecutor)>
+
+
+All parameters of the command are here:
+
+.. image:: ./images/breeze/output_k8s_shell.svg
+ :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_k8s_shell.svg
+ :width: 100%
+ :alt: Breeze k8s shell
+
+You can also specify any shell flags and commands as extra parameters - they will be passed to the
+shell command directly. In case the shell parameters are the same as the parameters of the command, you
+can pass them after ``--``. For example this is the way how you can see all available parameters of the shell
+you have:
+
+.. code-block::bash
+
+ breeze k8s shell -- --help
+
+Running k9s tool
+................
+
+The ``k9s`` is a fantastic tool that allows you to interact with running k8s cluster. Since we can have
+multiple clusters capability, ``breeze k8s k9s`` allows you to start k9s without setting it up or
+downloading - it uses k9s docker image to run it and connect it to the right cluster.
+
+All parameters of the command are here:
+
+.. image:: ./images/breeze/output_k8s_k9s.svg
+ :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_k8s_k9s.svg
+ :width: 100%
+ :alt: Breeze k8s k9s
+
+You can also specify any ``k9s`` flags and commands as extra parameters - they will be passed to the
+``k9s`` command directly. In case the ``k9s`` parameters are the same as the parameters of the command, you
+can pass them after ``--``. For example this is the way how you can see all available parameters of the
+``k9s`` you have:
+
+.. code-block::bash
+
+ breeze k8s k9s -- --help
+
+Dumping logs from all k8s clusters
+..................................
+
+KinD allows to export logs from the running cluster so that you can troubleshoot your deployment.
+This can be done with ``breeze k8s logs`` command. Logs can be also dumped fora all clusters created
+so far by passing ``--all`` flag.
+
+All parameters of the command are here:
+
+.. image:: ./images/breeze/output_k8s_logs.svg
+ :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_k8s_logs.svg
+ :width: 100%
+ :alt: Breeze k8s logs
+
+
+CI Image tasks
+--------------
+
+The image building is usually run for users automatically when needed,
+but sometimes Breeze users might want to manually build, pull or verify the CI images.
+
+.. image:: ./images/breeze/output_ci-image.svg
+ :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_ci-image.svg
+ :width: 100%
+ :alt: Breeze ci-image
+
+For all development tasks, unit tests, integration tests, and static code checks, we use the
+**CI image** maintained in GitHub Container Registry.
+
+The CI image is built automatically as needed, however it can be rebuilt manually with
+``ci image build`` command.
+
+Building the image first time pulls a pre-built version of images from the Docker Hub, which may take some
+time. But for subsequent source code changes, no wait time is expected.
+However, changes to sensitive files like ``setup.py`` or ``Dockerfile.ci`` will trigger a rebuild
+that may take more time though it is highly optimized to only rebuild what is needed.
+
+Breeze has built in mechanism to check if your local image has not diverged too much from the
+latest image build on CI. This might happen when for example latest patches have been released as new
+Python images or when significant changes are made in the Dockerfile. In such cases, Breeze will
+download the latest images before rebuilding because this is usually faster than rebuilding the image.
+
+Building CI image
+.................
+
+Those are all available flags of ``ci-image build`` command:
+
+.. image:: ./images/breeze/output_ci-image_build.svg
+ :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_ci-image_build.svg
+ :width: 100%
+ :alt: Breeze ci-image build
+
+Pulling CI image
+................
+
+You can also pull the CI images locally in parallel with optional verification.
+
+Those are all available flags of ``pull`` command:
+
+.. image:: ./images/breeze/output_ci-image_pull.svg
+ :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_ci-image_pull.svg
+ :width: 100%
+ :alt: Breeze ci-image pull
+
+Verifying CI image
+..................
+
+Finally, you can verify CI image by running tests - either with the pulled/built images or
+with an arbitrary image.
+
+Those are all available flags of ``verify`` command:
+
+.. image:: ./images/breeze/output_ci-image_verify.svg
+ :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_ci-image_verify.svg
+ :width: 100%
+ :alt: Breeze ci-image verify
+
+PROD Image tasks
+----------------
+
+Users can also build Production images when they are developing them. However when you want to
+use the PROD image, the regular docker build commands are recommended. See
+`building the image `_
+
+.. image:: ./images/breeze/output_prod-image.svg
+ :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_prod-image.svg
+ :width: 100%
+ :alt: Breeze prod-image
The **Production image** is also maintained in GitHub Container Registry for Caching
and in ``apache/airflow`` manually pushed for released versions. This Docker image (built using official
Dockerfile) contains size-optimised Airflow installation with selected extras and dependencies.
However in many cases you want to add your own custom version of the image - with added apt dependencies,
-python dependencies, additional Airflow extras. Breeze's ``build-image`` command helps to build your own,
+python dependencies, additional Airflow extras. Breeze's ``prod-image build`` command helps to build your own,
customized variant of the image that contains everything you need.
-You can switch to building the production image by using ``build-prod-image`` command.
+You can building the production image manually by using ``prod-image build`` command.
Note, that the images can also be built using ``docker build`` command by passing appropriate
build-args as described in `IMAGES.rst `_ , but Breeze provides several flags that
-makes it easier to do it. You can see all the flags by running ``breeze build-prod-image --help``,
+makes it easier to do it. You can see all the flags by running ``breeze prod-image build --help``,
but here typical examples are presented:
.. code-block:: bash
- breeze build-prod-image --additional-extras "jira"
+ breeze prod-image build --additional-extras "jira"
This installs additional ``jira`` extra while installing airflow in the image.
.. code-block:: bash
- breeze build-prod-image --additional-python-deps "torchio==0.17.10"
+ breeze prod-image build --additional-python-deps "torchio==0.17.10"
This install additional pypi dependency - torchio in specified version.
-
.. code-block:: bash
- breeze build-prod-image --additional-dev-apt-deps "libasound2-dev" \
+ breeze prod-image build --additional-dev-apt-deps "libasound2-dev" \
--additional-runtime-apt-deps "libasound2"
This installs additional apt dependencies - ``libasound2-dev`` in the build image and ``libasound`` in the
@@ -929,202 +1232,389 @@ suffix and they need to also be paired with corresponding runtime dependency add
.. code-block:: bash
- breeze build-prod-image --python 3.7 --additional-dev-deps "libasound2-dev" \
+ breeze prod-image build --python 3.7 --additional-dev-deps "libasound2-dev" \
--additional-runtime-apt-deps "libasound2"
Same as above but uses python 3.7.
+Building PROD image
+...................
+
Those are all available flags of ``build-prod-image`` command:
-.. image:: ./images/breeze/output-build-prod-image.svg
+.. image:: ./images/breeze/output_prod-image_build.svg
+ :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_prod-image_build.svg
:width: 100%
- :alt: Breeze commands
-
-Here is the part of Breeze video which is relevant (note that it refers to the old ``./breeze-legacy``
-command but it is very similar to current ``breeze`` command):
+ :alt: Breeze prod-image build
-.. raw:: html
-
-
+Pulling PROD image
+..................
You can also pull PROD images in parallel with optional verification.
Those are all available flags of ``pull-prod-image`` command:
-.. image:: ./images/breeze/output-pull-prod-image.svg
+.. image:: ./images/breeze/output_prod-image_pull.svg
+ :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_prod-image_pull.svg
:width: 100%
- :alt: Breeze pull-prod-image
+ :alt: Breeze prod-image pull
+
+Verifying PROD image
+....................
Finally, you can verify PROD image by running tests - either with the pulled/built images or
with an arbitrary image.
Those are all available flags of ``verify-prod-image`` command:
-.. image:: ./images/breeze/output-verify-prod-image.svg
+.. image:: ./images/breeze/output_prod-image_verify.svg
+ :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_prod-image_verify.svg
:width: 100%
- :alt: Breeze verify-prod-image
+ :alt: Breeze prod-image verify
-Releasing Production images to DockerHub
-----------------------------------------
-The **Production image** can be released by release managers who have permissions to push the image. This
-happens only when there is an RC candidate or final version of Airflow released.
+Breeze setup
+------------
-You release "regular" and "slim" images as separate steps.
+Breeze has tools that you can use to configure defaults and breeze behaviours and perform some maintenance
+operations that might be necessary when you add new commands in Breeze. It also allows to configure your
+host operating system for Breeze autocompletion.
-Releasing "regular" images:
+Those are all available flags of ``setup`` command:
-.. code-block:: bash
+.. image:: ./images/breeze/output_setup.svg
+ :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_setup.svg
+ :width: 100%
+ :alt: Breeze setup
- breeze release-prod-images --airflow-version 2.4.0
+Breeze configuration
+....................
-Or "slim" images:
+You can configure and inspect settings of Breeze command via this command: Python version, Backend used as
+well as backend versions.
+
+Another part of configuration is enabling/disabling cheatsheet, asciiart. The cheatsheet and asciiart can
+be disabled - they are "nice looking" and cheatsheet
+contains useful information for first time users but eventually you might want to disable both if you
+find it repetitive and annoying.
+
+With the config setting colour-blind-friendly communication for Breeze messages. By default we communicate
+with the users about information/errors/warnings/successes via colour-coded messages, but we can switch
+it off by passing ``--no-colour`` to config in which case the messages to the user printed by Breeze
+will be printed using different schemes (italic/bold/underline) to indicate different kind of messages
+rather than colours.
+
+Those are all available flags of ``setup config`` command:
+
+.. image:: ./images/breeze/output_setup_config.svg
+ :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_setup_config.svg
+ :width: 100%
+ :alt: Breeze setup config
+
+Setting up autocompletion
+.........................
+
+You get the auto-completion working when you re-enter the shell (follow the instructions printed).
+The command will warn you and not reinstall autocomplete if you already did, but you can
+also force reinstalling the autocomplete via:
.. code-block:: bash
- breeze release-prod-images --airflow-version 2.4.0 --slim-images
+ breeze setup autocomplete --force
-By default when you are releasing the "final" image, we also tag image with "latest" tags but this
-step can be skipped if you pass the ``--skip-latest`` flag.
+Those are all available flags of ``setup-autocomplete`` command:
-These are all of the available flags for the ``release-prod-images`` command:
+.. image:: ./images/breeze/output_setup_autocomplete.svg
+ :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_setup_autocomplete.svg
+ :width: 100%
+ :alt: Breeze setup autocomplete
+
+Breeze version
+..............
-.. image:: ./images/breeze/output-release-prod-images.svg
+You can display Breeze version and with ``--verbose`` flag it can provide more information: where
+Breeze is installed from and details about setup hashes.
+
+Those are all available flags of ``version`` command:
+
+.. image:: ./images/breeze/output_setup_version.svg
+ :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_setup_version.svg
:width: 100%
- :alt: Release prod images
+ :alt: Breeze version
-Running static checks
----------------------
+Breeze self-upgrade
+...................
-You can run static checks via Breeze. You can also run them via pre-commit command but with auto-completion
-Breeze makes it easier to run selective static checks. If you press after the static-check and if
-you have auto-complete setup you should see auto-completable list of all checks available.
+You can self-upgrade breeze automatically. Those are all available flags of ``self-upgrade`` command:
-.. code-block:: bash
+.. image:: ./images/breeze/output_setup_self-upgrade.svg
+ :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_setup_self-upgrade.svg
+ :width: 100%
+ :alt: Breeze setup self-upgrade
- breeze static-checks -t mypy
-The above will run mypy check for currently staged files.
+Regenerating images for documentation
+.....................................
-You can also pass specific pre-commit flags for example ``--all-files`` :
+This documentation contains exported images with "help" of their commands and parameters. You can
+regenerate those images that need to be regenerated because their commands changed (usually after
+the breeze code has been changed) via ``regenerate-command-images`` command. Usually this is done
+automatically via pre-commit, but sometimes (for example when ``rich`` or ``rich-click`` library changes)
+you need to regenerate those images.
-.. code-block:: bash
+You can add ``--force`` flag (or ``FORCE="true"`` environment variable to regenerate all images (not
+only those that need regeneration). You can also run the command with ``--check-only`` flag to simply
+check if there are any images that need regeneration.
- breeze static-checks -t mypy --all-files
+.. image:: ./images/breeze/output_setup_regenerate-command-images.svg
+ :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_setup_regenerate-command-images.svg
+ :width: 100%
+ :alt: Breeze setup regenerate-command-images
-The above will run mypy check for all files.
-There is a convenience ``--last-commit`` flag that you can use to run static check on last commit only:
+CI tasks
+--------
-.. code-block:: bash
+Breeze hase a number of commands that are mostly used in CI environment to perform cleanup.
- breeze static-checks -t mypy --last-commit
+.. image:: ./images/breeze/output_ci.svg
+ :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_ci.svg
+ :width: 100%
+ :alt: Breeze ci commands
-The above will run mypy check for all files in the last commit.
+Running resource check
+......................
-There is another convenience ``--commit-ref`` flag that you can use to run static check on specific commit:
+Breeze requires certain resources to be available - disk, memory, CPU. When you enter Breeze's shell,
+the resources are checked and information if there is enough resources is displayed. However you can
+manually run resource check any time by ``breeze ci resource-check`` command.
+
+Those are all available flags of ``resource-check`` command:
+
+.. image:: ./images/breeze/output_ci_resource-check.svg
+ :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_ci_resource-check.svg
+ :width: 100%
+ :alt: Breeze ci resource-check
+
+Freeing the space
+.................
+
+When our CI runs a job, it needs all memory and disk it can have. We have a Breeze command that frees
+the memory and disk space used. You can also use it clear space locally but it performs a few operations
+that might be a bit invasive - such are removing swap file and complete pruning of docker disk space used.
+
+Those are all available flags of ``free-space`` command:
+
+.. image:: ./images/breeze/output_ci_free-space.svg
+ :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_ci_free-space.svg
+ :width: 100%
+ :alt: Breeze ci free-space
+
+Fixing File/Directory Ownership
+...............................
+
+On Linux, there is a problem with propagating ownership of created files (a known Docker problem). The
+files and directories created in the container are not owned by the host user (but by the root user in our
+case). This may prevent you from switching branches, for example, if files owned by the root user are
+created within your sources. In case you are on a Linux host and have some files in your sources created
+by the root user, you can fix the ownership of those files by running :
+
+.. code-block::
+
+ breeze ci fix-ownership
+
+Those are all available flags of ``fix-ownership`` command:
+
+.. image:: ./images/breeze/output_ci_fix-ownership.svg
+ :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_ci_fix-ownership.svg
+ :width: 100%
+ :alt: Breeze ci fix-ownership
+
+Selective check
+...............
+
+When our CI runs a job, it needs to decide which tests to run, whether to build images and how much the test
+should be run on multiple combinations of Python, Kubernetes, Backend versions. In order to optimize time
+needed to run the CI Builds. You can also use the tool to test what tests will be run when you provide
+a specific commit that Breeze should run the tests on.
+
+The selective-check command will produce the set of ``name=value`` pairs of outputs derived
+from the context of the commit/PR to be merged via stderr output.
+
+More details about the algorithm used to pick the right tests and the available outputs can be
+found in `Selective Checks `_.
+
+Those are all available flags of ``selective-check`` command:
+
+.. image:: ./images/breeze/output_ci_selective-check.svg
+ :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_ci_selective-check.svg
+ :width: 100%
+ :alt: Breeze ci selective-check
+
+Getting workflow information
+............................
+
+When our CI runs a job, it might be within one of several workflows. Information about those workflows
+is stored in GITHUB_CONTEXT. Rather than using some jq/bash commands, we retrieve the necessary information
+(like PR labels, event_type, where the job runs on, job description and convert them into GA outputs.
+
+Those are all available flags of ``get-workflow-info`` command:
+
+.. image:: ./images/breeze/output_ci_get-workflow-info.svg
+ :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_ci_get-workflow-info.svg
+ :width: 100%
+ :alt: Breeze ci get-workflow-info
+
+Tracking backtracking issues for CI builds
+..........................................
+
+When our CI runs a job, we automatically upgrade our dependencies in the ``main`` build. However, this might
+lead to conflicts and ``pip`` backtracking for a long time (possibly forever) for dependency resolution.
+Unfortunately those issues are difficult to diagnose so we had to invent our own tool to help us with
+diagnosing them. This tool is ``find-newer-dependencies`` and it works in the way that it helps to guess
+which new dependency might have caused the backtracking. The whole process is described in
+`tracking backtracking issues `_.
+
+Those are all available flags of ``find-newer-dependencies`` command:
+
+.. image:: ./images/breeze/output_ci_find-newer-dependencies.svg
+ :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_ci_find-newer-dependencies.svg
+ :width: 100%
+ :alt: Breeze ci find-newer-dependencies
+
+Release management tasks
+------------------------
+
+Maintainers also can use Breeze for other purposes (those are commands that regular contributors likely
+do not need or have no access to run). Those are usually connected with releasing Airflow:
+
+.. image:: ./images/breeze/output_release-management.svg
+ :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_release-management.svg
+ :width: 100%
+ :alt: Breeze release management
+
+Breeze can be used to prepare airflow packages - both "apache-airflow" main package and
+provider packages.
+
+Preparing provider documentation
+................................
+
+You can read more about testing provider packages in
+`TESTING.rst `_
+
+There are several commands that you can run in Breeze to manage and build packages:
+
+* preparing Provider documentation files
+* preparing Airflow packages
+* preparing Provider packages
+
+Preparing provider documentation files is part of the release procedure by the release managers
+and it is described in detail in `dev `_ .
+
+The below example perform documentation preparation for provider packages.
.. code-block:: bash
- breeze static-checks -t mypy --commit-ref 639483d998ecac64d0fef7c5aa4634414065f690
+ breeze release-management prepare-provider-documentation
-The above will run mypy check for all files in the 639483d998ecac64d0fef7c5aa4634414065f690 commit.
-Any ``commit-ish`` reference from Git will work here (branch, tag, short/long hash etc.)
+By default, the documentation preparation runs package verification to check if all packages are
+importable, but you can add ``--skip-package-verification`` to skip it.
+
+.. code-block:: bash
+
+ breeze release-management prepare-provider-documentation --skip-package-verification
+
+You can also add ``--answer yes`` to perform non-interactive build.
+
+.. image:: ./images/breeze/output_release-management_prepare-provider-documentation.svg
+ :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_release-management_prepare-provider-documentation.svg
+ :width: 100%
+ :alt: Breeze prepare-provider-documentation
+
+Preparing provider packages
+...........................
+
+You can use Breeze to prepare provider packages.
+
+The packages are prepared in ``dist`` folder. Note, that this command cleans up the ``dist`` folder
+before running, so you should run it before generating airflow package below as it will be removed.
-If you ever need to get a list of the files that will be checked (for troubleshooting) use these commands:
+The below example builds provider packages in the wheel format.
.. code-block:: bash
- breeze static-checks -t identity --verbose # currently staged files
- breeze static-checks -t identity --verbose --from-ref $(git merge-base main HEAD) --to-ref HEAD # branch updates
-
-Those are all available flags of ``static-checks`` command:
+ breeze release-management prepare-provider-packages
-.. image:: ./images/breeze/output-static-checks.svg
- :width: 100%
- :alt: Breeze static checks
+If you run this command without packages, you will prepare all packages, you can however specify
+providers that you would like to build. By default ``both`` types of packages are prepared (
+``wheel`` and ``sdist``, but you can change it providing optional --package-format flag.
-Here is the part of Breeze video which is relevant (note that it refers to the old ``./breeze-legacy``
-command but it is very similar to current ``breeze`` command):
+.. code-block:: bash
-.. raw:: html
+ breeze release-management prepare-provider-packages google amazon
-
+You can see all providers available by running this command:
-.. note::
+.. code-block:: bash
- When you run static checks, some of the artifacts (mypy_cache) is stored in docker-compose volume
- so that it can speed up static checks execution significantly. However, sometimes, the cache might
- get broken, in which case you should run ``breeze stop`` to clean up the cache.
+ breeze release-management prepare-provider-packages --help
+.. image:: ./images/breeze/output_release-management_prepare-provider-packages.svg
+ :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_release-management_prepare-provider-packages.svg
+ :width: 100%
+ :alt: Breeze prepare-provider-packages
-Building the Documentation
---------------------------
+Verifying provider packages
+...........................
-To build documentation in Breeze, use the ``build-docs`` command:
+Breeze can also be used to verify if provider classes are importable and if they are following the
+right naming conventions. This happens automatically on CI but you can also run it manually if you
+just prepared provider packages and they are present in ``dist`` folder.
.. code-block:: bash
- breeze build-docs
+ breeze release-management verify-provider-packages
-Results of the build can be found in the ``docs/_build`` folder.
+You can also run the verification with an earlier airflow version to check for compatibility.
-The documentation build consists of three steps:
+.. code-block:: bash
-* verifying consistency of indexes
-* building documentation
-* spell checking
+ breeze release-management verify-provider-packages --use-airflow-version 2.1.0
-You can choose only one stage of the two by providing ``--spellcheck-only`` or ``--docs-only`` after
-extra ``--`` flag.
+All the command parameters are here:
-.. code-block:: bash
+.. image:: ./images/breeze/output_release-management_verify-provider-packages.svg
+ :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_release-management_verify-provider-packages.svg
+ :width: 100%
+ :alt: Breeze verify-provider-packages
- breeze build-docs --spellcheck-only
-This process can take some time, so in order to make it shorter you can filter by package, using the flag
-``--package-filter ``. The package name has to be one of the providers or ``apache-airflow``. For
-instance, for using it with Amazon, the command would be:
+Preparing airflow packages
+..........................
-.. code-block:: bash
+You can prepare airflow packages using Breeze:
- breeze build-docs --package-filter apache-airflow-providers-amazon
+.. code-block:: bash
-Often errors during documentation generation come from the docstrings of auto-api generated classes.
-During the docs building auto-api generated files are stored in the ``docs/_api`` folder. This helps you
-easily identify the location the problems with documentation originated from.
+ breeze release-management prepare-airflow-package
-Those are all available flags of ``build-docs`` command:
+This prepares airflow .whl package in the dist folder.
-.. image:: ./images/breeze/output-build-docs.svg
- :width: 100%
- :alt: Breeze build documentation
+Again, you can specify optional ``--package-format`` flag to build selected formats of airflow packages,
+default is to build ``both`` type of packages ``sdist`` and ``wheel``.
-Here is the part of Breeze video which is relevant (note that it refers to the old ``./breeze-legacy``
-command but it is very similar to current ``breeze`` command):
+.. code-block:: bash
-.. raw:: html
+ breeze release-management prepare-airflow-package --package-format=wheel
-
+.. image:: ./images/breeze/output_release-management_prepare-airflow-package.svg
+ :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_release-management_prepare-airflow-package.svg
+ :width: 100%
+ :alt: Breeze release-management prepare-airflow-package
Generating constraints
-----------------------
+......................
Whenever setup.py gets modified, the CI main job will re-generate constraint files. Those constraint
files are stored in separated orphan branches: ``constraints-main``, ``constraints-2-0``.
@@ -1133,7 +1623,7 @@ Those are constraint files as described in detail in the
``_ contributing documentation.
-You can use ``breeze generate-constraints`` command to manually generate constraints for
+You can use ``breeze release-management generate-constraints`` command to manually generate constraints for
all or selected python version and single constraint mode like this:
.. warning::
@@ -1144,7 +1634,7 @@ all or selected python version and single constraint mode like this:
.. code-block:: bash
- breeze generate-constraints --airflow-constraints-mode constraints
+ breeze release-management generate-constraints --airflow-constraints-mode constraints
Constraints are generated separately for each python version and there are separate constraints modes:
@@ -1163,7 +1653,8 @@ Constraints are generated separately for each python version and there are separ
Those are all available flags of ``generate-constraints`` command:
-.. image:: ./images/breeze/output-generate-constraints.svg
+.. image:: ./images/breeze/output_release-management_generate-constraints.svg
+ :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_release-management_generate-constraints.svg
:width: 100%
:alt: Breeze generate-constraints
@@ -1176,157 +1667,140 @@ This bumps the constraint files to latest versions and stores hash of setup.py.
and setup.py hash files are stored in the ``files`` folder and while generating the constraints diff
of changes vs the previous constraint files is printed.
+Releasing Production images
+...........................
-Using local virtualenv environment in Your Host IDE
----------------------------------------------------
-
-You can set up your host IDE (for example, IntelliJ's PyCharm/Idea) to work with Breeze
-and benefit from all the features provided by your IDE, such as local and remote debugging,
-language auto-completion, documentation support, etc.
+The **Production image** can be released by release managers who have permissions to push the image. This
+happens only when there is an RC candidate or final version of Airflow released.
-To use your host IDE with Breeze:
+You release "regular" and "slim" images as separate steps.
-1. Create a local virtual environment:
+Releasing "regular" images:
- You can use any of the following wrappers to create and manage your virtual environments:
- `pyenv `_, `pyenv-virtualenv `_,
- or `virtualenvwrapper `_.
+.. code-block:: bash
-2. Use the right command to activate the virtualenv (``workon`` if you use virtualenvwrapper or
- ``pyenv activate`` if you use pyenv.
+ breeze release-management release-prod-images --airflow-version 2.4.0
-3. Initialize the created local virtualenv:
+Or "slim" images:
.. code-block:: bash
- ./scripts/tools/initialize_virtualenv.py
-
-.. warning::
- Make sure that you use the right Python version in this command - matching the Python version you have
- in your local virtualenv. If you don't, you will get strange conflicts.
+ breeze release-management release-prod-images --airflow-version 2.4.0 --slim-images
-4. Select the virtualenv you created as the project's default virtualenv in your IDE.
+By default when you are releasing the "final" image, we also tag image with "latest" tags but this
+step can be skipped if you pass the ``--skip-latest`` flag.
-Note that you can also use the local virtualenv for Airflow development without Breeze.
-This is a lightweight solution that has its own limitations.
+These are all of the available flags for the ``release-prod-images`` command:
-More details on using the local virtualenv are available in the `LOCAL_VIRTUALENV.rst `_.
+.. image:: ./images/breeze/output_release-management_release-prod-images.svg
+ :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output_release-management_release-prod-images.svg
+ :width: 100%
+ :alt: Breeze release management release prod images
-Here is the part of Breeze video which is relevant (note that it refers to the old ``./breeze-legacy``
-but it is not available in the ``breeze`` command):
-.. raw:: html
+Details of Breeze usage
+=======================
-
+Database volumes in Breeze
+--------------------------
-Running docker-compose tests
-----------------------------
+Breeze keeps data for all it's integration in named docker volumes. Each backend and integration
+keeps data in their own volume. Those volumes are persisted until ``breeze stop`` command.
+You can also preserve the volumes by adding flag ``--preserve-volumes`` when you run the command.
+Then, next time when you start Breeze, it will have the data pre-populated.
-You can use Breeze to run docker-compose tests. Those tests are run using Production image
-and they are running test with the Quick-start docker compose we have.
+Those are all available flags of ``stop`` command:
-.. image:: ./images/breeze/output-docker-compose-tests.svg
+.. image:: ./images/breeze/output-stop.svg
+ :target: https://raw.githubusercontent.com/apache/airflow/main/images/breeze/output-stop.svg
:width: 100%
- :alt: Breeze generate-constraints
-
-
-Running Kubernetes tests
-------------------------
+ :alt: Breeze stop
-Breeze helps with running Kubernetes tests in the same environment/way as CI tests are run.
-Breeze helps to setup KinD cluster for testing, setting up virtualenv and downloads the right tools
-automatically to run the tests.
-This is described in detail in `Testing Kubernetes `_.
+Additional tools
+----------------
-Here is the part of Breeze video which is relevant (note that it refers to the old ``./breeze-legacy``
-command and it is not yet available in the current ``breeze`` command):
+To shrink the Docker image, not all tools are pre-installed in the Docker image. But we have made sure that there
+is an easy process to install additional tools.
-.. raw:: html
+Additional tools are installed in ``/files/bin``. This path is added to ``$PATH``, so your shell will
+automatically autocomplete files that are in that directory. You can also keep the binaries for your tools
+in this directory if you need to.
-
+**Installation scripts**
-Stopping the interactive environment
-------------------------------------
+For the development convenience, we have also provided installation scripts for commonly used tools. They are
+installed to ``/files/opt/``, so they are preserved after restarting the Breeze environment. Each script
+is also available in ``$PATH``, so just type ``install_`` to get a list of tools.
-After starting up, the environment runs in the background and takes precious memory.
-You can always stop it via:
+Currently available scripts:
-.. code-block:: bash
+* ``install_aws.sh`` - installs `the AWS CLI `__ including
+* ``install_az.sh`` - installs `the Azure CLI `__ including
+* ``install_gcloud.sh`` - installs `the Google Cloud SDK `__ including
+ ``gcloud``, ``gsutil``.
+* ``install_imgcat.sh`` - installs `imgcat - Inline Images Protocol `__
+ for iTerm2 (Mac OS only)
+* ``install_java.sh`` - installs `the OpenJDK 8u41 `__
+* ``install_kubectl.sh`` - installs `the Kubernetes command-line tool, kubectl `__
+* ``install_snowsql.sh`` - installs `SnowSQL `__
+* ``install_terraform.sh`` - installs `Terraform `__
- breeze stop
+Launching Breeze integrations
+-----------------------------
-Those are all available flags of ``stop`` command:
+When Breeze starts, it can start additional integrations. Those are additional docker containers
+that are started in the same docker-compose command. Those are required by some of the tests
+as described in ``_.
-.. image:: ./images/breeze/output-stop.svg
- :width: 100%
- :alt: Breeze stop
+By default Breeze starts only airflow container without any integration enabled. If you selected
+``postgres`` or ``mysql`` backend, the container for the selected backend is also started (but only the one
+that is selected). You can start the additional integrations by passing ``--integration`` flag
+with appropriate integration name when starting Breeze. You can specify several ``--integration`` flags
+to start more than one integration at a time.
+Finally you can specify ``--integration all`` to start all integrations.
-Here is the part of Breeze video which is relevant (note that it refers to the old ``./breeze-legacy``
-command but it is very similar to current ``breeze`` command):
+Once integration is started, it will continue to run until the environment is stopped with
+``breeze stop`` command. or restarted via ``breeze restart`` command
-.. raw:: html
+Note that running integrations uses significant resources - CPU and memory.
-
-Resource check
-==============
+Using local virtualenv environment in Your Host IDE
+---------------------------------------------------
-Breeze requires certain resources to be available - disk, memory, CPU. When you enter Breeze's shell,
-the resources are checked and information if there is enough resources is displayed. However you can
-manually run resource check any time by ``breeze resource-check`` command.
+You can set up your host IDE (for example, IntelliJ's PyCharm/Idea) to work with Breeze
+and benefit from all the features provided by your IDE, such as local and remote debugging,
+language auto-completion, documentation support, etc.
-Those are all available flags of ``resource-check`` command:
+To use your host IDE with Breeze:
-.. image:: ./images/breeze/output-resource-check.svg
- :width: 100%
- :alt: Breeze resource-check
+1. Create a local virtual environment:
+ You can use any of the following wrappers to create and manage your virtual environments:
+ `pyenv `_, `pyenv-virtualenv `_,
+ or `virtualenvwrapper `_.
-Freeing the space
-=================
+2. Use the right command to activate the virtualenv (``workon`` if you use virtualenvwrapper or
+ ``pyenv activate`` if you use pyenv.
-When our CI runs a job, it needs all memory and disk it can have. We have a Breeze command that frees
-the memory and disk space used. You can also use it clear space locally but it performs a few operations
-that might be a bit invasive - such are removing swap file and complete pruning of docker disk space used.
+3. Initialize the created local virtualenv:
-Those are all available flags of ``free-space`` command:
+.. code-block:: bash
-.. image:: ./images/breeze/output-free-space.svg
- :width: 100%
- :alt: Breeze free-space
+ ./scripts/tools/initialize_virtualenv.py
+.. warning::
+ Make sure that you use the right Python version in this command - matching the Python version you have
+ in your local virtualenv. If you don't, you will get strange conflicts.
-Tracking backtracking issues for CI builds
-==========================================
+4. Select the virtualenv you created as the project's default virtualenv in your IDE.
-When our CI runs a job, we automatically upgrade our dependencies in the ``main`` build. However, this might
-lead to conflicts and ``pip`` backtracking for a long time (possibly forever) for dependency resolution.
-Unfortunately those issues are difficult to diagnose so we had to invent our own tool to help us with
-diagnosing them. This tool is ``find-newer-dependencies`` and it works in the way that it helps to guess
-which new dependency might have caused the backtracking. The whole process is described in
-`tracking backtracking issues `_.
+Note that you can also use the local virtualenv for Airflow development without Breeze.
+This is a lightweight solution that has its own limitations.
-Those are all available flags of ``find-newer-dependencies`` command:
+More details on using the local virtualenv are available in the `LOCAL_VIRTUALENV.rst `_.
-.. image:: ./images/breeze/output-find-newer-dependencies.svg
- :width: 100%
- :alt: Breeze find-newer-dependencies
Internal details of Breeze
==========================
@@ -1341,10 +1815,12 @@ When you are in the CI container, the following directories are used:
/opt/airflow - Contains sources of Airflow mounted from the host (AIRFLOW_SOURCES).
/root/airflow - Contains all the "dynamic" Airflow files (AIRFLOW_HOME), such as:
airflow.db - sqlite database in case sqlite is used;
- dags - folder with non-test dags (test dags are in /opt/airflow/tests/dags);
logs - logs from Airflow executions;
unittest.cfg - unit test configuration generated when entering the environment;
webserver_config.py - webserver configuration generated when running Airflow in the container.
+ /files - files mounted from "files" folder in your sources. You can edit them in the host as well
+ dags - this is the folder where Airflow DAGs are read from
+ airflow-breeze-config - this is where you can keep your own customization configuration of breeze
Note that when running in your local environment, the ``/root/airflow/logs`` folder is actually mounted
from your ``logs`` directory in the Airflow sources, so all logs created in the container are automatically
@@ -1358,43 +1834,17 @@ When you are in the production container, the following directories are used:
/opt/airflow - Contains sources of Airflow mounted from the host (AIRFLOW_SOURCES).
/root/airflow - Contains all the "dynamic" Airflow files (AIRFLOW_HOME), such as:
airflow.db - sqlite database in case sqlite is used;
- dags - folder with non-test dags (test dags are in /opt/airflow/tests/dags);
logs - logs from Airflow executions;
unittest.cfg - unit test configuration generated when entering the environment;
webserver_config.py - webserver configuration generated when running Airflow in the container.
+ /files - files mounted from "files" folder in your sources. You can edit them in the host as well
+ dags - this is the folder where Airflow DAGs are read from
Note that when running in your local environment, the ``/root/airflow/logs`` folder is actually mounted
from your ``logs`` directory in the Airflow sources, so all logs created in the container are automatically
visible in the host as well. Every time you enter the container, the ``logs`` directory is
cleaned so that logs do not accumulate.
-Running Arbitrary commands in the Breeze environment
-----------------------------------------------------
-
-To run other commands/executables inside the Breeze Docker-based environment, use the
-``breeze shell`` command.
-
-.. code-block:: bash
-
- breeze shell "ls -la"
-
-Those are all available flags of ``shell`` command:
-
-.. image:: ./images/breeze/output-shell.svg
- :width: 100%
- :alt: Breeze shell
-
-Running "Docker Compose" commands
----------------------------------
-
-To run Docker Compose commands (such as ``help``, ``pull``, etc), use the
-``docker-compose`` command. To add extra arguments, specify them
-after ``--`` as extra arguments.
-
-.. code-block:: bash
-
- ./breeze-legacy docker-compose pull -- --ignore-pull-failures
-
Setting default answers for user interaction
--------------------------------------------
@@ -1410,25 +1860,6 @@ For automation scripts, you can export the ``ANSWER`` variable (and set it to
export ANSWER="yes"
-Fixing File/Directory Ownership
--------------------------------
-
-On Linux, there is a problem with propagating ownership of created files (a known Docker problem). The
-files and directories created in the container are not owned by the host user (but by the root user in our
-case). This may prevent you from switching branches, for example, if files owned by the root user are
-created within your sources. In case you are on a Linux host and have some files in your sources created
-by the root user, you can fix the ownership of those files by running :
-
-.. code-block::
-
- breeze fix-ownership
-
-Those are all available flags of ``fix-ownership`` command:
-
-.. image:: ./images/breeze/output-fix-ownership.svg
- :width: 100%
- :alt: Breeze fix-ownership
-
Mounting Local Sources to Breeze
--------------------------------
@@ -1448,6 +1879,11 @@ By default ``/files/dags`` folder is mounted from your local ``
the directory used by airflow scheduler and webserver to scan dags for. You can use it to test your dags
from local sources in Airflow. If you wish to add local DAGs that can be run by Breeze.
+The ``/files/airflow-breeze-config`` folder contains configuration files that might be used to
+customize your breeze instance. Those files will be kept across checking out a code from different
+branches and stopping/starting breeze so you can keep your configuration there and use it continuously while
+you switch to different source code versions.
+
Port Forwarding
---------------
@@ -1508,28 +1944,18 @@ If you set these variables, next time when you enter the environment the new por
Managing Dependencies
---------------------
-If you need to change apt dependencies in the ``Dockerfile.ci``, add Python packages in ``setup.py`` or
-add JavaScript dependencies in ``package.json``, you can either add dependencies temporarily for a single
-Breeze session or permanently in ``setup.py``, ``Dockerfile.ci``, or ``package.json`` files.
-
-Installing Dependencies for a Single Breeze Session
-...................................................
-
-You can install dependencies inside the container using ``sudo apt install``, ``pip install`` or
-``yarn install`` (in ``airflow/www`` folder) respectively. This is useful if you want to test something
-quickly while you are in the container. However, these changes are not retained: they disappear once you
-exit the container (except for the node.js dependencies if your sources are mounted to the container).
-Therefore, if you want to retain a new dependency, follow the second option described below.
+If you need to change apt dependencies in the ``Dockerfile.ci``, add Python packages in ``setup.py``
+for airflow and in provider.yaml for packages. If you add any "node" dependencies in ``airflow/www``
+, you need to compile them in the host with ``breeze compile-www-assets`` command.
Adding Dependencies Permanently
...............................
-You can add dependencies to the ``Dockerfile.ci``, ``setup.py`` or ``package.json`` and rebuild the image.
-This should happen automatically if you modify any of these files.
+You can add dependencies to the ``Dockerfile.ci``, ``setup.py``.
After you exit the container and re-run ``breeze``, Breeze detects changes in dependencies,
asks you to confirm rebuilding the image and proceeds with rebuilding if you confirm (or skip it
if you do not confirm). After rebuilding is done, Breeze drops you to shell. You may also use the
-``build-image`` command to only rebuild CI image and not to go into shell.
+``build`` command to only rebuild CI image and not to go into shell.
Incremental apt Dependencies in the Dockerfile.ci during development
....................................................................
diff --git a/CI.rst b/CI.rst
index f24639271e977..5c937e579431f 100644
--- a/CI.rst
+++ b/CI.rst
@@ -21,7 +21,7 @@ CI Environment
==============
Continuous Integration is important component of making Apache Airflow robust and stable. We are running
-a lot of tests for every pull request, for main and v2-*-test branches and regularly as CRON jobs.
+a lot of tests for every pull request, for main and v2-*-test branches and regularly as scheduled jobs.
Our execution environment for CI is `GitHub Actions `_. GitHub Actions
(GA) are very well integrated with GitHub code and Workflow and it has evolved fast in 2019/202 to become
@@ -33,7 +33,7 @@ environments we use. Most of our CI jobs are written as bash scripts which are e
the CI jobs. And we have a number of variables determine build behaviour.
You can also take a look at the `CI Sequence Diagrams `_ for more graphical overview
-of how Airlfow's CI works.
+of how Airflow CI works.
GitHub Actions runs
-------------------
@@ -84,288 +84,88 @@ We use `GitHub Container Registry `_ but in essence it is a script that allows
-you to re-create CI environment in your local development instance and interact with it. In its basic
-form, when you do development you can run all the same tests that will be run in CI - but locally,
-before you submit them as PR. Another use case where Breeze is useful is when tests fail on CI. You can
-take the full ``COMMIT_SHA`` of the failed build pass it as ``--github-image-id`` parameter of Breeze and it will
-download the very same version of image that was used in CI and run it locally. This way, you can very
-easily reproduce any failed test that happens in CI - even if you do not check out the sources
-connected with the run.
-You can read more about it in `BREEZE.rst `_ and `TESTING.rst `_
+Naming conventions for stored images
+====================================
-Difference between local runs and GitHub Action workflows
----------------------------------------------------------
+The images produced during the ``Build Images`` workflow of CI jobs are stored in the
+`GitHub Container Registry `_
-Depending whether the scripts are run locally (most often via `Breeze `_) or whether they
-are run in ``Build Images`` or ``Tests`` workflows they can take different values.
+The images are stored with both "latest" tag (for last main push image that passes all the tests as well
+with the COMMIT_SHA id for images that were used in particular build.
-You can use those variables when you try to reproduce the build locally.
+The image names follow the patterns (except the Python image, all the images are stored in
+https://ghcr.io/ in ``apache`` organization.
-+-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+
-| Variable | Local | Build Images | Tests | Comment |
-| | development | CI workflow | Workflow | |
-+=========================================+=============+==============+============+=================================================+
-| Basic variables |
-+-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+
-| ``PYTHON_MAJOR_MINOR_VERSION`` | | | | Major/Minor version of Python used. |
-+-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+
-| ``DB_RESET`` | false | true | true | Determines whether database should be reset |
-| | | | | at the container entry. By default locally |
-| | | | | the database is not reset, which allows to |
-| | | | | keep the database content between runs in |
-| | | | | case of Postgres or MySQL. However, |
-| | | | | it requires to perform manual init/reset |
-| | | | | if you stop the environment. |
-+-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+
-| Mount variables |
-+-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+
-| ``MOUNT_SELECTED_LOCAL_SOURCES`` | true | false | false | Determines whether local sources are |
-| | | | | mounted to inside the container. Useful for |
-| | | | | local development, as changes you make |
-| | | | | locally can be immediately tested in |
-| | | | | the container. We mount only selected, |
-| | | | | important folders. We do not mount the whole |
-| | | | | project folder in order to avoid accidental |
-| | | | | use of artifacts (such as ``egg-info`` |
-| | | | | directories) generated locally on the |
-| | | | | host during development. |
-+-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+
-| ``MOUNT_ALL_LOCAL_SOURCES`` | false | false | false | Determines whether all local sources are |
-| | | | | mounted to inside the container. Useful for |
-| | | | | local development when you need to access .git |
-| | | | | folders and other folders excluded when |
-| | | | | ``MOUNT_SELECTED_LOCAL_SOURCES`` is true. |
-| | | | | You might need to manually delete egg-info |
-| | | | | folder when you enter breeze and the folder was |
-| | | | | generated using different Python versions. |
-+-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+
-| Force variables |
-+-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+
-| ``FORCE_BUILD_IMAGES`` | false | false | false | Forces building images. This is generally not |
-| | | | | very useful in CI as in CI environment image |
-| | | | | is built or pulled only once, so there is no |
-| | | | | need to set the variable to true. For local |
-| | | | | builds it forces rebuild, regardless if it |
-| | | | | is determined to be needed. |
-+-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+
-| ``ANSWER`` | | yes | yes | This variable determines if answer to questions |
-| | | | | during the build process should be |
-| | | | | automatically given. For local development, |
-| | | | | the user is occasionally asked to provide |
-| | | | | answers to questions such as - whether |
-| | | | | the image should be rebuilt. By default |
-| | | | | the user has to answer but in the CI |
-| | | | | environment, we force "yes" answer. |
-+-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+
-| ``SKIP_CHECK_REMOTE_IMAGE`` | false | true | true | Determines whether we check if remote image |
-| | | | | is "fresher" than the current image. |
-| | | | | When doing local breeze runs we try to |
-| | | | | determine if it will be faster to rebuild |
-| | | | | the image or whether the image should be |
-| | | | | pulled first from the cache because it has |
-| | | | | been rebuilt. This is slightly experimental |
-| | | | | feature and will be improved in the future |
-| | | | | as the current mechanism does not always |
-| | | | | work properly. |
-+-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+
-| Host variables |
-+-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+
-| ``HOST_USER_ID`` | | | | User id of the host user. |
-+-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+
-| ``HOST_GROUP_ID`` | | | | Group id of the host user. |
-+-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+
-| ``HOST_OS`` | | Linux | Linux | OS of the Host (Darwin/Linux). |
-+-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+
-| Git variables |
-+-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+
-| ``COMMIT_SHA`` | | GITHUB_SHA | GITHUB_SHA | SHA of the commit of the build is run |
-+-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+
-| Initialization |
-+-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+
-| ``SKIP_ENVIRONMENT_INITIALIZATION`` | false\* | false\* | false\* | Skip initialization of test environment |
-| | | | | |
-| | | | | \* set to true in pre-commits |
-+-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+
-| ``SKIP_SSH_SETUP`` | false\* | false\* | false\* | Skip setting up SSH server for tests. |
-| | | | | |
-| | | | | \* set to true in GitHub CodeSpaces |
-+-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+
-| Verbosity variables |
-+-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+
-| ``PRINT_INFO_FROM_SCRIPTS`` | true\* | true\* | true\* | Allows to print output to terminal from running |
-| | | | | scripts. It prints some extra outputs if true |
-| | | | | including what the commands do, results of some |
-| | | | | operations, summary of variable values, exit |
-| | | | | status from the scripts, outputs of failing |
-| | | | | commands. If verbose is on it also prints the |
-| | | | | commands executed by docker, kind, helm, |
-| | | | | kubectl. Disabled in pre-commit checks. |
-| | | | | |
-| | | | | \* set to false in pre-commits |
-+-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+
-| ``VERBOSE`` | false | true | true | Determines whether docker, helm, kind, |
-| | | | | kubectl commands should be printed before |
-| | | | | execution. This is useful to determine |
-| | | | | what exact commands were executed for |
-| | | | | debugging purpose as well as allows |
-| | | | | to replicate those commands easily by |
-| | | | | copy&pasting them from the output. |
-| | | | | requires ``PRINT_INFO_FROM_SCRIPTS`` set to |
-| | | | | true. |
-+-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+
-| ``VERBOSE_COMMANDS`` | false | false | false | Determines whether every command |
-| | | | | executed in bash should also be printed |
-| | | | | before execution. This is a low-level |
-| | | | | debugging feature of bash (set -x) and |
-| | | | | it should only be used if you are lost |
-| | | | | at where the script failed. |
-+-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+
-| Image build variables |
-+-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+
-| ``UPGRADE_TO_NEWER_DEPENDENCIES`` | false | false | false\* | Determines whether the build should |
-| | | | | attempt to upgrade Python base image and all |
-| | | | | PIP dependencies to latest ones matching |
-| | | | | ``setup.py`` limits. This tries to replicate |
-| | | | | the situation of "fresh" user who just installs |
-| | | | | airflow and uses latest version of matching |
-| | | | | dependencies. By default we are using a |
-| | | | | tested set of dependency constraints |
-| | | | | stored in separated "orphan" branches |
-| | | | | of the airflow repository |
-| | | | | ("constraints-main, "constraints-2-0") |
-| | | | | but when this flag is set to anything but false |
-| | | | | (for example random value), they are not used |
-| | | | | used and "eager" upgrade strategy is used |
-| | | | | when installing dependencies. We set it |
-| | | | | to true in case of direct pushes (merges) |
-| | | | | to main and scheduled builds so that |
-| | | | | the constraints are tested. In those builds, |
-| | | | | in case we determine that the tests pass |
-| | | | | we automatically push latest set of |
-| | | | | "tested" constraints to the repository. |
-| | | | | |
-| | | | | Setting the value to random value is best way |
-| | | | | to assure that constraints are upgraded even if |
-| | | | | there is no change to setup.py |
-| | | | | |
-| | | | | This way our constraints are automatically |
-| | | | | tested and updated whenever new versions |
-| | | | | of libraries are released. |
-| | | | | |
-| | | | | \* true in case of direct pushes and |
-| | | | | scheduled builds |
-+-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+
-| ``CHECK_IMAGE_FOR_REBUILD`` | true | true | true\* | Determines whether attempt should be |
-| | | | | made to rebuild the CI image with latest |
-| | | | | sources. It is true by default for |
-| | | | | local builds, however it is set to |
-| | | | | true in case we know that the image |
-| | | | | we pulled or built already contains |
-| | | | | the right sources. In such case we |
-| | | | | should set it to false, especially |
-| | | | | in case our local sources are not the |
-| | | | | ones we intend to use (for example |
-| | | | | when ``--github-image-id`` is used |
-| | | | | in Breeze. |
-| | | | | |
-| | | | | In CI jobs it is set to true |
-| | | | | in case of the ``Build Images`` |
-| | | | | workflow or when |
-| | | | | waiting for images is disabled |
-| | | | | in the "Tests" workflow. |
-| | | | | |
-| | | | | \* if waiting for images the variable is set |
-| | | | | to false automatically. |
-+-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+
-| ``SKIP_BUILDING_PROD_IMAGE`` | false | false | false\* | Determines whether we should skip building |
-| | | | | the PROD image with latest sources. |
-| | | | | It is set to false, but in deploy app for |
-| | | | | kubernetes step it is set to "true", because at |
-| | | | | this stage we know we have good image build or |
-| | | | | pulled. |
-| | | | | |
-| | | | | \* set to true in "Deploy App to Kubernetes" |
-| | | | | to false automatically. |
-+-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+
+The packages are available under (CONTAINER_NAME is url-encoded name of the image). Note that "/" are
+supported now in the ``ghcr.io`` as apart of the image name within ``apache`` organization, but they
+have to be percent-encoded when you access them via UI (/ = %2F)
-Running CI Jobs locally
-=======================
-
-The scripts and configuration files for CI jobs are all in ``scripts/ci`` - so that in the
-``pull_request_target`` target workflow, we can copy those scripts from the ``main`` branch and use them
-regardless of the changes done in the PR. This way we are kept safe from PRs injecting code into the builds.
-
-* ``build_airflow`` - builds airflow packages
-* ``constraints`` - scripts to build and publish latest set of valid constraints
-* ``docs`` - scripts to build documentation
-* ``images`` - scripts to build and push CI and PROD images
-* ``kubernetes`` - scripts to setup kubernetes cluster, deploy airflow and run kubernetes tests with it
-* ``openapi`` - scripts to run openapi generation
-* ``pre_commit`` - scripts to run pre-commit checks
-* ``provider_packages`` - scripts to build and test provider packages
-* ``static_checks`` - scripts to run static checks manually
-* ``testing`` - scripts that run unit and integration tests
-* ``tools`` - scripts that can be used for various clean-up and preparation tasks
-
-Common libraries of functions for all the scripts can be found in ``libraries`` folder. The ``dockerfiles``,
-``mysql.d``, ``openldap``, ``spectral_rules`` folders contains DockerFiles and configuration of integrations
-needed to run tests.
-
-For detailed use of those scripts you can refer to ``.github/workflows/`` - those scripts are used
-by the CI workflows of ours. There are some variables that you can set to change the behaviour of the
-scripts.
-
-The default values are "sane" you can change them to interact with your own repositories or registries.
-Note that you need to set "CI" variable to true in order to get the same results as in CI.
-
-+------------------------------+----------------------+-----------------------------------------------------+
-| Variable | Default | Comment |
-+==============================+======================+=====================================================+
-| CI | ``false`` | If set to "true", we simulate behaviour of |
-| | | all scripts as if they are in CI environment |
-+------------------------------+----------------------+-----------------------------------------------------+
-| CI_TARGET_REPO | ``apache/airflow`` | Target repository for the CI job. Used to |
-| | | compare incoming changes from PR with the target. |
-+------------------------------+----------------------+-----------------------------------------------------+
-| CI_TARGET_BRANCH | ``main`` | Target branch where the PR should land. Used to |
-| | | compare incoming changes from PR with the target. |
-+------------------------------+----------------------+-----------------------------------------------------+
-| CI_BUILD_ID | ``0`` | Unique id of the build that is kept across re runs |
-| | | (for GitHub actions it is ``GITHUB_RUN_ID``) |
-+------------------------------+----------------------+-----------------------------------------------------+
-| CI_JOB_ID | ``0`` | Unique id of the job - used to produce unique |
-| | | artifact names. |
-+------------------------------+----------------------+-----------------------------------------------------+
-| CI_EVENT_TYPE | ``pull_request`` | Type of the event. It can be one of |
-| | | [``pull_request``, ``pull_request_target``, |
-| | | ``schedule``, ``push``] |
-+------------------------------+----------------------+-----------------------------------------------------+
-| CI_REF | ``refs/head/main`` | Branch in the source repository that is used to |
-| | | make the pull request. |
-+------------------------------+----------------------+-----------------------------------------------------+
+``https://github.com/apache/airflow/pkgs/container/``
++--------------+----------------------------------------------------------+----------------------------------------------------------+
+| Image | Name:tag (both cases latest version and per-build) | Description |
++==============+==========================================================+==========================================================+
+| Python image | python:-slim-bullseye | Base Python image used by both production and CI image. |
+| (DockerHub) | | Python maintainer release new versions of those image |
+| | | with security fixes every few weeks in DockerHub. |
++--------------+----------------------------------------------------------+----------------------------------------------------------+
+| Airflow | airflow//python:-slim-bullseye | Version of python base image used in Airflow Builds |
+| python base | | We keep the "latest" version only to mark last "good" |
+| image | | python base that went through testing and was pushed. |
++--------------+----------------------------------------------------------+----------------------------------------------------------+
+| PROD Build | airflow//prod-build/python:latest | Production Build image - this is the "build" stage of |
+| image | | production image. It contains build-essentials and all |
+| | | necessary apt packages to build/install PIP packages. |
+| | | We keep the "latest" version only to speed up builds. |
++--------------+----------------------------------------------------------+----------------------------------------------------------+
+| Manifest | airflow//ci-manifest/python:latest | CI manifest image - this is the image used to optimize |
+| CI image | | pulls and builds for Breeze development environment |
+| | | They store hash indicating whether the image will be |
+| | | faster to build or pull. |
+| | | We keep the "latest" version only to help breeze to |
+| | | check if new image should be pulled. |
++--------------+----------------------------------------------------------+----------------------------------------------------------+
+| CI image | airflow//ci/python:latest | CI image - this is the image used for most of the tests. |
+| | or | Contains all provider dependencies and tools useful |
+| | airflow//ci/python: | For testing. This image is used in Breeze. |
++--------------+----------------------------------------------------------+----------------------------------------------------------+
+| | | faster to build or pull. |
+| PROD image | airflow//prod/python:latest | Production image. This is the actual production image |
+| | or | optimized for size. |
+| | airflow//prod/python: | It contains only compiled libraries and minimal set of |
+| | | dependencies to run Airflow. |
++--------------+----------------------------------------------------------+----------------------------------------------------------+
+
+* might be either "main" or "v2-*-test"
+* - Python version (Major + Minor).Should be one of ["3.7", "3.8", "3.9"].
+* - full-length SHA of commit either from the tip of the branch (for pushes/schedule) or
+ commit from the tip of the branch used for the PR.
GitHub Registry Variables
=========================
-Our CI uses GitHub Registry to pull and push images to/from by default. You can use your own repo by changing
-``GITHUB_REPOSITORY`` and providing your own GitHub Username and Token.
+Our CI uses GitHub Registry to pull and push images to/from by default. Those variables are set automatically
+by GitHub Actions when you run Airflow workflows in your fork, so they should automatically use your
+own repository as GitHub Registry to build and keep the images as build image cache.
+
+The variables are automatically set in GitHub actions
+--------------------------------+---------------------------+----------------------------------------------+
| Variable | Default | Comment |
+================================+===========================+==============================================+
| GITHUB_REPOSITORY | ``apache/airflow`` | Prefix of the image. It indicates which. |
-| | | registry from GitHub to use |
+| | | registry from GitHub to use for image cache |
+| | | and to determine the name of the image. |
++--------------------------------+---------------------------+----------------------------------------------+
+| CONSTRAINTS_GITHUB_REPOSITORY | ``apache/airflow`` | Repository where constraints are stored |
+--------------------------------+---------------------------+----------------------------------------------+
| GITHUB_USERNAME | | Username to use to login to GitHub |
| | | |
@@ -373,37 +173,37 @@ Our CI uses GitHub Registry to pull and push images to/from by default. You can
| GITHUB_TOKEN | | Token to use to login to GitHub. |
| | | Only used when pushing images on CI. |
+--------------------------------+---------------------------+----------------------------------------------+
-| GITHUB_REGISTRY_PULL_IMAGE_TAG | ``latest`` | Pull this image tag. This is "latest" by |
-| | | default, can also be full-length commit SHA. |
-+--------------------------------+---------------------------+----------------------------------------------+
-| GITHUB_REGISTRY_PUSH_IMAGE_TAG | ``latest`` | Push this image tag. This is "latest" by |
-| | | default, can also be full-length commit SHA. |
-+--------------------------------+---------------------------+----------------------------------------------+
+
+The Variables beginning with ``GITHUB_`` cannot be overridden in GitHub Actions by the workflow.
+Those variables are set by GitHub Actions automatically and they are immutable. Therefore if
+you want to override them in your own CI workflow and use ``breeze``, you need to pass the
+values by corresponding ``breeze`` flags ``--github-repository``, ``--github-username``,
+``--github-token`` rather than by setting them as environment variables in your workflow.
+Unless you want to keep your own copy of constraints in orphaned ``constraints-*``
+branches, the ``CONSTRAINTS_GITHUB_REPOSITORY`` should remain ``apache/airflow``, regardless in which
+repository the CI job is run.
+
+One of the variables you might want to override in your own GitHub Actions workflow when using ``breeze`` is
+``--github-repository`` - you might want to force it to ``apache/airflow``, because then the cache from
+``apache/airflow`` repository will be used and your builds will be much faster.
+
+Example command to build your CI image efficiently in your own CI workflow:
+
+.. code-block:: bash
+
+ # GITHUB_REPOSITORY is set automatically in Github Actions so we need to override it with flag
+ #
+ breeze ci-image build --github-repository apache/airflow --python 3.10
+ docker tag ghcr.io/apache/airflow/main/ci/python3.10 your-image-name:tag
+
Authentication in GitHub Registry
=================================
We are using GitHub Container Registry as cache for our images. Authentication uses GITHUB_TOKEN mechanism.
Authentication is needed for pushing the images (WRITE) only in "push", "pull_request_target" workflows.
+When you are running the CI jobs in GitHub Actions, GITHUB_TOKEN is set automatically by the actions.
-CI Architecture
-===============
-
-The following components are part of the CI infrastructure
-
-* **Apache Airflow Code Repository** - our code repository at https://github.com/apache/airflow
-* **Apache Airflow Forks** - forks of the Apache Airflow Code Repository from which contributors make
- Pull Requests
-* **GitHub Actions** - (GA) UI + execution engine for our jobs
-* **GA CRON trigger** - GitHub Actions CRON triggering our jobs
-* **GA Workers** - virtual machines running our jobs at GitHub Actions (max 20 in parallel)
-* **GitHub Image Registry** - image registry used as build cache for CI jobs.
- It is at https://ghcr.io/apache/airflow
-* **DockerHub Image Registry** - image registry used to pull base Python images and (manually) publish
- the released Production Airflow images. It is at https://dockerhub.com/apache/airflow
-* **Official Images** (future) - these are official images that are prominently visible in DockerHub.
- We aim our images to become official images so that you will be able to pull them
- with ``docker pull apache-airflow``
CI run types
============
@@ -411,6 +211,12 @@ CI run types
The following CI Job run types are currently run for Apache Airflow (run by ci.yaml workflow)
and each of the run types has different purpose and context.
+Besides the regular "PR" runs we also have "Canary" runs that are able to detect most of the
+problems that might impact regular PRs early, without necessarily failing all PRs when those
+problems happen. This allows to provide much more stable environment for contributors, who
+contribute their PR, while giving a chance to maintainers to react early on problems that
+need reaction, when the "canary" builds fail.
+
Pull request run
----------------
@@ -426,33 +232,37 @@ CI, Production Images as well as base Python images that are also cached in the
Also for those builds we only execute Python tests if important files changed (so for example if it is
"no-code" change, no tests will be executed.
-The workflow involved in Pull Requests review and approval is a bit more complex than simple workflows
-in most of other projects because we've implemented some optimizations related to efficient use
-of queue slots we share with other Apache Software Foundation projects. More details about it
-can be found in `PULL_REQUEST_WORKFLOW.rst `_.
+Regular PR builds run in a "stable" environment:
+* fixed set of constraints (constraints that passed the tests) - except the PRs that change dependencies
+* limited matrix and set of tests (determined by selective checks based on what changed in the PR)
+* no ARM image builds are build in the regular PRs
+* lower probability of flaky tests for non-committer PRs (public runners and less parallelism)
-Direct Push/Merge Run
----------------------
+Canary run
+----------
-Those runs are results of direct pushes done by the committers or as result of merge of a Pull Request
+Those runs are results of direct pushes done by the committers - basically merging of a Pull Request
by the committers. Those runs execute in the context of the Apache Airflow Code Repository and have also
write permission for GitHub resources (container registry, code repository).
+
The main purpose for the run is to check if the code after merge still holds all the assertions - like
-whether it still builds, all tests are green.
+whether it still builds, all tests are green. This is a "Canary" build that helps us to detect early
+problems with dependencies, image building, full matrix of tests in case they passed through selective checks.
This is needed because some of the conflicting changes from multiple PRs might cause build and test failures
after merge even if they do not fail in isolation. Also those runs are already reviewed and confirmed by the
committers so they can be used to do some housekeeping:
-- pushing most recent image build in the PR to the GitHub Container Registry (for caching)
+
+- pushing most recent image build in the PR to the GitHub Container Registry (for caching) including recent
+ Dockerfile changes and setup.py/setup.cfg changes (Early Cache)
+- test that image in ``breeze`` command builds quickly
+- run full matrix of tests to detect any tests that will be mistakenly missed in ``selective checks``
- upgrading to latest constraints and pushing those constraints if all tests succeed
- refresh latest Python base images in case new patch-level is released
The housekeeping is important - Python base images are refreshed with varying frequency (once every few months
usually but sometimes several times per week) with the latest security and bug fixes.
-Those patch level images releases can occasionally break Airflow builds (specifically Docker image builds
-based on those images) therefore in PRs we only use latest "good" Python image that we store in the
-GitHub Container Registry and those push requests will refresh the latest images if they changed.
Scheduled runs
--------------
@@ -528,51 +338,50 @@ Tests Workflow
This workflow is a regular workflow that performs all checks of Airflow code.
-+---------------------------+----------------------------------------------+-------+-------+------+
-| Job | Description | PR | Push | CRON |
-| | | | Merge | (1) |
-+===========================+==============================================+=======+=======+======+
-| Build info | Prints detailed information about the build | Yes | Yes | Yes |
-+---------------------------+----------------------------------------------+-------+-------+------+
-| Test OpenAPI client gen | Tests if OpenAPIClient continues to generate | Yes | Yes | Yes |
-+---------------------------+----------------------------------------------+-------+-------+------+
-| UI tests | React UI tests for new Airflow UI | Yes | Yes | Yes |
-+---------------------------+----------------------------------------------+-------+-------+------+
-| WWW tests | React tests for current Airflow UI | Yes | Yes | Yes |
-+---------------------------+----------------------------------------------+-------+-------+------+
-| Test image building | Tests if PROD image build examples work | Yes | Yes | Yes |
-+---------------------------+----------------------------------------------+-------+-------+------+
-| CI Images | Waits for and verify CI Images (3) | Yes | Yes | Yes |
-+---------------------------+----------------------------------------------+-------+-------+------+
-| (Basic) Static checks | Performs static checks (full or basic) | Yes | Yes | Yes |
-+---------------------------+----------------------------------------------+-------+-------+------+
-| Build docs | Builds documentation | Yes | Yes | Yes |
-+---------------------------+----------------------------------------------+-------+-------+------+
-| Tests | Run all the Pytest tests for Python code | Yes(2)| Yes | Yes |
-+---------------------------+----------------------------------------------+-------+-------+------+
-| Tests provider packages | Tests if provider packages work | Yes | Yes | Yes |
-+---------------------------+----------------------------------------------+-------+-------+------+
-| Upload coverage | Uploads test coverage from all the tests | - | Yes | - |
-+---------------------------+----------------------------------------------+-------+-------+------+
-| PROD Images | Waits for and verify PROD Images (3) | Yes | Yes | Yes |
-+---------------------------+----------------------------------------------+-------+-------+------+
-| Tests Kubernetes | Run Kubernetes test | Yes(2)| Yes | Yes |
-+---------------------------+----------------------------------------------+-------+-------+------+
-| Constraints | Upgrade constraints to latest ones (4) | - | Yes | Yes |
-+---------------------------+----------------------------------------------+-------+-------+------+
-| Push images | Pushes latest images to GitHub Registry (4) | - | Yes | Yes |
-+---------------------------+----------------------------------------------+-------+-------+------+
-
-
-Comments:
-
- (1) CRON jobs builds images from scratch - to test if everything works properly for clean builds
- (2) The tests are run when the Trigger Tests job determine that important files change (this allows
- for example "no-code" changes to build much faster)
- (3) The jobs wait for CI images to be available.
- (4) PROD and CI images are pushed as "latest" to GitHub Container registry and constraints are upgraded
- only if all tests are successful. The images are rebuilt in this step using constraints pushed
- in the previous step.
++-----------------------------+----------------------------------------------------------+---------+----------+-----------+
+| Job | Description | PR | Canary | Scheduled |
++=============================+==========================================================+=========+==========+===========+
+| Build info | Prints detailed information about the build | Yes | Yes | Yes |
++-----------------------------+----------------------------------------------------------+---------+----------+-----------+
+| Build CI/PROD images | Builds images in-workflow (not in the build images one) | - | Yes | Yes (1) |
++-----------------------------+----------------------------------------------------------+---------+----------+-----------+
+| Push early cache & images | Pushes early cache/images to GitHub Registry and test | - | Yes | - |
+| | speed of building breeze images from scratch | | | |
++-----------------------------+----------------------------------------------------------+---------+----------+-----------+
+| Test OpenAPI client gen | Tests if OpenAPIClient continues to generate | Yes | Yes | Yes |
++-----------------------------+----------------------------------------------------------+---------+----------+-----------+
+| UI tests | React UI tests for new Airflow UI | Yes | Yes | Yes |
++-----------------------------+----------------------------------------------------------+---------+----------+-----------+
+| Test image building | Tests if PROD image build examples work | Yes | Yes | Yes |
++-----------------------------+----------------------------------------------------------+---------+----------+-----------+
+| CI Images | Waits for and verify CI Images (2) | Yes | Yes | Yes |
++-----------------------------+----------------------------------------------------------+---------+----------+-----------+
+| (Basic) Static checks | Performs static checks (full or basic) | Yes | Yes | Yes |
++-----------------------------+----------------------------------------------------------+---------+----------+-----------+
+| Build docs | Builds documentation | Yes | Yes | Yes |
++-----------------------------+----------------------------------------------------------+---------+----------+-----------+
+| Tests | Run all the Pytest tests for Python code | Yes | Yes | Yes |
++-----------------------------+----------------------------------------------------------+---------+----------+-----------+
+| Tests provider packages | Tests if provider packages work | Yes | Yes | Yes |
++-----------------------------+----------------------------------------------------------+---------+----------+-----------+
+| Upload coverage | Uploads test coverage from all the tests | - | Yes | - |
++-----------------------------+----------------------------------------------------------+---------+----------+-----------+
+| PROD Images | Waits for and verify PROD Images (2) | Yes | Yes | Yes |
++-----------------------------+----------------------------------------------------------+---------+----------+-----------+
+| Tests Kubernetes | Run Kubernetes test | Yes | Yes | Yes |
++-----------------------------+----------------------------------------------------------+---------+----------+-----------+
+| Constraints | Upgrade constraints to latest ones (3) | - | Yes | Yes |
++-----------------------------+----------------------------------------------------------+---------+----------+-----------+
+| Push cache & images | Pushes cache/images to GitHub Registry (3) | - | Yes | Yes |
++-----------------------------+----------------------------------------------------------+---------+----------+-----------+
+
+``(1)`` Scheduled jobs builds images from scratch - to test if everything works properly for clean builds
+
+``(2)`` The jobs wait for CI images to be available.
+
+``(3)`` PROD and CI cache & images are pushed as "latest" to GitHub Container registry and constraints are
+upgraded only if all tests are successful. The images are rebuilt in this step using constraints pushed
+in the previous step.
CodeQL scan
-----------
@@ -595,65 +404,60 @@ For more information, see:
Website endpoint: http://apache-airflow-docs.s3-website.eu-central-1.amazonaws.com/
-Naming conventions for stored images
-====================================
-The images produced during the ``Build Images`` workflow of CI jobs are stored in the
-`GitHub Container Registry `_
+Debugging CI Jobs in Github Actions
+===================================
-The images are stored with both "latest" tag (for last main push image that passes all the tests as well
-with the COMMIT_SHA id for images that were used in particular build.
+The CI jobs are notoriously difficult to test, because you can only really see results of it when you run them
+in CI environment, and the environment in which they run depend on who runs them (they might be either run
+in our Self-Hosted runners (with 64 GB RAM 8 CPUs) or in the GitHub Public runners (6 GB of RAM, 2 CPUs) and
+the results will vastly differ depending on which environment is used. We are utilizing parallelism to make
+use of all the available CPU/Memory but sometimes you need to enable debugging and force certain environments.
+Additional difficulty is that ``Build Images`` workflow is ``pull-request-target`` type, which means that it
+will always run using the ``main`` version - no matter what is in your Pull Request.
-The image names follow the patterns (except the Python image, all the images are stored in
-https://ghcr.io/ in ``apache`` organization.
+There are several ways how you can debug the CI jobs when you are maintainer.
-The packages are available under (CONTAINER_NAME is url-encoded name of the image). Note that "/" are
-supported now in the ``ghcr.io`` as apart of the image name within ``apache`` organization, but they
-have to be percent-encoded when you access them via UI (/ = %2F)
+* When you want to tests the build with all combinations of all python, backends etc on regular PR,
+ add ``full tests needed`` label to the PR.
+* When you want to test maintainer PR using public runners, add ``public runners`` label to the PR
+* When you want to see resources used by the run, add ``debug ci resources`` label to the PR
+* When you want to test changes to breeze that include changes to how images are build you should push
+ your PR to ``apache`` repository not to your fork. This will run the images as part of the ``CI`` workflow
+ rather than using ``Build images`` workflow and use the same breeze version for building image and testing
+* When you want to test changes to ``build-images.yml`` workflow you should push your branch as ``main``
+ branch in your local fork. This will run changed ``build-images.yml`` workflow as it will be in ``main``
+ branch of your fork
-``https://github.com/apache/airflow/pkgs/container/``
+Replicating the CI Jobs locally
+===============================
-+--------------+----------------------------------------------------------+----------------------------------------------------------+
-| Image | Name:tag (both cases latest version and per-build) | Description |
-+==============+==========================================================+==========================================================+
-| Python image | python:-slim-bullseye | Base Python image used by both production and CI image. |
-| (DockerHub) | | Python maintainer release new versions of those image |
-| | | with security fixes every few weeks in DockerHub. |
-+--------------+----------------------------------------------------------+----------------------------------------------------------+
-| Airflow | airflow//python:-slim-bullseye | Version of python base image used in Airflow Builds |
-| python base | | We keep the "latest" version only to mark last "good" |
-| image | | python base that went through testing and was pushed. |
-+--------------+----------------------------------------------------------+----------------------------------------------------------+
-| PROD Build | airflow//prod-build/python:latest | Production Build image - this is the "build" stage of |
-| image | | production image. It contains build-essentials and all |
-| | | necessary apt packages to build/install PIP packages. |
-| | | We keep the "latest" version only to speed up builds. |
-+--------------+----------------------------------------------------------+----------------------------------------------------------+
-| Manifest | airflow//ci-manifest/python:latest | CI manifest image - this is the image used to optimize |
-| CI image | | pulls and builds for Breeze development environment |
-| | | They store hash indicating whether the image will be |
-| | | faster to build or pull. |
-| | | We keep the "latest" version only to help breeze to |
-| | | check if new image should be pulled. |
-+--------------+----------------------------------------------------------+----------------------------------------------------------+
-| CI image | airflow//ci/python:latest | CI image - this is the image used for most of the tests. |
-| | or | Contains all provider dependencies and tools useful |
-| | airflow//ci/python: | For testing. This image is used in Breeze. |
-+--------------+----------------------------------------------------------+----------------------------------------------------------+
-| | | faster to build or pull. |
-| PROD image | airflow//prod/python:latest | Production image. This is the actual production image |
-| | or | optimized for size. |
-| | airflow//prod/python: | It contains only compiled libraries and minimal set of |
-| | | dependencies to run Airflow. |
-+--------------+----------------------------------------------------------+----------------------------------------------------------+
+The main goal of the CI philosophy we have that no matter how complex the test and integration
+infrastructure, as a developer you should be able to reproduce and re-run any of the failed checks
+locally. One part of it are pre-commit checks, that allow you to run the same static checks in CI
+and locally, but another part is the CI environment which is replicated locally with Breeze.
-* might be either "main" or "v2-*-test"
-* - Python version (Major + Minor).Should be one of ["3.7", "3.8", "3.9"].
-* - full-length SHA of commit either from the tip of the branch (for pushes/schedule) or
- commit from the tip of the branch used for the PR.
+You can read more about Breeze in `BREEZE.rst `_ but in essence it is a script that allows
+you to re-create CI environment in your local development instance and interact with it. In its basic
+form, when you do development you can run all the same tests that will be run in CI - but locally,
+before you submit them as PR. Another use case where Breeze is useful is when tests fail on CI. You can
+take the full ``COMMIT_SHA`` of the failed build pass it as ``--image-tag`` parameter of Breeze and it will
+download the very same version of image that was used in CI and run it locally. This way, you can very
+easily reproduce any failed test that happens in CI - even if you do not check out the sources
+connected with the run.
+
+All our CI jobs are executed via ``breeze`` commands. You can replicate exactly what our CI is doing
+by running the sequence of corresponding ``breeze`` command. Make sure however that you look at both:
-Reproducing CI Runs locally
-===========================
+* flags passed to ``breeze`` commands
+* environment variables used when ``breeze`` command is run - this is useful when we want
+ to set a common flag for all ``breeze`` commands in the same job or even the whole workflow. For
+ example ``VERBOSE`` variable is set to ``true`` for all our workflows so that more detailed information
+ about internal commands executed in CI is printed.
+
+In the output of the CI jobs, you will find both - the flags passed and environment variables set.
+
+You can read more about it in `BREEZE.rst `_ and `TESTING.rst `_
Since we store images from every CI run, you should be able easily reproduce any of the CI tests problems
locally. You can do it by pulling and using the right image and running it with the right docker command,
@@ -668,12 +472,11 @@ For example knowing that the CI job was for commit ``cd27124534b46c9688a1d89e75f
But you usually need to pass more variables and complex setup if you want to connect to a database or
enable some integrations. Therefore it is easiest to use `Breeze `_ for that. For example if
-you need to reproduce a MySQL environment with kerberos integration enabled for commit
-cd27124534b46c9688a1d89e75fcd137ab5137e3, in python 3.8 environment you can run:
+you need to reproduce a MySQL environment in python 3.8 environment you can run:
.. code-block:: bash
- ./breeze-legacy --github-image-id cd27124534b46c9688a1d89e75fcd137ab5137e3 --python 3.8
+ breeze --image-tag cd27124534b46c9688a1d89e75fcd137ab5137e3 --python 3.8 --backend mysql
You will be dropped into a shell with the exact version that was used during the CI run and you will
be able to run pytest tests manually, easily reproducing the environment that was used in CI. Note that in
@@ -681,26 +484,117 @@ this case, you do not need to checkout the sources that were used for that run -
the image - but remember that any changes you make in those sources are lost when you leave the image as
the sources are not mapped from your host machine.
+Depending whether the scripts are run locally via `Breeze `_ or whether they
+are run in ``Build Images`` or ``Tests`` workflows they can take different values.
-Adding new Python versions to CI
---------------------------------
+You can use those variables when you try to reproduce the build locally (alternatively you can pass
+those via command line flags passed to ``breeze`` command.
-In the ``main`` branch of development line we currently support Python 3.7, 3.8, 3.9.
++-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+
+| Variable | Local | Build Images | CI | Comment |
+| | development | workflow | Workflow | |
++=========================================+=============+==============+============+=================================================+
+| Basic variables |
++-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+
+| ``PYTHON_MAJOR_MINOR_VERSION`` | | | | Major/Minor version of Python used. |
++-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+
+| ``DB_RESET`` | false | true | true | Determines whether database should be reset |
+| | | | | at the container entry. By default locally |
+| | | | | the database is not reset, which allows to |
+| | | | | keep the database content between runs in |
+| | | | | case of Postgres or MySQL. However, |
+| | | | | it requires to perform manual init/reset |
+| | | | | if you stop the environment. |
++-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+
+| Forcing answer |
++-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+
+| ``ANSWER`` | | yes | yes | This variable determines if answer to questions |
+| | | | | during the build process should be |
+| | | | | automatically given. For local development, |
+| | | | | the user is occasionally asked to provide |
+| | | | | answers to questions such as - whether |
+| | | | | the image should be rebuilt. By default |
+| | | | | the user has to answer but in the CI |
+| | | | | environment, we force "yes" answer. |
++-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+
+| Host variables |
++-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+
+| ``HOST_USER_ID`` | | | | User id of the host user. |
++-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+
+| ``HOST_GROUP_ID`` | | | | Group id of the host user. |
++-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+
+| ``HOST_OS`` | | linux | linux | OS of the Host (darwin/linux/windows). |
++-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+
+| Git variables |
++-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+
+| ``COMMIT_SHA`` | | GITHUB_SHA | GITHUB_SHA | SHA of the commit of the build is run |
++-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+
+| In container environment initialization |
++-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+
+| ``SKIP_ENVIRONMENT_INITIALIZATION`` | false\* | false\* | false\* | Skip initialization of test environment |
+| | | | | |
+| | | | | \* set to true in pre-commits |
++-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+
+| ``SKIP_SSH_SETUP`` | false\* | false\* | false\* | Skip setting up SSH server for tests. |
+| | | | | |
+| | | | | \* set to true in GitHub CodeSpaces |
++-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+
+| ``VERBOSE_COMMANDS`` | false | false | false | Determines whether every command |
+| | | | | executed in docker should also be printed |
+| | | | | before execution. This is a low-level |
+| | | | | debugging feature of bash (set -x) enabled in |
+| | | | | entrypoint and it should only be used if you |
+| | | | | need to debug the bash scripts in container. |
++-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+
+| Image build variables |
++-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+
+| ``UPGRADE_TO_NEWER_DEPENDENCIES`` | false | false | false\* | Determines whether the build should |
+| | | | | attempt to upgrade Python base image and all |
+| | | | | PIP dependencies to latest ones matching |
+| | | | | ``setup.py`` limits. This tries to replicate |
+| | | | | the situation of "fresh" user who just installs |
+| | | | | airflow and uses latest version of matching |
+| | | | | dependencies. By default we are using a |
+| | | | | tested set of dependency constraints |
+| | | | | stored in separated "orphan" branches |
+| | | | | of the airflow repository |
+| | | | | ("constraints-main, "constraints-2-0") |
+| | | | | but when this flag is set to anything but false |
+| | | | | (for example random value), they are not used |
+| | | | | used and "eager" upgrade strategy is used |
+| | | | | when installing dependencies. We set it |
+| | | | | to true in case of direct pushes (merges) |
+| | | | | to main and scheduled builds so that |
+| | | | | the constraints are tested. In those builds, |
+| | | | | in case we determine that the tests pass |
+| | | | | we automatically push latest set of |
+| | | | | "tested" constraints to the repository. |
+| | | | | |
+| | | | | Setting the value to random value is best way |
+| | | | | to assure that constraints are upgraded even if |
+| | | | | there is no change to setup.py |
+| | | | | |
+| | | | | This way our constraints are automatically |
+| | | | | tested and updated whenever new versions |
+| | | | | of libraries are released. |
+| | | | | |
+| | | | | \* true in case of direct pushes and |
+| | | | | scheduled builds |
++-----------------------------------------+-------------+--------------+------------+-------------------------------------------------+
+
+Adding new Python versions to CI
+================================
In order to add a new version the following operations should be done (example uses Python 3.10)
* copy the latest constraints in ``constraints-main`` branch from previous versions and name it
using the new Python version (``constraints-3.10.txt``). Commit and push
-* add the new Python version to `breeze-complete `_ and
- `_initialization.sh `_ - tests will fail if they are not
- in sync.
-
* build image locally for both prod and CI locally using Breeze:
.. code-block:: bash
- breeze build-image --python 3.10
+ breeze ci-image build --python 3.10
* Find the 2 new images (prod, ci) created in
`GitHub Container registry `_
diff --git a/CI_DIAGRAMS.md b/CI_DIAGRAMS.md
index 3b228b13c71b7..e19228a02444d 100644
--- a/CI_DIAGRAMS.md
+++ b/CI_DIAGRAMS.md
@@ -38,7 +38,6 @@ sequenceDiagram
Note over Tests: Skip Build
(Runs in 'Build Images')
CI Images
Note over Tests: Skip Build
(Runs in 'Build Images')
PROD Images
par
- GitHub Registry ->> Build Images: Pull CI Images
[latest]
Note over Build Images: Build CI Images
[COMMIT_SHA]
Use latest constraints
or upgrade if setup changed
and
Note over Tests: OpenAPI client gen
@@ -121,7 +120,6 @@ sequenceDiagram
deactivate Build Images
Note over Tests: Build info
Decide on tests
Decide on Matrix (selective)
par
- GitHub Registry ->> Tests: Pull CI Images
[latest]
Note over Tests: Build CI Images
[COMMIT_SHA]
Use latest constraints
or upgrade if setup changed
and
Note over Tests: OpenAPI client gen
@@ -134,7 +132,6 @@ sequenceDiagram
GitHub Registry ->> Tests: Pull CI Images
[COMMIT_SHA]
Note over Tests: Verify CI Image
[COMMIT_SHA]
par
- GitHub Registry ->> Tests: Pull PROD Images
[latest]
Note over Tests: Build PROD Images
[COMMIT_SHA]
and
opt
@@ -181,18 +178,17 @@ sequenceDiagram
deactivate Tests
```
-## Direct Push/Merge flow
+## Merge "Canary" run
```mermaid
sequenceDiagram
- Note over Airflow Repo: pull request
+ Note over Airflow Repo: push/merge
Note over Tests: push
[Write Token]
activate Airflow Repo
Airflow Repo -->> Tests: Trigger 'push'
activate Tests
Note over Tests: Build info
All tests
Full matrix
par
- GitHub Registry ->> Tests: Pull CI Images
[latest]
Note over Tests: Build CI Images
[COMMIT_SHA]
Always upgrade deps
and
Note over Tests: OpenAPI client gen
@@ -200,12 +196,15 @@ sequenceDiagram
Note over Tests: Test UI
and
Note over Tests: Test examples
PROD image building
+ and
+ Note over Tests: Build CI Images
Use original constraints
+ Tests ->> GitHub Registry: Push CI Image Early cache + latest
+ Note over Tests: Test 'breeze' image build quickly
end
Tests ->> GitHub Registry: Push CI Images
[COMMIT_SHA]
GitHub Registry ->> Tests: Pull CI Images
[COMMIT_SHA]
Note over Tests: Verify CI Image
[COMMIT_SHA]
par
- GitHub Registry ->> Tests: Pull PROD Images
[latest]
Note over Tests: Build PROD Images
[COMMIT_SHA]
and
opt
@@ -249,11 +248,9 @@ sequenceDiagram
Tests ->> Airflow Repo: Push constraints if changed
end
opt In merge run?
- GitHub Registry ->> Tests: Pull CI Image
[latest]
Note over Tests: Build CI Images
[latest]
Use latest constraints
Tests ->> GitHub Registry: Push CI Image
[latest]
- GitHub Registry ->> Tests: Pull PROD Image
[latest]
- Note over Tests: Build PROD Images
[latest]
+ Note over Tests: Build PROD Images
[latest]
Use latest constraints
Tests ->> GitHub Registry: Push PROD Image
[latest]
end
Tests -->> Airflow Repo: Status update
@@ -261,7 +258,7 @@ sequenceDiagram
deactivate Tests
```
-## Scheduled build flow
+## Scheduled run
```mermaid
sequenceDiagram
@@ -280,6 +277,10 @@ sequenceDiagram
Note over Tests: Test UI
and
Note over Tests: Test examples
PROD image building
+ and
+ Note over Tests: Build CI Images
Use original constraints
+ Tests ->> GitHub Registry: Push CI Image Early cache + latest
+ Note over Tests: Test 'breeze' image build quickly
end
Tests ->> GitHub Registry: Push CI Images
[COMMIT_SHA]
GitHub Registry ->> Tests: Pull CI Images
[COMMIT_SHA]
@@ -326,12 +327,10 @@ sequenceDiagram
end
Note over Tests: Generate constraints
Tests ->> Airflow Repo: Push constraints if changed
- GitHub Registry ->> Tests: Pull CI Image
[latest]
Note over Tests: Build CI Images
[latest]
Use latest constraints
- Tests ->> GitHub Registry: Push CI Image
[latest]
- GitHub Registry ->> Tests: Pull PROD Image
[latest]
- Note over Tests: Build PROD Images
[latest]
- Tests ->> GitHub Registry: Push PROD Image
[latest]
+ Tests ->> GitHub Registry: Push CI Image cache + latest
+ Note over Tests: Build PROD Images
[latest]
Use latest constraints
+ Tests ->> GitHub Registry: Push PROD Image cache + latest
Tests -->> Airflow Repo: Status update
deactivate Airflow Repo
deactivate Tests
diff --git a/COMMITTERS.rst b/COMMITTERS.rst
index 054988407cb60..3d05f2ed197e0 100644
--- a/COMMITTERS.rst
+++ b/COMMITTERS.rst
@@ -27,7 +27,7 @@ Before reading this document, you should be familiar with `Contributor's guide <
Guidelines to become an Airflow Committer
------------------------------------------
-Committers are community members who have write access to the project’s
+Committers are community members who have write access to the project's
repositories, i.e., they can modify the code, documentation, and website by themselves and also
accept other contributions. There is no strict protocol for becoming a committer. Candidates for new
committers are typically people that are active contributors and community members.
@@ -75,9 +75,10 @@ Code contribution
Community contributions
^^^^^^^^^^^^^^^^^^^^^^^^
-1. Was instrumental in triaging issues
-2. Improved documentation of Airflow in significant way
-3. Lead change and improvements introduction in the “community” processes and tools
+1. Actively participates in `triaging issues `_ showing their understanding
+ of various areas of Airflow and willingness to help other community members.
+2. Improves documentation of Airflow in significant way
+3. Leads/implements changes and improvements introduction in the "community" processes and tools
4. Actively spreads the word about Airflow, for example organising Airflow summit, workshops for
community members, giving and recording talks, writing blogs
5. Reporting bugs with detailed reproduction steps
@@ -186,3 +187,4 @@ To be able to merge PRs, committers have to integrate their GitHub ID with Apach
3. Merge your Apache and GitHub accounts using `GitBox (Apache Account Linking utility) `__. You should see 3 green checks in GitBox.
4. Wait at least 30 minutes for an email inviting you to Apache GitHub Organization and accept invitation.
5. After accepting the GitHub Invitation verify that you are a member of the `Airflow committers team on GitHub `__.
+6. Ask in ``#internal-airflow-ci-cd`` channel to be `configured in self-hosted runners `_ by the CI maintainers
diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst
index 8bab15352577a..a5e00305c1297 100644
--- a/CONTRIBUTING.rst
+++ b/CONTRIBUTING.rst
@@ -134,7 +134,7 @@ and guidelines.
Committers/Maintainers
----------------------
-Committers are community members that have write access to the project’s repositories, i.e., they can modify the code,
+Committers are community members that have write access to the project's repositories, i.e., they can modify the code,
documentation, and website by themselves and also accept other contributions.
The official list of committers can be found `here `__.
@@ -277,7 +277,7 @@ For effective collaboration, make sure to join the following Airflow groups:
- Mailing lists:
- - Developer’s mailing list ``_
+ - Developer's mailing list ``_
(quite substantial traffic on this list)
- All commits mailing list: ``_
@@ -341,13 +341,11 @@ Step 4: Prepare PR
* `doc`
* `misc`
- To add a newsfragment, simply create an rst file named ``{pr_number}.{type}.rst`` (e.g. ``1234.bugfix.rst``)
+ To add a newsfragment, create an ``rst`` file named ``{pr_number}.{type}.rst`` (e.g. ``1234.bugfix.rst``)
and place in either `newsfragments `__ for core newsfragments,
or `chart/newsfragments `__ for helm chart newsfragments.
- For significant newsfragments, similar to git commits, the first line is the summary and optionally a
- body can be added with an empty line separating it.
- For other newsfragment types, only use a single summary line.
+ In general newsfragments must be one line. For newsfragment type ``significant``, you may include summary and body separated by a blank line, similar to ``git`` commit messages.
2. Rebase your fork, squash commits, and resolve all conflicts. See `How to rebase PR <#how-to-rebase-pr>`_
if you need help with rebasing your change. Remember to rebase often if your PR takes a lot of time to
@@ -361,33 +359,6 @@ Step 4: Prepare PR
PR guidelines described in `pull request guidelines <#pull-request-guidelines>`_.
Create Pull Request! Make yourself ready for the discussion!
-5. Depending on "scope" of your changes, your Pull Request might go through one of few paths after approval.
- We run some non-standard workflow with high degree of automation that allows us to optimize the usage
- of queue slots in GitHub Actions. Our automated workflows determine the "scope" of changes in your PR
- and send it through the right path:
-
- * In case of a "no-code" change, approval will generate a comment that the PR can be merged and no
- tests are needed. This is usually when the change modifies some non-documentation related RST
- files (such as this file). No python tests are run and no CI images are built for such PR. Usually
- it can be approved and merged few minutes after it is submitted (unless there is a big queue of jobs).
-
- * In case of change involving python code changes or documentation changes, a subset of full test matrix
- will be executed. This subset of tests perform relevant tests for single combination of python, backend
- version and only builds one CI image and one PROD image. Here the scope of tests depends on the
- scope of your changes:
-
- * when your change does not change "core" of Airflow (Providers, CLI, WWW, Helm Chart) you will get the
- comment that PR is likely ok to be merged without running "full matrix" of tests. However decision
- for that is left to committer who approves your change. The committer might set a "full tests needed"
- label for your PR and ask you to rebase your request or re-run all jobs. PRs with "full tests needed"
- run full matrix of tests.
-
- * when your change changes the "core" of Airflow you will get the comment that PR needs full tests and
- the "full tests needed" label is set for your PR. Additional check is set that prevents from
- accidental merging of the request until full matrix of tests succeeds for the PR.
-
- More details about the PR workflow be found in `PULL_REQUEST_WORKFLOW.rst `_.
-
Step 5: Pass PR Review
----------------------
@@ -574,13 +545,6 @@ All details about using and running Airflow Breeze can be found in
The Airflow Breeze solution is intended to ease your local development as "*It's
a Breeze to develop Airflow*".
-.. note::
-
- We are in a process of switching to the new Python-based Breeze from a legacy Bash
- Breeze. Not all functionality has been ported yet and the old Breeze is still available
- until then as ``./breeze-legacy`` script. The documentation mentions when the old ./breeze-legacy
- should be still used.
-
Benefits:
- Breeze is a complete environment that includes external components, such as
@@ -648,16 +612,16 @@ This is the full list of those extras:
.. START EXTRAS HERE
airbyte, alibaba, all, all_dbs, amazon, apache.atlas, apache.beam, apache.cassandra, apache.drill,
apache.druid, apache.hdfs, apache.hive, apache.kylin, apache.livy, apache.pig, apache.pinot,
-apache.spark, apache.sqoop, apache.webhdfs, arangodb, asana, async, atlas, aws, azure, cassandra,
-celery, cgroups, cloudant, cncf.kubernetes, crypto, dask, databricks, datadog, dbt.cloud,
-deprecated_api, devel, devel_all, devel_ci, devel_hadoop, dingding, discord, doc, docker, druid,
-elasticsearch, exasol, facebook, ftp, gcp, gcp_api, github, github_enterprise, google, google_auth,
-grpc, hashicorp, hdfs, hive, http, imap, influxdb, jdbc, jenkins, jira, kerberos, kubernetes, ldap,
-leveldb, microsoft.azure, microsoft.mssql, microsoft.psrp, microsoft.winrm, mongo, mssql, mysql,
-neo4j, odbc, openfaas, opsgenie, oracle, pagerduty, pandas, papermill, password, pinot, plexus,
-postgres, presto, qds, qubole, rabbitmq, redis, s3, salesforce, samba, segment, sendgrid, sentry,
-sftp, singularity, slack, snowflake, spark, sqlite, ssh, statsd, tableau, telegram, trino, vertica,
-virtualenv, webhdfs, winrm, yandex, zendesk
+apache.spark, apache.sqoop, apache.webhdfs, arangodb, asana, async, atlas, atlassian.jira, aws,
+azure, cassandra, celery, cgroups, cloudant, cncf.kubernetes, common.sql, crypto, dask, databricks,
+datadog, dbt.cloud, deprecated_api, devel, devel_all, devel_ci, devel_hadoop, dingding, discord,
+doc, doc_gen, docker, druid, elasticsearch, exasol, facebook, ftp, gcp, gcp_api, github,
+github_enterprise, google, google_auth, grpc, hashicorp, hdfs, hive, http, imap, influxdb, jdbc,
+jenkins, kerberos, kubernetes, ldap, leveldb, microsoft.azure, microsoft.mssql, microsoft.psrp,
+microsoft.winrm, mongo, mssql, mysql, neo4j, odbc, openfaas, opsgenie, oracle, pagerduty, pandas,
+papermill, password, pinot, plexus, postgres, presto, qds, qubole, rabbitmq, redis, s3, salesforce,
+samba, segment, sendgrid, sentry, sftp, singularity, slack, snowflake, spark, sqlite, ssh, statsd,
+tableau, tabular, telegram, trino, vertica, virtualenv, webhdfs, winrm, yandex, zendesk
.. END EXTRAS HERE
Provider packages
@@ -666,7 +630,23 @@ Provider packages
Airflow 2.0 is split into core and providers. They are delivered as separate packages:
* ``apache-airflow`` - core of Apache Airflow
-* ``apache-airflow-providers-*`` - More than 50 provider packages to communicate with external services
+* ``apache-airflow-providers-*`` - More than 70 provider packages to communicate with external services
+
+The information/meta-data about the providers is kept in ``provider.yaml`` file in the right sub-directory
+of ``airflow\providers``. This file contains:
+
+* package name (``apache-airflow-provider-*``)
+* user-facing name of the provider package
+* description of the package that is available in the documentation
+* list of versions of package that have been released so far
+* list of dependencies of the provider package
+* list of additional-extras that the provider package provides (together with dependencies of those extras)
+* list of integrations, operators, hooks, sensors, transfers provided by the provider (useful for documentation generation)
+* list of connection types, extra-links, secret backends, auth backends, and logging handlers (useful to both
+ register them as they are needed by Airflow and to include them in documentation automatically).
+
+If you want to add dependencies to the provider, you should add them to the corresponding ``provider.yaml``
+and Airflow pre-commits and package generation commands will use them when preparing package information.
In Airflow 1.10 all those providers were installed together within one single package and when you installed
airflow locally, from sources, they were also installed. In Airflow 2.0, providers are separated out,
@@ -685,7 +665,7 @@ in this airflow folder - the providers package is importable.
Some of the packages have cross-dependencies with other providers packages. This typically happens for
transfer operators where operators use hooks from the other providers in case they are transferring
data between the providers. The list of dependencies is maintained (automatically with pre-commits)
-in the ``airflow/providers/dependencies.json``. Pre-commits are also used to generate dependencies.
+in the ``generated/provider_dependencies.json``. Pre-commits are also used to generate dependencies.
The dependency list is automatically used during PyPI packages generation.
Cross-dependencies between provider packages are converted into extras - if you need functionality from
@@ -695,49 +675,8 @@ the other provider package you can install it adding [extra] after the
transfer operators from Amazon ECS.
If you add a new dependency between different providers packages, it will be detected automatically during
-pre-commit phase and pre-commit will fail - and add entry in dependencies.json so that the package extra
-dependencies are properly added when package is installed.
-
-You can regenerate the whole list of provider dependencies by running this command (you need to have
-``pre-commits`` installed).
-
-.. code-block:: bash
-
- pre-commit run build-providers-dependencies
-
-
-Here is the list of packages and their extras:
-
-
- .. START PACKAGE DEPENDENCIES HERE
-
-========================== ===========================
-Package Extras
-========================== ===========================
-airbyte http
-amazon apache.hive,cncf.kubernetes,exasol,ftp,google,imap,mongo,salesforce,ssh
-apache.beam google
-apache.druid apache.hive
-apache.hive amazon,microsoft.mssql,mysql,presto,samba,vertica
-apache.livy http
-dbt.cloud http
-dingding http
-discord http
-google amazon,apache.beam,apache.cassandra,cncf.kubernetes,facebook,microsoft.azure,microsoft.mssql,mysql,oracle,postgres,presto,salesforce,sftp,ssh,trino
-hashicorp google
-microsoft.azure google,oracle,sftp
-mysql amazon,presto,trino,vertica
-postgres amazon
-presto google
-salesforce tableau
-sftp ssh
-slack http
-snowflake slack
-trino google
-========================== ===========================
-
- .. END PACKAGE DEPENDENCIES HERE
-
+and pre-commit will generate new entry in ``generated/provider_dependencies.json`` so that
+the package extra dependencies are properly handled when package is installed.
Developing community managed provider packages
----------------------------------------------
@@ -1190,12 +1129,12 @@ itself comes bundled with jQuery and bootstrap. While they may be phased out
over time, these packages are currently not managed with yarn.
Make sure you are using recent versions of node and yarn. No problems have been
-found with node\>=8.11.3 and yarn\>=1.19.1.
-
-Installing yarn and its packages
---------------------------------
+found with node\>=8.11.3 and yarn\>=1.19.1. The pre-commit framework of ours install
+node and yarn automatically when installed - if you use ``breeze`` you do not need to install
+neither node nor yarn.
-Make sure yarn is available in your environment.
+Installing yarn and its packages manually
+-----------------------------------------
To install yarn on macOS:
@@ -1219,27 +1158,6 @@ To install yarn on macOS:
export PATH="$HOME/.yarn/bin:$PATH"
4. Install third-party libraries defined in ``package.json`` by running the
- following commands within the ``airflow/www/`` directory:
-
-
-.. code-block:: bash
-
- # from the root of the repository, move to where our JS package.json lives
- cd airflow/www/
- # run yarn install to fetch all the dependencies
- yarn install
-
-
-These commands install the libraries in a new ``node_modules/`` folder within
-``www/``.
-
-Should you add or upgrade a node package, run
-``yarn add --dev `` for packages needed in development or
-``yarn add `` for packages used by the code.
-Then push the newly generated ``package.json`` and ``yarn.lock`` file so that we
-could get a reproducible build. See the `Yarn docs
-`_ for more details.
-
Generate Bundled Files with yarn
--------------------------------
@@ -1253,7 +1171,7 @@ commands:
yarn run prod
# Starts a web server that manages and updates your assets as you modify them
- # You'll need to run the webserver in debug mode too: `airflow webserver -d`
+ # You'll need to run the webserver in debug mode too: ``airflow webserver -d``
yarn run dev
@@ -1285,10 +1203,10 @@ commands:
React, JSX and Chakra
-----------------------------
-In order to create a more modern UI, we have started to include [React](https://reactjs.org/) in the ``airflow/www/`` project.
+In order to create a more modern UI, we have started to include `React `__ in the ``airflow/www/`` project.
If you are unfamiliar with React then it is recommended to check out their documentation to understand components and jsx syntax.
-We are using [Chakra UI](https://chakra-ui.com/) as a component and styling library. Notably, all styling is done in a theme file or
+We are using `Chakra UI `__ as a component and styling library. Notably, all styling is done in a theme file or
inline when defining a component. There are a few shorthand style props like ``px`` instead of ``padding-right, padding-left``.
To make this work, all Chakra styling and css styling are completely separate. It is best to think of the React components as a separate app
that lives inside of the main app.
@@ -1526,14 +1444,14 @@ Here are a few rules that are important to keep in mind when you enter our commu
* There is a #newbie-questions channel in slack as a safe place to ask questions
* You can ask one of the committers to be a mentor for you, committers can guide within the community
* You can apply to more structured `Apache Mentoring Programme `_
-* It’s your responsibility as an author to take your PR from start-to-end including leading communication
+* It's your responsibility as an author to take your PR from start-to-end including leading communication
in the PR
-* It’s your responsibility as an author to ping committers to review your PR - be mildly annoying sometimes,
- it’s OK to be slightly annoying with your change - it is also a sign for committers that you care
+* It's your responsibility as an author to ping committers to review your PR - be mildly annoying sometimes,
+ it's OK to be slightly annoying with your change - it is also a sign for committers that you care
* Be considerate to the high code quality/test coverage requirements for Apache Airflow
* If in doubt - ask the community for their opinion or propose to vote at the devlist
* Discussions should concern subject matters - judge or criticise the merit but never criticise people
-* It’s OK to express your own emotions while communicating - it helps other people to understand you
+* It's OK to express your own emotions while communicating - it helps other people to understand you
* Be considerate for feelings of others. Tell about how you feel not what you think of others
Commit Policy
@@ -1549,6 +1467,6 @@ and slightly modified and consensus reached in October 2020:
Resources & Links
=================
-- `Airflow’s official documentation `__
+- `Airflow's official documentation `__
- `More resources and links to Airflow related content on the Wiki `__
diff --git a/CONTRIBUTORS_QUICK_START.rst b/CONTRIBUTORS_QUICK_START.rst
index 10ff2cc9516d9..bffd4967c8902 100644
--- a/CONTRIBUTORS_QUICK_START.rst
+++ b/CONTRIBUTORS_QUICK_START.rst
@@ -50,11 +50,11 @@ Local machine development
If you do not work with remote development environment, you need those prerequisites.
-1. Docker Community Edition
+1. Docker Community Edition (you can also use Colima, see instructions below)
2. Docker Compose
3. pyenv (you can also use pyenv-virtualenv or virtualenvwrapper)
-The below setup describe Ubuntu installation. It might be slightly different on different machines.
+The below setup describe `Ubuntu installation `_. It might be slightly different on different machines.
Docker Community Edition
------------------------
@@ -66,25 +66,24 @@ Docker Community Edition
$ sudo apt-get update
$ sudo apt-get install \
- apt-transport-https \
ca-certificates \
curl \
- gnupg-agent \
- software-properties-common
+ gnupg \
+ lsb-release
- $ curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo apt-key add -
+ $ sudo mkdir -p /etc/apt/keyrings
+ $ curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo gpg --dearmor -o /etc/apt/keyrings/docker.gpg
- $ sudo add-apt-repository \
- "deb [arch=amd64] https://download.docker.com/linux/ubuntu \
- $(lsb_release -cs) \
- stable"
+ $ echo \
+ "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.gpg] https://download.docker.com/linux/ubuntu \
+ $(lsb_release -cs) stable" | sudo tee /etc/apt/sources.list.d/docker.list > /dev/null
-2. Install Docker
+2. Install Docker Engine, containerd, and Docker Compose Plugin.
.. code-block:: bash
$ sudo apt-get update
- $ sudo apt-get install docker-ce docker-ce-cli containerd.io
+ $ sudo apt-get install docker-ce docker-ce-cli containerd.io docker-compose-plugin
3. Creating group for docker and adding current user to it.
@@ -101,6 +100,24 @@ Note : After adding user to docker group Logout and Login again for group member
$ docker run hello-world
+Colima
+------
+If you use Colima as your container runtimes engine, please follow the next steps:
+
+1. `Install buildx manually `_ and follow it's instructions
+
+2. Link the Colima socket to the default socket path. Note that this may break other Docker servers.
+
+.. code-block:: bash
+
+ $ sudo ln -sf $HOME/.colima/default/docker.sock /var/run/docker.sock
+
+3. Change docker context to use default:
+
+.. code-block:: bash
+
+ $ docker context use default
+
Docker Compose
--------------
@@ -218,13 +235,13 @@ Setting up Breeze
.. code-block:: bash
- $ breeze setup-autocomplete
+ $ breeze setup autocomplete
4. Initialize breeze environment with required python version and backend. This may take a while for first time.
.. code-block:: bash
- $ breeze --python 3.7 --backend mysql
+ $ breeze --python 3.8 --backend postgres
.. note::
If you encounter an error like "docker.credentials.errors.InitializationError:
@@ -272,7 +289,7 @@ Using Breeze
------------
1. Starting breeze environment using ``breeze start-airflow`` starts Breeze environment with last configuration run(
- In this case python and backend will be picked up from last execution ``breeze --python 3.8 --backend mysql``)
+ In this case python and backend will be picked up from last execution ``breeze --python 3.8 --backend postgres``)
It also automatically starts webserver, backend and scheduler. It drops you in tmux with scheduler in bottom left
and webserver in bottom right. Use ``[Ctrl + B] and Arrow keys`` to navigate.
@@ -301,7 +318,7 @@ Using Breeze
* 26379 -> forwarded to Redis broker -> redis:6379
Here are links to those services that you can use on host:
- * ssh connection for remote debugging: ssh -p 12322 airflow@127.0.0.1 pw: airflow
+ * ssh connection for remote debugging: ssh -p 12322 airflow@127.0.0.1 (password: airflow)
* Webserver: http://127.0.0.1:28080
* Flower: http://127.0.0.1:25555
* Postgres: jdbc:postgresql://127.0.0.1:25433/airflow?user=postgres&password=airflow
@@ -324,7 +341,7 @@ Using Breeze
.. code-block:: bash
- $ breeze --python 3.8 --backend mysql
+ $ breeze --python 3.8 --backend postgres
2. Open tmux
@@ -505,7 +522,7 @@ To avoid burden on CI infrastructure and to save time, Pre-commit hooks can be r
.. code-block:: bash
- $ pre-commit run --files airflow/decorators.py tests/utils/test_task_group.py
+ $ pre-commit run --files airflow/utils/decorators.py tests/utils/test_task_group.py
@@ -515,7 +532,7 @@ To avoid burden on CI infrastructure and to save time, Pre-commit hooks can be r
$ pre-commit run black --files airflow/decorators.py tests/utils/test_task_group.py
black...............................................................Passed
- $ pre-commit run flake8 --files airflow/decorators.py tests/utils/test_task_group.py
+ $ pre-commit run run-flake8 --files airflow/decorators.py tests/utils/test_task_group.py
Run flake8..........................................................Passed
@@ -612,8 +629,7 @@ All Tests are inside ./tests directory.
.. code-block:: bash
- $ breeze --backend mysql --mysql-version 5.7 --python 3.8 --db-reset --test-type All tests
-
+ $ breeze --backend postgres --postgres-version 10 --python 3.8 --db-reset testing tests --test-type All
- Running specific type of test
@@ -623,7 +639,7 @@ All Tests are inside ./tests directory.
.. code-block:: bash
- $ breeze --backend mysql --mysql-version 5.7 --python 3.8 --db-reset --test-type Core
+ $ breeze --backend postgres --postgres-version 10 --python 3.8 --db-reset testing tests --test-type Core
- Running Integration test for specific test type
@@ -632,7 +648,7 @@ All Tests are inside ./tests directory.
.. code-block:: bash
- $ breeze --backend mysql --mysql-version 5.7 --python 3.8 --db-reset --test-type All --integration mongo
+ $ breeze --backend postgres --postgres-version 10 --python 3.8 --db-reset testing tests --test-type All --integration mongo
- For more information on Testing visit : |TESTING.rst|
diff --git a/Dockerfile b/Dockerfile
index cab46244ae954..3da82f2d2b15b 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -44,11 +44,11 @@ ARG AIRFLOW_UID="50000"
ARG AIRFLOW_USER_HOME_DIR=/home/airflow
# latest released version here
-ARG AIRFLOW_VERSION="2.3.1"
+ARG AIRFLOW_VERSION="2.4.3"
ARG PYTHON_BASE_IMAGE="python:3.7-slim-bullseye"
-ARG AIRFLOW_PIP_VERSION=22.1.2
+ARG AIRFLOW_PIP_VERSION=22.3.1
ARG AIRFLOW_IMAGE_REPOSITORY="https://github.com/apache/airflow"
ARG AIRFLOW_IMAGE_README_URL="https://raw.githubusercontent.com/apache/airflow/main/docs/docker-stack/README.md"
@@ -71,39 +71,106 @@ FROM scratch as scripts
# make the PROD Dockerfile standalone
##############################################################################################
-# The content below is automatically copied from scripts/docker/determine_debian_version_specific_variables.sh
-COPY <<"EOF" /determine_debian_version_specific_variables.sh
-function determine_debian_version_specific_variables() {
- local color_red
- color_red=$'\e[31m'
- local color_reset
- color_reset=$'\e[0m'
-
- local debian_version
- debian_version=$(lsb_release -cs)
- if [[ ${debian_version} == "buster" ]]; then
- export DISTRO_LIBENCHANT="libenchant-dev"
- export DISTRO_LIBGCC="libgcc-8-dev"
- export DISTRO_SELINUX="python-selinux"
- export DISTRO_LIBFFI="libffi6"
- # Note missing man directories on debian-buster
- # https://bugs.debian.org/cgi-bin/bugreport.cgi?bug=863199
- mkdir -pv /usr/share/man/man1
- mkdir -pv /usr/share/man/man7
- elif [[ ${debian_version} == "bullseye" ]]; then
- export DISTRO_LIBENCHANT="libenchant-2-2"
- export DISTRO_LIBGCC="libgcc-10-dev"
- export DISTRO_SELINUX="python3-selinux"
- export DISTRO_LIBFFI="libffi7"
+# The content below is automatically copied from scripts/docker/install_os_dependencies.sh
+COPY <<"EOF" /install_os_dependencies.sh
+set -euo pipefail
+
+DOCKER_CLI_VERSION=20.10.9
+
+if [[ "$#" != 1 ]]; then
+ echo "ERROR! There should be 'runtime' or 'dev' parameter passed as argument.".
+ exit 1
+fi
+
+if [[ "${1}" == "runtime" ]]; then
+ INSTALLATION_TYPE="RUNTIME"
+elif [[ "${1}" == "dev" ]]; then
+ INSTALLATION_TYPE="dev"
+else
+ echo "ERROR! Wrong argument. Passed ${1} and it should be one of 'runtime' or 'dev'.".
+ exit 1
+fi
+
+function get_dev_apt_deps() {
+ if [[ "${DEV_APT_DEPS=}" == "" ]]; then
+ DEV_APT_DEPS="apt-transport-https apt-utils build-essential ca-certificates dirmngr \
+freetds-bin freetds-dev git gosu graphviz graphviz-dev krb5-user ldap-utils libffi-dev \
+libkrb5-dev libldap2-dev libleveldb1d libleveldb-dev libsasl2-2 libsasl2-dev libsasl2-modules \
+libssl-dev locales lsb-release openssh-client sasl2-bin \
+software-properties-common sqlite3 sudo unixodbc unixodbc-dev"
+ export DEV_APT_DEPS
+ fi
+}
+
+function get_runtime_apt_deps() {
+ if [[ "${RUNTIME_APT_DEPS=}" == "" ]]; then
+ RUNTIME_APT_DEPS="apt-transport-https apt-utils ca-certificates \
+curl dumb-init freetds-bin gosu krb5-user \
+ldap-utils libffi7 libldap-2.4-2 libsasl2-2 libsasl2-modules libssl1.1 locales \
+lsb-release netcat openssh-client python3-selinux rsync sasl2-bin sqlite3 sudo unixodbc"
+ export RUNTIME_APT_DEPS
+ fi
+}
+
+function install_docker_cli() {
+ local platform
+ if [[ $(uname -m) == "arm64" || $(uname -m) == "aarch64" ]]; then
+ platform="aarch64"
else
- echo
- echo "${color_red}Unknown distro version ${debian_version}${color_reset}"
- echo
- exit 1
+ platform="x86_64"
fi
+ curl --silent \
+ "https://download.docker.com/linux/static/stable/${platform}/docker-${DOCKER_CLI_VERSION}.tgz" \
+ | tar -C /usr/bin --strip-components=1 -xvzf - docker/docker
}
-determine_debian_version_specific_variables
+function install_debian_dev_dependencies() {
+ apt-get update
+ apt-get install --no-install-recommends -yqq apt-utils >/dev/null 2>&1
+ apt-get install -y --no-install-recommends curl gnupg2 lsb-release
+ # shellcheck disable=SC2086
+ export ${ADDITIONAL_DEV_APT_ENV?}
+ if [[ ${DEV_APT_COMMAND} != "" ]]; then
+ bash -o pipefail -o errexit -o nounset -o nolog -c "${DEV_APT_COMMAND}"
+ fi
+ if [[ ${ADDITIONAL_DEV_APT_COMMAND} != "" ]]; then
+ bash -o pipefail -o errexit -o nounset -o nolog -c "${ADDITIONAL_DEV_APT_COMMAND}"
+ fi
+ apt-get update
+ # shellcheck disable=SC2086
+ apt-get install -y --no-install-recommends ${DEV_APT_DEPS} ${ADDITIONAL_DEV_APT_DEPS}
+}
+
+function install_debian_runtime_dependencies() {
+ apt-get update
+ apt-get install --no-install-recommends -yqq apt-utils >/dev/null 2>&1
+ apt-get install -y --no-install-recommends curl gnupg2 lsb-release
+ # shellcheck disable=SC2086
+ export ${ADDITIONAL_RUNTIME_APT_ENV?}
+ if [[ "${RUNTIME_APT_COMMAND}" != "" ]]; then
+ bash -o pipefail -o errexit -o nounset -o nolog -c "${RUNTIME_APT_COMMAND}"
+ fi
+ if [[ "${ADDITIONAL_RUNTIME_APT_COMMAND}" != "" ]]; then
+ bash -o pipefail -o errexit -o nounset -o nolog -c "${ADDITIONAL_RUNTIME_APT_COMMAND}"
+ fi
+ apt-get update
+ # shellcheck disable=SC2086
+ apt-get install -y --no-install-recommends ${RUNTIME_APT_DEPS} ${ADDITIONAL_RUNTIME_APT_DEPS}
+ apt-get autoremove -yqq --purge
+ apt-get clean
+ rm -rf /var/lib/apt/lists/* /var/log/*
+}
+
+if [[ "${INSTALLATION_TYPE}" == "RUNTIME" ]]; then
+ get_runtime_apt_deps
+ install_debian_runtime_dependencies
+ install_docker_cli
+
+else
+ get_dev_apt_deps
+ install_debian_dev_dependencies
+ install_docker_cli
+fi
EOF
# The content below is automatically copied from scripts/docker/install_mysql.sh
@@ -178,8 +245,6 @@ set -euo pipefail
COLOR_BLUE=$'\e[34m'
readonly COLOR_BLUE
-COLOR_YELLOW=$'\e[33m'
-readonly COLOR_YELLOW
COLOR_RESET=$'\e[0m'
readonly COLOR_RESET
@@ -197,19 +262,8 @@ function install_mssql_client() {
local distro
local version
distro=$(lsb_release -is | tr '[:upper:]' '[:lower:]')
- version_name=$(lsb_release -cs | tr '[:upper:]' '[:lower:]')
version=$(lsb_release -rs)
- local driver
- if [[ ${version_name} == "buster" ]]; then
- driver=msodbcsql17
- elif [[ ${version_name} == "bullseye" ]]; then
- driver=msodbcsql18
- else
- echo
- echo "${COLOR_YELLOW}Only Buster or Bullseye are supported. Skipping MSSQL installation${COLOR_RESET}"
- echo
- return
- fi
+ local driver=msodbcsql18
curl --silent https://packages.microsoft.com/keys/microsoft.asc | apt-key add - >/dev/null 2>&1
curl --silent "https://packages.microsoft.com/config/${distro}/${version}/prod.list" > \
/etc/apt/sources.list.d/mssql-release.list
@@ -221,11 +275,6 @@ function install_mssql_client() {
apt-get clean && rm -rf /var/lib/apt/lists/*
}
-if [[ $(uname -m) == "arm64" || $(uname -m) == "aarch64" ]]; then
- # disable MSSQL for ARM64
- INSTALL_MSSQL_CLIENT="false"
-fi
-
install_mssql_client "${@}"
EOF
@@ -276,20 +325,12 @@ COPY <<"EOF" /install_pip_version.sh
: "${AIRFLOW_PIP_VERSION:?Should be set}"
-function install_pip_version() {
- echo
- echo "${COLOR_BLUE}Installing pip version ${AIRFLOW_PIP_VERSION}${COLOR_RESET}"
- echo
- pip install --disable-pip-version-check --no-cache-dir --upgrade "pip==${AIRFLOW_PIP_VERSION}" &&
- mkdir -p ${HOME}/.local/bin
-}
-
common::get_colors
common::get_airflow_version_specification
common::override_pip_version_if_needed
common::show_pip_version_and_location
-install_pip_version
+common::install_pip_version
EOF
# The content below is automatically copied from scripts/docker/install_airflow_dependencies_from_branch_tip.sh
@@ -317,10 +358,10 @@ function install_airflow_dependencies_from_branch_tip() {
# are conflicts, this might fail, but it should be fixed in the following installation steps
set -x
pip install --root-user-action ignore \
+ ${ADDITIONAL_PIP_INSTALL_FLAGS} \
"https://github.com/${AIRFLOW_REPO}/archive/${AIRFLOW_BRANCH}.tar.gz#egg=apache-airflow[${AIRFLOW_EXTRAS}]" \
--constraint "${AIRFLOW_CONSTRAINTS_LOCATION}" || true
- # make sure correct PIP version is used
- pip install --disable-pip-version-check "pip==${AIRFLOW_PIP_VERSION}" 2>/dev/null
+ common::install_pip_version
pip freeze | grep apache-airflow-providers | xargs pip uninstall --yes 2>/dev/null || true
set +x
echo
@@ -367,7 +408,7 @@ function common::get_airflow_version_specification() {
function common::override_pip_version_if_needed() {
if [[ -n ${AIRFLOW_VERSION} ]]; then
if [[ ${AIRFLOW_VERSION} =~ ^2\.0.* || ${AIRFLOW_VERSION} =~ ^1\.* ]]; then
- export AIRFLOW_PIP_VERSION="22.1.2"
+ export AIRFLOW_PIP_VERSION="22.3.1"
fi
fi
}
@@ -395,101 +436,18 @@ function common::show_pip_version_and_location() {
echo "pip on path: $(which pip)"
echo "Using pip: $(pip --version)"
}
-EOF
-
-# The content below is automatically copied from scripts/docker/prepare_node_modules.sh
-COPY <<"EOF" /prepare_node_modules.sh
-set -euo pipefail
-
-COLOR_BLUE=$'\e[34m'
-readonly COLOR_BLUE
-COLOR_RESET=$'\e[0m'
-readonly COLOR_RESET
-
-function prepare_node_modules() {
- echo
- echo "${COLOR_BLUE}Preparing node modules${COLOR_RESET}"
- echo
- local www_dir
- if [[ ${AIRFLOW_INSTALLATION_METHOD=} == "." ]]; then
- # In case we are building from sources in production image, we should build the assets
- www_dir="${AIRFLOW_SOURCES_TO=${AIRFLOW_SOURCES}}/airflow/www"
- else
- www_dir="$(python -m site --user-site)/airflow/www"
- fi
- pushd ${www_dir} || exit 1
- set +e
- yarn install --frozen-lockfile --no-cache 2>/tmp/out-yarn-install.txt
- local res=$?
- if [[ ${res} != 0 ]]; then
- >&2 echo
- >&2 echo "Error when running yarn install:"
- >&2 echo
- >&2 cat /tmp/out-yarn-install.txt && rm -f /tmp/out-yarn-install.txt
- exit 1
- fi
- rm -f /tmp/out-yarn-install.txt
- popd || exit 1
-}
-prepare_node_modules
-EOF
-
-# The content below is automatically copied from scripts/docker/compile_www_assets.sh
-COPY <<"EOF" /compile_www_assets.sh
-set -euo pipefail
-
-BUILD_TYPE=${BUILD_TYPE="prod"}
-REMOVE_ARTIFACTS=${REMOVE_ARTIFACTS="true"}
-
-COLOR_BLUE=$'\e[34m'
-readonly COLOR_BLUE
-COLOR_RESET=$'\e[0m'
-readonly COLOR_RESET
-
-function compile_www_assets() {
+function common::install_pip_version() {
echo
- echo "${COLOR_BLUE}Compiling www assets: running yarn ${BUILD_TYPE}${COLOR_RESET}"
+ echo "${COLOR_BLUE}Installing pip version ${AIRFLOW_PIP_VERSION}${COLOR_RESET}"
echo
- local www_dir
- if [[ ${AIRFLOW_INSTALLATION_METHOD=} == "." ]]; then
- # In case we are building from sources in production image, we should build the assets
- www_dir="${AIRFLOW_SOURCES_TO=${AIRFLOW_SOURCES}}/airflow/www"
+ if [[ ${AIRFLOW_PIP_VERSION} =~ .*https.* ]]; then
+ pip install --disable-pip-version-check --no-cache-dir "pip @ ${AIRFLOW_PIP_VERSION}"
else
- www_dir="$(python -m site --user-site)/airflow/www"
+ pip install --disable-pip-version-check --no-cache-dir "pip==${AIRFLOW_PIP_VERSION}"
fi
- pushd ${www_dir} || exit 1
- set +e
- yarn run "${BUILD_TYPE}" 2>/tmp/out-yarn-run.txt
- res=$?
- if [[ ${res} != 0 ]]; then
- >&2 echo
- >&2 echo "Error when running yarn run:"
- >&2 echo
- >&2 cat /tmp/out-yarn-run.txt && rm -rf /tmp/out-yarn-run.txt
- exit 1
- fi
- rm -f /tmp/out-yarn-run.txt
- set -e
- local md5sum_file
- md5sum_file="static/dist/sum.md5"
- readonly md5sum_file
- find package.json yarn.lock static/css static/js -type f | sort | xargs md5sum > "${md5sum_file}"
- if [[ ${REMOVE_ARTIFACTS} == "true" ]]; then
- echo
- echo "${COLOR_BLUE}Removing generated node modules${COLOR_RESET}"
- echo
- rm -rf "${www_dir}/node_modules"
- rm -vf "${www_dir}"/{package.json,yarn.lock,.eslintignore,.eslintrc,.stylelintignore,.stylelintrc,compile_assets.sh,webpack.config.js}
- else
- echo
- echo "${COLOR_BLUE}Leaving generated node modules${COLOR_RESET}"
- echo
- fi
- popd || exit 1
+ mkdir -p "${HOME}/.local/bin"
}
-
-compile_www_assets
EOF
# The content below is automatically copied from scripts/docker/pip
@@ -571,12 +529,12 @@ function install_airflow_and_providers_from_docker_context_files(){
# force reinstall all airflow + provider package local files with eager upgrade
set -x
pip install "${pip_flags[@]}" --root-user-action ignore --upgrade --upgrade-strategy eager \
+ ${ADDITIONAL_PIP_INSTALL_FLAGS} \
${reinstalling_apache_airflow_package} ${reinstalling_apache_airflow_providers_packages} \
${EAGER_UPGRADE_ADDITIONAL_REQUIREMENTS}
set +x
- # make sure correct PIP version is left installed
- pip install --disable-pip-version-check "pip==${AIRFLOW_PIP_VERSION}" 2>/dev/null
+ common::install_pip_version
pip check
}
@@ -591,10 +549,10 @@ function install_all_other_packages_from_docker_context_files() {
grep -v apache_airflow | grep -v apache-airflow || true)
if [[ -n "${reinstalling_other_packages}" ]]; then
set -x
- pip install --root-user-action ignore --force-reinstall --no-deps --no-index ${reinstalling_other_packages}
- # make sure correct PIP version is used
- pip install --disable-pip-version-check "pip==${AIRFLOW_PIP_VERSION}" 2>/dev/null
- set -x
+ pip install ${ADDITIONAL_PIP_INSTALL_FLAGS} \
+ --root-user-action ignore --force-reinstall --no-deps --no-index ${reinstalling_other_packages}
+ common::install_pip_version
+ set +x
fi
}
@@ -642,6 +600,7 @@ function install_airflow() {
echo
# eager upgrade
pip install --root-user-action ignore --upgrade --upgrade-strategy eager \
+ ${ADDITIONAL_PIP_INSTALL_FLAGS} \
"${AIRFLOW_INSTALLATION_METHOD}[${AIRFLOW_EXTRAS}]${AIRFLOW_VERSION_SPECIFICATION}" \
${EAGER_UPGRADE_ADDITIONAL_REQUIREMENTS}
if [[ -n "${AIRFLOW_INSTALL_EDITABLE_FLAG}" ]]; then
@@ -650,12 +609,12 @@ function install_airflow() {
set -x
pip uninstall apache-airflow --yes
pip install --root-user-action ignore ${AIRFLOW_INSTALL_EDITABLE_FLAG} \
+ ${ADDITIONAL_PIP_INSTALL_FLAGS} \
"${AIRFLOW_INSTALLATION_METHOD}[${AIRFLOW_EXTRAS}]${AIRFLOW_VERSION_SPECIFICATION}"
set +x
fi
- # make sure correct PIP version is used
- pip install --disable-pip-version-check "pip==${AIRFLOW_PIP_VERSION}" 2>/dev/null
+ common::install_pip_version
echo
echo "${COLOR_BLUE}Running 'pip check'${COLOR_RESET}"
echo
@@ -666,16 +625,16 @@ function install_airflow() {
echo
set -x
pip install --root-user-action ignore ${AIRFLOW_INSTALL_EDITABLE_FLAG} \
+ ${ADDITIONAL_PIP_INSTALL_FLAGS} \
"${AIRFLOW_INSTALLATION_METHOD}[${AIRFLOW_EXTRAS}]${AIRFLOW_VERSION_SPECIFICATION}" \
--constraint "${AIRFLOW_CONSTRAINTS_LOCATION}"
- # make sure correct PIP version is used
- pip install --disable-pip-version-check "pip==${AIRFLOW_PIP_VERSION}" 2>/dev/null
+ common::install_pip_version
# then upgrade if needed without using constraints to account for new limits in setup.py
pip install --root-user-action ignore --upgrade --upgrade-strategy only-if-needed \
+ ${ADDITIONAL_PIP_INSTALL_FLAGS} \
${AIRFLOW_INSTALL_EDITABLE_FLAG} \
"${AIRFLOW_INSTALLATION_METHOD}[${AIRFLOW_EXTRAS}]${AIRFLOW_VERSION_SPECIFICATION}"
- # make sure correct PIP version is used
- pip install --disable-pip-version-check "pip==${AIRFLOW_PIP_VERSION}" 2>/dev/null
+ common::install_pip_version
set +x
echo
echo "${COLOR_BLUE}Running 'pip check'${COLOR_RESET}"
@@ -712,9 +671,9 @@ function install_additional_dependencies() {
echo
set -x
pip install --root-user-action ignore --upgrade --upgrade-strategy eager \
+ ${ADDITIONAL_PIP_INSTALL_FLAGS} \
${ADDITIONAL_PYTHON_DEPS} ${EAGER_UPGRADE_ADDITIONAL_REQUIREMENTS}
- # make sure correct PIP version is used
- pip install --disable-pip-version-check "pip==${AIRFLOW_PIP_VERSION}" 2>/dev/null
+ common::install_pip_version
set +x
echo
echo "${COLOR_BLUE}Running 'pip check'${COLOR_RESET}"
@@ -726,9 +685,9 @@ function install_additional_dependencies() {
echo
set -x
pip install --root-user-action ignore --upgrade --upgrade-strategy only-if-needed \
+ ${ADDITIONAL_PIP_INSTALL_FLAGS} \
${ADDITIONAL_PYTHON_DEPS}
- # make sure correct PIP version is used
- pip install --disable-pip-version-check "pip==${AIRFLOW_PIP_VERSION}" 2>/dev/null
+ common::install_pip_version
set +x
echo
echo "${COLOR_BLUE}Running 'pip check'${COLOR_RESET}"
@@ -1100,44 +1059,10 @@ ENV PYTHON_BASE_IMAGE=${PYTHON_BASE_IMAGE} \
DEBIAN_FRONTEND=noninteractive LANGUAGE=C.UTF-8 LANG=C.UTF-8 LC_ALL=C.UTF-8 \
LC_CTYPE=C.UTF-8 LC_MESSAGES=C.UTF-8
-ARG DEV_APT_DEPS="\
- apt-transport-https \
- apt-utils \
- build-essential \
- ca-certificates \
- dirmngr \
- freetds-bin \
- freetds-dev \
- gosu \
- krb5-user \
- ldap-utils \
- libffi-dev \
- libkrb5-dev \
- libldap2-dev \
- libsasl2-2 \
- libsasl2-dev \
- libsasl2-modules \
- libssl-dev \
- locales \
- lsb-release \
- nodejs \
- openssh-client \
- sasl2-bin \
- software-properties-common \
- sqlite3 \
- sudo \
- unixodbc \
- unixodbc-dev \
- yarn"
-
+ARG DEV_APT_DEPS=""
ARG ADDITIONAL_DEV_APT_DEPS=""
-ARG DEV_APT_COMMAND="\
- curl --silent --fail --location https://deb.nodesource.com/setup_14.x | \
- bash -o pipefail -o errexit -o nolog - \
- && curl --silent https://dl.yarnpkg.com/debian/pubkey.gpg | \
- apt-key add - >/dev/null 2>&1\
- && echo 'deb https://dl.yarnpkg.com/debian/ stable main' > /etc/apt/sources.list.d/yarn.list"
-ARG ADDITIONAL_DEV_APT_COMMAND="echo"
+ARG DEV_APT_COMMAND=""
+ARG ADDITIONAL_DEV_APT_COMMAND=""
ARG ADDITIONAL_DEV_APT_ENV=""
ENV DEV_APT_DEPS=${DEV_APT_DEPS} \
@@ -1146,23 +1071,8 @@ ENV DEV_APT_DEPS=${DEV_APT_DEPS} \
ADDITIONAL_DEV_APT_COMMAND=${ADDITIONAL_DEV_APT_COMMAND} \
ADDITIONAL_DEV_APT_ENV=${ADDITIONAL_DEV_APT_ENV}
-COPY --from=scripts determine_debian_version_specific_variables.sh /scripts/docker/
-# Install basic and additional apt dependencies
-RUN apt-get update \
- && apt-get install --no-install-recommends -yqq apt-utils >/dev/null 2>&1 \
- && apt-get install -y --no-install-recommends curl gnupg2 lsb-release \
- && export ${ADDITIONAL_DEV_APT_ENV?} \
- && source /scripts/docker/determine_debian_version_specific_variables.sh \
- && bash -o pipefail -o errexit -o nounset -o nolog -c "${DEV_APT_COMMAND}" \
- && bash -o pipefail -o errexit -o nounset -o nolog -c "${ADDITIONAL_DEV_APT_COMMAND}" \
- && apt-get update \
- && apt-get install -y --no-install-recommends \
- ${DEV_APT_DEPS} \
- "${DISTRO_SELINUX}" \
- ${ADDITIONAL_DEV_APT_DEPS} \
- && apt-get autoremove -yqq --purge \
- && apt-get clean \
- && rm -rf /var/lib/apt/lists/*
+COPY --from=scripts install_os_dependencies.sh /scripts/docker/
+RUN bash /scripts/docker/install_os_dependencies.sh dev
ARG INSTALL_MYSQL_CLIENT="true"
ARG INSTALL_MSSQL_CLIENT="true"
@@ -1196,18 +1106,10 @@ ARG INSTALL_PROVIDERS_FROM_SOURCES="false"
# But it also can be `.` from local installation or GitHub URL pointing to specific branch or tag
# Of Airflow. Note That for local source installation you need to have local sources of
# Airflow checked out together with the Dockerfile and AIRFLOW_SOURCES_FROM and AIRFLOW_SOURCES_TO
-# set to "." and "/opt/airflow" respectively. Similarly AIRFLOW_SOURCES_WWW_FROM/TO are set to right source
-# and destination
+# set to "." and "/opt/airflow" respectively.
ARG AIRFLOW_INSTALLATION_METHOD="apache-airflow"
# By default we do not upgrade to latest dependencies
ARG UPGRADE_TO_NEWER_DEPENDENCIES="false"
-# By default we install latest airflow from PyPI so we do not need to copy sources of Airflow
-# www to compile the assets but in case of breeze/CI builds we use latest sources and we override those
-# those SOURCES_FROM/TO with "airflow/www" and "/opt/airflow/airflow/www" respectively.
-# This is to rebuild the assets only when any of the www sources change
-ARG AIRFLOW_SOURCES_WWW_FROM="Dockerfile"
-ARG AIRFLOW_SOURCES_WWW_TO="/Dockerfile"
-
# By default we install latest airflow from PyPI so we do not need to copy sources of Airflow
# but in case of breeze/CI builds we use latest sources and we override those
# those SOURCES_FROM/TO with "." and "/opt/airflow" respectively
@@ -1251,6 +1153,9 @@ RUN if [[ -f /docker-context-files/pip.conf ]]; then \
cp /docker-context-files/.piprc "${AIRFLOW_USER_HOME_DIR}/.piprc"; \
fi
+# Additional PIP flags passed to all pip install commands except reinstalling pip itself
+ARG ADDITIONAL_PIP_INSTALL_FLAGS=""
+
ENV AIRFLOW_PIP_VERSION=${AIRFLOW_PIP_VERSION} \
AIRFLOW_PRE_CACHED_PIP_PACKAGES=${AIRFLOW_PRE_CACHED_PIP_PACKAGES} \
INSTALL_PROVIDERS_FROM_SOURCES=${INSTALL_PROVIDERS_FROM_SOURCES} \
@@ -1270,6 +1175,7 @@ ENV AIRFLOW_PIP_VERSION=${AIRFLOW_PIP_VERSION} \
PATH=${PATH}:${AIRFLOW_USER_HOME_DIR}/.local/bin \
AIRFLOW_PIP_VERSION=${AIRFLOW_PIP_VERSION} \
PIP_PROGRESS_BAR=${PIP_PROGRESS_BAR} \
+ ADDITIONAL_PIP_INSTALL_FLAGS=${ADDITIONAL_PIP_INSTALL_FLAGS} \
AIRFLOW_USER_HOME_DIR=${AIRFLOW_USER_HOME_DIR} \
AIRFLOW_HOME=${AIRFLOW_HOME} \
AIRFLOW_UID=${AIRFLOW_UID} \
@@ -1283,61 +1189,50 @@ ENV AIRFLOW_PIP_VERSION=${AIRFLOW_PIP_VERSION} \
COPY --from=scripts common.sh install_pip_version.sh \
install_airflow_dependencies_from_branch_tip.sh /scripts/docker/
+# We can set this value to true in case we want to install .whl/.tar.gz packages placed in the
+# docker-context-files folder. This can be done for both additional packages you want to install
+# as well as Airflow and Provider packages (it will be automatically detected if airflow
+# is installed from docker-context files rather than from PyPI)
+ARG INSTALL_PACKAGES_FROM_CONTEXT="false"
+
# In case of Production build image segment we want to pre-install main version of airflow
# dependencies from GitHub so that we do not have to always reinstall it from the scratch.
# The Airflow (and providers in case INSTALL_PROVIDERS_FROM_SOURCES is "false")
# are uninstalled, only dependencies remain
# the cache is only used when "upgrade to newer dependencies" is not set to automatically
-# account for removed dependencies (we do not install them in the first place)
-# Upgrade to specific PIP version
+# account for removed dependencies (we do not install them in the first place) and in case
+# INSTALL_PACKAGES_FROM_CONTEXT is not set (because then caching it from main makes no sense).
RUN bash /scripts/docker/install_pip_version.sh; \
if [[ ${AIRFLOW_PRE_CACHED_PIP_PACKAGES} == "true" && \
- ${UPGRADE_TO_NEWER_DEPENDENCIES} == "false" ]]; then \
+ ${INSTALL_PACKAGES_FROM_CONTEXT} == "false" && \
+ ${UPGRADE_TO_NEWER_DEPENDENCIES} == "false" ]]; then \
bash /scripts/docker/install_airflow_dependencies_from_branch_tip.sh; \
fi
-COPY --from=scripts compile_www_assets.sh prepare_node_modules.sh /scripts/docker/
-COPY --chown=airflow:0 ${AIRFLOW_SOURCES_WWW_FROM} ${AIRFLOW_SOURCES_WWW_TO}
-
-# hadolint ignore=SC2086, SC2010
-RUN if [[ ${AIRFLOW_INSTALLATION_METHOD} == "." ]]; then \
- # only prepare node modules and compile assets if the prod image is build from sources
- # otherwise they are already compiled-in. We should do it in one step with removing artifacts \
- # as we want to keep the final image small
- bash /scripts/docker/prepare_node_modules.sh; \
- REMOVE_ARTIFACTS="true" BUILD_TYPE="prod" bash /scripts/docker/compile_www_assets.sh; \
- # Copy generated dist folder (otherwise it will be overridden by the COPY step below)
- mv -f /opt/airflow/airflow/www/static/dist /tmp/dist; \
- fi;
-
COPY --chown=airflow:0 ${AIRFLOW_SOURCES_FROM} ${AIRFLOW_SOURCES_TO}
-# Copy back the generated dist folder
-RUN if [[ ${AIRFLOW_INSTALLATION_METHOD} == "." ]]; then \
- mv -f /tmp/dist /opt/airflow/airflow/www/static/dist; \
- fi;
-
# Add extra python dependencies
ARG ADDITIONAL_PYTHON_DEPS=""
-# We can set this value to true in case we want to install .whl .tar.gz packages placed in the
-# docker-context-files folder. This can be done for both - additional packages you want to install
-# and for airflow as well (you have to set AIRFLOW_IS_IN_CONTEXT to true in this case)
-ARG INSTALL_PACKAGES_FROM_CONTEXT="false"
-# By default we install latest airflow from PyPI or sources. You can set this parameter to false
-# if Airflow is in the .whl or .tar.gz packages placed in `docker-context-files` folder and you want
-# to skip installing Airflow/Providers from PyPI or sources.
-ARG AIRFLOW_IS_IN_CONTEXT="false"
+
# Those are additional constraints that are needed for some extras but we do not want to
# Force them on the main Airflow package.
# * dill<0.3.3 required by apache-beam
-ARG EAGER_UPGRADE_ADDITIONAL_REQUIREMENTS="dill<0.3.3"
+# * pyarrow>=6.0.0 is because pip resolver decides for Python 3.10 to downgrade pyarrow to 5 even if it is OK
+# for python 3.10 and other dependencies adding the limit helps resolver to make better decisions
+# We need to limit the protobuf library to < 4.21.0 because not all google libraries we use
+# are compatible with the new protobuf version. All the google python client libraries need
+# to be upgraded to >=2.0.0 in order to able to lift that limitation
+# https://developers.google.com/protocol-buffers/docs/news/2022-05-06#python-updates
+# * authlib, gcloud_aio_auth, adal are needed to generate constraints for PyPI packages and can be removed after we release
+# new google, azure providers
+# !!! MAKE SURE YOU SYNCHRONIZE THE LIST BETWEEN: Dockerfile, Dockerfile.ci, find_newer_dependencies.py
+ARG EAGER_UPGRADE_ADDITIONAL_REQUIREMENTS="dill<0.3.3 pyarrow>=6.0.0 protobuf<4.21.0 authlib>=1.0.0 gcloud_aio_auth>=4.0.0 adal>=1.2.7"
ENV ADDITIONAL_PYTHON_DEPS=${ADDITIONAL_PYTHON_DEPS} \
INSTALL_PACKAGES_FROM_CONTEXT=${INSTALL_PACKAGES_FROM_CONTEXT} \
- AIRFLOW_IS_IN_CONTEXT=${AIRFLOW_IS_IN_CONTEXT} \
EAGER_UPGRADE_ADDITIONAL_REQUIREMENTS=${EAGER_UPGRADE_ADDITIONAL_REQUIREMENTS}
-WORKDIR /opt/airflow
+WORKDIR ${AIRFLOW_HOME}
COPY --from=scripts install_from_docker_context_files.sh install_airflow.sh \
install_additional_dependencies.sh /scripts/docker/
@@ -1345,7 +1240,8 @@ COPY --from=scripts install_from_docker_context_files.sh install_airflow.sh \
# hadolint ignore=SC2086, SC2010
RUN if [[ ${INSTALL_PACKAGES_FROM_CONTEXT} == "true" ]]; then \
bash /scripts/docker/install_from_docker_context_files.sh; \
- elif [[ ${AIRFLOW_IS_IN_CONTEXT} == "false" ]]; then \
+ fi; \
+ if ! airflow version 2>/dev/null >/dev/null; then \
bash /scripts/docker/install_airflow.sh; \
fi; \
if [[ -n "${ADDITIONAL_PYTHON_DEPS}" ]]; then \
@@ -1388,32 +1284,10 @@ ARG AIRFLOW_PIP_VERSION
ENV PYTHON_BASE_IMAGE=${PYTHON_BASE_IMAGE} \
# Make sure noninteractive debian install is used and language variables set
DEBIAN_FRONTEND=noninteractive LANGUAGE=C.UTF-8 LANG=C.UTF-8 LC_ALL=C.UTF-8 \
- LC_CTYPE=C.UTF-8 LC_MESSAGES=C.UTF-8 \
+ LC_CTYPE=C.UTF-8 LC_MESSAGES=C.UTF-8 LD_LIBRARY_PATH=/usr/local/lib \
AIRFLOW_PIP_VERSION=${AIRFLOW_PIP_VERSION}
-ARG RUNTIME_APT_DEPS="\
- apt-transport-https \
- apt-utils \
- ca-certificates \
- curl \
- dumb-init \
- freetds-bin \
- gosu \
- krb5-user \
- ldap-utils \
- libldap-2.4-2 \
- libsasl2-2 \
- libsasl2-modules \
- libssl1.1 \
- locales \
- lsb-release \
- netcat \
- openssh-client \
- rsync \
- sasl2-bin \
- sqlite3 \
- sudo \
- unixodbc"
+ARG RUNTIME_APT_DEPS=""
ARG ADDITIONAL_RUNTIME_APT_DEPS=""
ARG RUNTIME_APT_COMMAND="echo"
ARG ADDITIONAL_RUNTIME_APT_COMMAND=""
@@ -1432,25 +1306,8 @@ ENV RUNTIME_APT_DEPS=${RUNTIME_APT_DEPS} \
GUNICORN_CMD_ARGS="--worker-tmp-dir /dev/shm" \
AIRFLOW_INSTALLATION_METHOD=${AIRFLOW_INSTALLATION_METHOD}
-COPY --from=scripts determine_debian_version_specific_variables.sh /scripts/docker/
-
-# Install basic and additional apt dependencies
-RUN apt-get update \
- && apt-get install --no-install-recommends -yqq apt-utils >/dev/null 2>&1 \
- && apt-get install -y --no-install-recommends curl gnupg2 lsb-release \
- && export ${ADDITIONAL_RUNTIME_APT_ENV?} \
- && source /scripts/docker/determine_debian_version_specific_variables.sh \
- && bash -o pipefail -o errexit -o nounset -o nolog -c "${RUNTIME_APT_COMMAND}" \
- && bash -o pipefail -o errexit -o nounset -o nolog -c "${ADDITIONAL_RUNTIME_APT_COMMAND}" \
- && apt-get update \
- && apt-get install -y --no-install-recommends \
- ${RUNTIME_APT_DEPS} \
- "${DISTRO_LIBFFI}" \
- ${ADDITIONAL_RUNTIME_APT_DEPS} \
- && apt-get autoremove -yqq --purge \
- && apt-get clean \
- && rm -rf /var/lib/apt/lists/* \
- && rm -rf /var/log/*
+COPY --from=scripts install_os_dependencies.sh /scripts/docker/
+RUN bash /scripts/docker/install_os_dependencies.sh runtime
# Having the variable in final image allows to disable providers manager warnings when
# production image is prepared from sources rather than from package
@@ -1472,7 +1329,7 @@ COPY --from=scripts install_mysql.sh install_mssql.sh install_postgres.sh /scrip
# We run scripts with bash here to make sure we can execute the scripts. Changing to +x might have an
# unexpected result - the cache for Dockerfiles might get invalidated in case the host system
# had different umask set and group x bit was not set. In Azure the bit might be not set at all.
-# That also protects against AUFS Docker backen dproblem where changing the executable bit required sync
+# That also protects against AUFS Docker backend problem where changing the executable bit required sync
RUN bash /scripts/docker/install_mysql.sh prod \
&& bash /scripts/docker/install_mssql.sh \
&& bash /scripts/docker/install_postgres.sh prod \
diff --git a/Dockerfile.ci b/Dockerfile.ci
index c089544f244b3..cc28c355d916d 100644
--- a/Dockerfile.ci
+++ b/Dockerfile.ci
@@ -31,39 +31,106 @@ FROM ${PYTHON_BASE_IMAGE} as scripts
# make the PROD Dockerfile standalone
##############################################################################################
-# The content below is automatically copied from scripts/docker/determine_debian_version_specific_variables.sh
-COPY <<"EOF" /determine_debian_version_specific_variables.sh
-function determine_debian_version_specific_variables() {
- local color_red
- color_red=$'\e[31m'
- local color_reset
- color_reset=$'\e[0m'
-
- local debian_version
- debian_version=$(lsb_release -cs)
- if [[ ${debian_version} == "buster" ]]; then
- export DISTRO_LIBENCHANT="libenchant-dev"
- export DISTRO_LIBGCC="libgcc-8-dev"
- export DISTRO_SELINUX="python-selinux"
- export DISTRO_LIBFFI="libffi6"
- # Note missing man directories on debian-buster
- # https://bugs.debian.org/cgi-bin/bugreport.cgi?bug=863199
- mkdir -pv /usr/share/man/man1
- mkdir -pv /usr/share/man/man7
- elif [[ ${debian_version} == "bullseye" ]]; then
- export DISTRO_LIBENCHANT="libenchant-2-2"
- export DISTRO_LIBGCC="libgcc-10-dev"
- export DISTRO_SELINUX="python3-selinux"
- export DISTRO_LIBFFI="libffi7"
+# The content below is automatically copied from scripts/docker/install_os_dependencies.sh
+COPY <<"EOF" /install_os_dependencies.sh
+set -euo pipefail
+
+DOCKER_CLI_VERSION=20.10.9
+
+if [[ "$#" != 1 ]]; then
+ echo "ERROR! There should be 'runtime' or 'dev' parameter passed as argument.".
+ exit 1
+fi
+
+if [[ "${1}" == "runtime" ]]; then
+ INSTALLATION_TYPE="RUNTIME"
+elif [[ "${1}" == "dev" ]]; then
+ INSTALLATION_TYPE="dev"
+else
+ echo "ERROR! Wrong argument. Passed ${1} and it should be one of 'runtime' or 'dev'.".
+ exit 1
+fi
+
+function get_dev_apt_deps() {
+ if [[ "${DEV_APT_DEPS=}" == "" ]]; then
+ DEV_APT_DEPS="apt-transport-https apt-utils build-essential ca-certificates dirmngr \
+freetds-bin freetds-dev git gosu graphviz graphviz-dev krb5-user ldap-utils libffi-dev \
+libkrb5-dev libldap2-dev libleveldb1d libleveldb-dev libsasl2-2 libsasl2-dev libsasl2-modules \
+libssl-dev locales lsb-release openssh-client sasl2-bin \
+software-properties-common sqlite3 sudo unixodbc unixodbc-dev"
+ export DEV_APT_DEPS
+ fi
+}
+
+function get_runtime_apt_deps() {
+ if [[ "${RUNTIME_APT_DEPS=}" == "" ]]; then
+ RUNTIME_APT_DEPS="apt-transport-https apt-utils ca-certificates \
+curl dumb-init freetds-bin gosu krb5-user \
+ldap-utils libffi7 libldap-2.4-2 libsasl2-2 libsasl2-modules libssl1.1 locales \
+lsb-release netcat openssh-client python3-selinux rsync sasl2-bin sqlite3 sudo unixodbc"
+ export RUNTIME_APT_DEPS
+ fi
+}
+
+function install_docker_cli() {
+ local platform
+ if [[ $(uname -m) == "arm64" || $(uname -m) == "aarch64" ]]; then
+ platform="aarch64"
else
- echo
- echo "${color_red}Unknown distro version ${debian_version}${color_reset}"
- echo
- exit 1
+ platform="x86_64"
fi
+ curl --silent \
+ "https://download.docker.com/linux/static/stable/${platform}/docker-${DOCKER_CLI_VERSION}.tgz" \
+ | tar -C /usr/bin --strip-components=1 -xvzf - docker/docker
}
-determine_debian_version_specific_variables
+function install_debian_dev_dependencies() {
+ apt-get update
+ apt-get install --no-install-recommends -yqq apt-utils >/dev/null 2>&1
+ apt-get install -y --no-install-recommends curl gnupg2 lsb-release
+ # shellcheck disable=SC2086
+ export ${ADDITIONAL_DEV_APT_ENV?}
+ if [[ ${DEV_APT_COMMAND} != "" ]]; then
+ bash -o pipefail -o errexit -o nounset -o nolog -c "${DEV_APT_COMMAND}"
+ fi
+ if [[ ${ADDITIONAL_DEV_APT_COMMAND} != "" ]]; then
+ bash -o pipefail -o errexit -o nounset -o nolog -c "${ADDITIONAL_DEV_APT_COMMAND}"
+ fi
+ apt-get update
+ # shellcheck disable=SC2086
+ apt-get install -y --no-install-recommends ${DEV_APT_DEPS} ${ADDITIONAL_DEV_APT_DEPS}
+}
+
+function install_debian_runtime_dependencies() {
+ apt-get update
+ apt-get install --no-install-recommends -yqq apt-utils >/dev/null 2>&1
+ apt-get install -y --no-install-recommends curl gnupg2 lsb-release
+ # shellcheck disable=SC2086
+ export ${ADDITIONAL_RUNTIME_APT_ENV?}
+ if [[ "${RUNTIME_APT_COMMAND}" != "" ]]; then
+ bash -o pipefail -o errexit -o nounset -o nolog -c "${RUNTIME_APT_COMMAND}"
+ fi
+ if [[ "${ADDITIONAL_RUNTIME_APT_COMMAND}" != "" ]]; then
+ bash -o pipefail -o errexit -o nounset -o nolog -c "${ADDITIONAL_RUNTIME_APT_COMMAND}"
+ fi
+ apt-get update
+ # shellcheck disable=SC2086
+ apt-get install -y --no-install-recommends ${RUNTIME_APT_DEPS} ${ADDITIONAL_RUNTIME_APT_DEPS}
+ apt-get autoremove -yqq --purge
+ apt-get clean
+ rm -rf /var/lib/apt/lists/* /var/log/*
+}
+
+if [[ "${INSTALLATION_TYPE}" == "RUNTIME" ]]; then
+ get_runtime_apt_deps
+ install_debian_runtime_dependencies
+ install_docker_cli
+
+else
+ get_dev_apt_deps
+ install_debian_dev_dependencies
+ install_docker_cli
+fi
EOF
# The content below is automatically copied from scripts/docker/install_mysql.sh
@@ -138,8 +205,6 @@ set -euo pipefail
COLOR_BLUE=$'\e[34m'
readonly COLOR_BLUE
-COLOR_YELLOW=$'\e[33m'
-readonly COLOR_YELLOW
COLOR_RESET=$'\e[0m'
readonly COLOR_RESET
@@ -157,19 +222,8 @@ function install_mssql_client() {
local distro
local version
distro=$(lsb_release -is | tr '[:upper:]' '[:lower:]')
- version_name=$(lsb_release -cs | tr '[:upper:]' '[:lower:]')
version=$(lsb_release -rs)
- local driver
- if [[ ${version_name} == "buster" ]]; then
- driver=msodbcsql17
- elif [[ ${version_name} == "bullseye" ]]; then
- driver=msodbcsql18
- else
- echo
- echo "${COLOR_YELLOW}Only Buster or Bullseye are supported. Skipping MSSQL installation${COLOR_RESET}"
- echo
- return
- fi
+ local driver=msodbcsql18
curl --silent https://packages.microsoft.com/keys/microsoft.asc | apt-key add - >/dev/null 2>&1
curl --silent "https://packages.microsoft.com/config/${distro}/${version}/prod.list" > \
/etc/apt/sources.list.d/mssql-release.list
@@ -181,11 +235,6 @@ function install_mssql_client() {
apt-get clean && rm -rf /var/lib/apt/lists/*
}
-if [[ $(uname -m) == "arm64" || $(uname -m) == "aarch64" ]]; then
- # disable MSSQL for ARM64
- INSTALL_MSSQL_CLIENT="false"
-fi
-
install_mssql_client "${@}"
EOF
@@ -236,20 +285,12 @@ COPY <<"EOF" /install_pip_version.sh
: "${AIRFLOW_PIP_VERSION:?Should be set}"
-function install_pip_version() {
- echo
- echo "${COLOR_BLUE}Installing pip version ${AIRFLOW_PIP_VERSION}${COLOR_RESET}"
- echo
- pip install --disable-pip-version-check --no-cache-dir --upgrade "pip==${AIRFLOW_PIP_VERSION}" &&
- mkdir -p ${HOME}/.local/bin
-}
-
common::get_colors
common::get_airflow_version_specification
common::override_pip_version_if_needed
common::show_pip_version_and_location
-install_pip_version
+common::install_pip_version
EOF
# The content below is automatically copied from scripts/docker/install_airflow_dependencies_from_branch_tip.sh
@@ -277,10 +318,10 @@ function install_airflow_dependencies_from_branch_tip() {
# are conflicts, this might fail, but it should be fixed in the following installation steps
set -x
pip install --root-user-action ignore \
+ ${ADDITIONAL_PIP_INSTALL_FLAGS} \
"https://github.com/${AIRFLOW_REPO}/archive/${AIRFLOW_BRANCH}.tar.gz#egg=apache-airflow[${AIRFLOW_EXTRAS}]" \
--constraint "${AIRFLOW_CONSTRAINTS_LOCATION}" || true
- # make sure correct PIP version is used
- pip install --disable-pip-version-check "pip==${AIRFLOW_PIP_VERSION}" 2>/dev/null
+ common::install_pip_version
pip freeze | grep apache-airflow-providers | xargs pip uninstall --yes 2>/dev/null || true
set +x
echo
@@ -327,7 +368,7 @@ function common::get_airflow_version_specification() {
function common::override_pip_version_if_needed() {
if [[ -n ${AIRFLOW_VERSION} ]]; then
if [[ ${AIRFLOW_VERSION} =~ ^2\.0.* || ${AIRFLOW_VERSION} =~ ^1\.* ]]; then
- export AIRFLOW_PIP_VERSION="22.1.2"
+ export AIRFLOW_PIP_VERSION="22.3.1"
fi
fi
}
@@ -355,6 +396,18 @@ function common::show_pip_version_and_location() {
echo "pip on path: $(which pip)"
echo "Using pip: $(pip --version)"
}
+
+function common::install_pip_version() {
+ echo
+ echo "${COLOR_BLUE}Installing pip version ${AIRFLOW_PIP_VERSION}${COLOR_RESET}"
+ echo
+ if [[ ${AIRFLOW_PIP_VERSION} =~ .*https.* ]]; then
+ pip install --disable-pip-version-check --no-cache-dir "pip @ ${AIRFLOW_PIP_VERSION}"
+ else
+ pip install --disable-pip-version-check --no-cache-dir "pip==${AIRFLOW_PIP_VERSION}"
+ fi
+ mkdir -p "${HOME}/.local/bin"
+}
EOF
# The content below is automatically copied from scripts/docker/install_pipx_tools.sh
@@ -384,101 +437,6 @@ common::get_colors
install_pipx_tools
EOF
-# The content below is automatically copied from scripts/docker/prepare_node_modules.sh
-COPY <<"EOF" /prepare_node_modules.sh
-set -euo pipefail
-
-COLOR_BLUE=$'\e[34m'
-readonly COLOR_BLUE
-COLOR_RESET=$'\e[0m'
-readonly COLOR_RESET
-
-function prepare_node_modules() {
- echo
- echo "${COLOR_BLUE}Preparing node modules${COLOR_RESET}"
- echo
- local www_dir
- if [[ ${AIRFLOW_INSTALLATION_METHOD=} == "." ]]; then
- # In case we are building from sources in production image, we should build the assets
- www_dir="${AIRFLOW_SOURCES_TO=${AIRFLOW_SOURCES}}/airflow/www"
- else
- www_dir="$(python -m site --user-site)/airflow/www"
- fi
- pushd ${www_dir} || exit 1
- set +e
- yarn install --frozen-lockfile --no-cache 2>/tmp/out-yarn-install.txt
- local res=$?
- if [[ ${res} != 0 ]]; then
- >&2 echo
- >&2 echo "Error when running yarn install:"
- >&2 echo
- >&2 cat /tmp/out-yarn-install.txt && rm -f /tmp/out-yarn-install.txt
- exit 1
- fi
- rm -f /tmp/out-yarn-install.txt
- popd || exit 1
-}
-
-prepare_node_modules
-EOF
-
-# The content below is automatically copied from scripts/docker/compile_www_assets.sh
-COPY <<"EOF" /compile_www_assets.sh
-set -euo pipefail
-
-BUILD_TYPE=${BUILD_TYPE="prod"}
-REMOVE_ARTIFACTS=${REMOVE_ARTIFACTS="true"}
-
-COLOR_BLUE=$'\e[34m'
-readonly COLOR_BLUE
-COLOR_RESET=$'\e[0m'
-readonly COLOR_RESET
-
-function compile_www_assets() {
- echo
- echo "${COLOR_BLUE}Compiling www assets: running yarn ${BUILD_TYPE}${COLOR_RESET}"
- echo
- local www_dir
- if [[ ${AIRFLOW_INSTALLATION_METHOD=} == "." ]]; then
- # In case we are building from sources in production image, we should build the assets
- www_dir="${AIRFLOW_SOURCES_TO=${AIRFLOW_SOURCES}}/airflow/www"
- else
- www_dir="$(python -m site --user-site)/airflow/www"
- fi
- pushd ${www_dir} || exit 1
- set +e
- yarn run "${BUILD_TYPE}" 2>/tmp/out-yarn-run.txt
- res=$?
- if [[ ${res} != 0 ]]; then
- >&2 echo
- >&2 echo "Error when running yarn run:"
- >&2 echo
- >&2 cat /tmp/out-yarn-run.txt && rm -rf /tmp/out-yarn-run.txt
- exit 1
- fi
- rm -f /tmp/out-yarn-run.txt
- set -e
- local md5sum_file
- md5sum_file="static/dist/sum.md5"
- readonly md5sum_file
- find package.json yarn.lock static/css static/js -type f | sort | xargs md5sum > "${md5sum_file}"
- if [[ ${REMOVE_ARTIFACTS} == "true" ]]; then
- echo
- echo "${COLOR_BLUE}Removing generated node modules${COLOR_RESET}"
- echo
- rm -rf "${www_dir}/node_modules"
- rm -vf "${www_dir}"/{package.json,yarn.lock,.eslintignore,.eslintrc,.stylelintignore,.stylelintrc,compile_assets.sh,webpack.config.js}
- else
- echo
- echo "${COLOR_BLUE}Leaving generated node modules${COLOR_RESET}"
- echo
- fi
- popd || exit 1
-}
-
-compile_www_assets
-EOF
-
# The content below is automatically copied from scripts/docker/install_airflow.sh
COPY <<"EOF" /install_airflow.sh
@@ -511,6 +469,7 @@ function install_airflow() {
echo
# eager upgrade
pip install --root-user-action ignore --upgrade --upgrade-strategy eager \
+ ${ADDITIONAL_PIP_INSTALL_FLAGS} \
"${AIRFLOW_INSTALLATION_METHOD}[${AIRFLOW_EXTRAS}]${AIRFLOW_VERSION_SPECIFICATION}" \
${EAGER_UPGRADE_ADDITIONAL_REQUIREMENTS}
if [[ -n "${AIRFLOW_INSTALL_EDITABLE_FLAG}" ]]; then
@@ -519,12 +478,12 @@ function install_airflow() {
set -x
pip uninstall apache-airflow --yes
pip install --root-user-action ignore ${AIRFLOW_INSTALL_EDITABLE_FLAG} \
+ ${ADDITIONAL_PIP_INSTALL_FLAGS} \
"${AIRFLOW_INSTALLATION_METHOD}[${AIRFLOW_EXTRAS}]${AIRFLOW_VERSION_SPECIFICATION}"
set +x
fi
- # make sure correct PIP version is used
- pip install --disable-pip-version-check "pip==${AIRFLOW_PIP_VERSION}" 2>/dev/null
+ common::install_pip_version
echo
echo "${COLOR_BLUE}Running 'pip check'${COLOR_RESET}"
echo
@@ -535,16 +494,16 @@ function install_airflow() {
echo
set -x
pip install --root-user-action ignore ${AIRFLOW_INSTALL_EDITABLE_FLAG} \
+ ${ADDITIONAL_PIP_INSTALL_FLAGS} \
"${AIRFLOW_INSTALLATION_METHOD}[${AIRFLOW_EXTRAS}]${AIRFLOW_VERSION_SPECIFICATION}" \
--constraint "${AIRFLOW_CONSTRAINTS_LOCATION}"
- # make sure correct PIP version is used
- pip install --disable-pip-version-check "pip==${AIRFLOW_PIP_VERSION}" 2>/dev/null
+ common::install_pip_version
# then upgrade if needed without using constraints to account for new limits in setup.py
pip install --root-user-action ignore --upgrade --upgrade-strategy only-if-needed \
+ ${ADDITIONAL_PIP_INSTALL_FLAGS} \
${AIRFLOW_INSTALL_EDITABLE_FLAG} \
"${AIRFLOW_INSTALLATION_METHOD}[${AIRFLOW_EXTRAS}]${AIRFLOW_VERSION_SPECIFICATION}"
- # make sure correct PIP version is used
- pip install --disable-pip-version-check "pip==${AIRFLOW_PIP_VERSION}" 2>/dev/null
+ common::install_pip_version
set +x
echo
echo "${COLOR_BLUE}Running 'pip check'${COLOR_RESET}"
@@ -581,9 +540,9 @@ function install_additional_dependencies() {
echo
set -x
pip install --root-user-action ignore --upgrade --upgrade-strategy eager \
+ ${ADDITIONAL_PIP_INSTALL_FLAGS} \
${ADDITIONAL_PYTHON_DEPS} ${EAGER_UPGRADE_ADDITIONAL_REQUIREMENTS}
- # make sure correct PIP version is used
- pip install --disable-pip-version-check "pip==${AIRFLOW_PIP_VERSION}" 2>/dev/null
+ common::install_pip_version
set +x
echo
echo "${COLOR_BLUE}Running 'pip check'${COLOR_RESET}"
@@ -595,9 +554,9 @@ function install_additional_dependencies() {
echo
set -x
pip install --root-user-action ignore --upgrade --upgrade-strategy only-if-needed \
+ ${ADDITIONAL_PIP_INSTALL_FLAGS} \
${ADDITIONAL_PYTHON_DEPS}
- # make sure correct PIP version is used
- pip install --disable-pip-version-check "pip==${AIRFLOW_PIP_VERSION}" 2>/dev/null
+ common::install_pip_version
set +x
echo
echo "${COLOR_BLUE}Running 'pip check'${COLOR_RESET}"
@@ -622,7 +581,7 @@ if [[ ${VERBOSE_COMMANDS:="false"} == "true" ]]; then
set -x
fi
-. /opt/airflow/scripts/in_container/_in_container_script_init.sh
+. "${AIRFLOW_SOURCES:-/opt/airflow}"/scripts/in_container/_in_container_script_init.sh
LD_PRELOAD="/usr/lib/$(uname -m)-linux-gnu/libstdc++.so.6"
export LD_PRELOAD
@@ -637,6 +596,35 @@ export AIRFLOW_HOME=${AIRFLOW_HOME:=${HOME}}
: "${AIRFLOW_SOURCES:?"ERROR: AIRFLOW_SOURCES not set !!!!"}"
+function wait_for_asset_compilation() {
+ if [[ -f "${AIRFLOW_SOURCES}/.build/www/.asset_compile.lock" ]]; then
+ echo
+ echo "${COLOR_YELLOW}Waiting for asset compilation to complete in the background.${COLOR_RESET}"
+ echo
+ local counter=0
+ while [[ -f "${AIRFLOW_SOURCES}/.build/www/.asset_compile.lock" ]]; do
+ echo "${COLOR_BLUE}Still waiting .....${COLOR_RESET}"
+ sleep 1
+ ((counter=counter+1))
+ if [[ ${counter} == "30" ]]; then
+ echo
+ echo "${COLOR_YELLOW}The asset compilation is taking too long.${COLOR_YELLOW}"
+ echo """
+If it does not complete soon, you might want to stop it and remove file lock:
+ * press Ctrl-C
+ * run 'rm ${AIRFLOW_SOURCES}/.build/www/.asset_compile.lock'
+"""
+ fi
+ if [[ ${counter} == "60" ]]; then
+ echo
+ echo "${COLOR_RED}The asset compilation is taking too long. Exiting.${COLOR_RED}"
+ echo
+ exit 1
+ fi
+ done
+ fi
+}
+
if [[ ${SKIP_ENVIRONMENT_INITIALIZATION=} != "true" ]]; then
if [[ $(uname -m) == "arm64" || $(uname -m) == "aarch64" ]]; then
@@ -657,17 +645,13 @@ if [[ ${SKIP_ENVIRONMENT_INITIALIZATION=} != "true" ]]; then
RUN_TESTS=${RUN_TESTS:="false"}
CI=${CI:="false"}
USE_AIRFLOW_VERSION="${USE_AIRFLOW_VERSION:=""}"
+ TEST_TIMEOUT=${TEST_TIMEOUT:="60"}
if [[ ${USE_AIRFLOW_VERSION} == "" ]]; then
export PYTHONPATH=${AIRFLOW_SOURCES}
echo
echo "${COLOR_BLUE}Using airflow version from current sources${COLOR_RESET}"
echo
- if [[ -d "${AIRFLOW_SOURCES}/airflow/www/" ]]; then
- pushd "${AIRFLOW_SOURCES}/airflow/www/" >/dev/null
- ./ask_for_recompile_assets_if_needed.sh
- popd >/dev/null
- fi
# Cleanup the logs, tmp when entering the environment
sudo rm -rf "${AIRFLOW_SOURCES}"/logs/*
sudo rm -rf "${AIRFLOW_SOURCES}"/tmp/*
@@ -686,9 +670,15 @@ if [[ ${SKIP_ENVIRONMENT_INITIALIZATION=} != "true" ]]; then
echo "${COLOR_BLUE}Uninstalling airflow and providers"
echo
uninstall_airflow_and_providers
- echo "${COLOR_BLUE}Install airflow from wheel package with extras: '${AIRFLOW_EXTRAS}' and constraints reference ${AIRFLOW_CONSTRAINTS_REFERENCE}.${COLOR_RESET}"
- echo
- install_airflow_from_wheel "${AIRFLOW_EXTRAS}" "${AIRFLOW_CONSTRAINTS_REFERENCE}"
+ if [[ ${SKIP_CONSTRAINTS,,=} == "true" ]]; then
+ echo "${COLOR_BLUE}Install airflow from wheel package with extras: '${AIRFLOW_EXTRAS}' with no constraints.${COLOR_RESET}"
+ echo
+ install_airflow_from_wheel "${AIRFLOW_EXTRAS}" "none"
+ else
+ echo "${COLOR_BLUE}Install airflow from wheel package with extras: '${AIRFLOW_EXTRAS}' and constraints reference ${AIRFLOW_CONSTRAINTS_REFERENCE}.${COLOR_RESET}"
+ echo
+ install_airflow_from_wheel "${AIRFLOW_EXTRAS}" "${AIRFLOW_CONSTRAINTS_REFERENCE}"
+ fi
uninstall_providers
elif [[ ${USE_AIRFLOW_VERSION} == "sdist" ]]; then
echo
@@ -696,9 +686,15 @@ if [[ ${SKIP_ENVIRONMENT_INITIALIZATION=} != "true" ]]; then
echo
uninstall_airflow_and_providers
echo
- echo "${COLOR_BLUE}Install airflow from sdist package with extras: '${AIRFLOW_EXTRAS}' and constraints reference ${AIRFLOW_CONSTRAINTS_REFERENCE}.${COLOR_RESET}"
- echo
- install_airflow_from_sdist "${AIRFLOW_EXTRAS}" "${AIRFLOW_CONSTRAINTS_REFERENCE}"
+ if [[ ${SKIP_CONSTRAINTS,,=} == "true" ]]; then
+ echo "${COLOR_BLUE}Install airflow from sdist package with extras: '${AIRFLOW_EXTRAS}' with no constraints.${COLOR_RESET}"
+ echo
+ install_airflow_from_sdist "${AIRFLOW_EXTRAS}" "none"
+ else
+ echo "${COLOR_BLUE}Install airflow from sdist package with extras: '${AIRFLOW_EXTRAS}' and constraints reference ${AIRFLOW_CONSTRAINTS_REFERENCE}.${COLOR_RESET}"
+ echo
+ install_airflow_from_sdist "${AIRFLOW_EXTRAS}" "${AIRFLOW_CONSTRAINTS_REFERENCE}"
+ fi
uninstall_providers
else
echo
@@ -706,9 +702,19 @@ if [[ ${SKIP_ENVIRONMENT_INITIALIZATION=} != "true" ]]; then
echo
uninstall_airflow_and_providers
echo
- echo "${COLOR_BLUE}Install released airflow from PyPI with extras: '${AIRFLOW_EXTRAS}' and constraints reference ${AIRFLOW_CONSTRAINTS_REFERENCE}.${COLOR_RESET}"
- echo
- install_released_airflow_version "${USE_AIRFLOW_VERSION}" "${AIRFLOW_CONSTRAINTS_REFERENCE}"
+ if [[ ${SKIP_CONSTRAINTS,,=} == "true" ]]; then
+ echo "${COLOR_BLUE}Install released airflow from PyPI with extras: '${AIRFLOW_EXTRAS}' with no constraints.${COLOR_RESET}"
+ echo
+ install_released_airflow_version "${USE_AIRFLOW_VERSION}" "none"
+ else
+ echo "${COLOR_BLUE}Install released airflow from PyPI with extras: '${AIRFLOW_EXTRAS}' and constraints reference ${AIRFLOW_CONSTRAINTS_REFERENCE}.${COLOR_RESET}"
+ echo
+ install_released_airflow_version "${USE_AIRFLOW_VERSION}" "${AIRFLOW_CONSTRAINTS_REFERENCE}"
+ fi
+ if [[ "${USE_AIRFLOW_VERSION}" =~ ^2\.2\..*|^2\.1\..*|^2\.0\..* && "${AIRFLOW__DATABASE__SQL_ALCHEMY_CONN=}" != "" ]]; then
+ # make sure old variable is used for older airflow versions
+ export AIRFLOW__CORE__SQL_ALCHEMY_CONN="${AIRFLOW__DATABASE__SQL_ALCHEMY_CONN}"
+ fi
fi
if [[ ${USE_PACKAGES_FROM_DIST=} == "true" ]]; then
echo
@@ -768,11 +774,12 @@ if [[ ${SKIP_ENVIRONMENT_INITIALIZATION=} != "true" ]]; then
echo
exit ${ENVIRONMENT_EXIT_CODE}
fi
- # Create symbolic link to fix possible issues with kubectl config cmd-path
mkdir -p /usr/lib/google-cloud-sdk/bin
touch /usr/lib/google-cloud-sdk/bin/gcloud
ln -s -f /usr/bin/gcloud /usr/lib/google-cloud-sdk/bin/gcloud
+ in_container_fix_ownership
+
if [[ ${SKIP_SSH_SETUP="false"} == "false" ]]; then
# Set up ssh keys
echo 'yes' | ssh-keygen -t rsa -C your_email@youremail.com -m PEM -P '' -f ~/.ssh/id_rsa \
@@ -803,8 +810,9 @@ if [[ ${SKIP_ENVIRONMENT_INITIALIZATION=} != "true" ]]; then
cd "${AIRFLOW_SOURCES}"
if [[ ${START_AIRFLOW:="false"} == "true" || ${START_AIRFLOW} == "True" ]]; then
- export AIRFLOW__CORE__LOAD_DEFAULT_CONNECTIONS=${LOAD_DEFAULT_CONNECTIONS}
+ export AIRFLOW__DATABASE__LOAD_DEFAULT_CONNECTIONS=${LOAD_DEFAULT_CONNECTIONS}
export AIRFLOW__CORE__LOAD_EXAMPLES=${LOAD_EXAMPLES}
+ wait_for_asset_compilation
# shellcheck source=scripts/in_container/bin/run_tmux
exec run_tmux
fi
@@ -816,7 +824,8 @@ if [[ "${RUN_TESTS}" != "true" ]]; then
fi
set -u
-export RESULT_LOG_FILE="/files/test_result-${TEST_TYPE}-${BACKEND}.xml"
+export RESULT_LOG_FILE="/files/test_result-${TEST_TYPE/\[*\]/}-${BACKEND}.xml"
+export WARNINGS_FILE="/files/warnings-${TEST_TYPE/\[*\]/}-${BACKEND}.txt"
EXTRA_PYTEST_ARGS=(
"--verbosity=0"
@@ -828,9 +837,11 @@ EXTRA_PYTEST_ARGS=(
# timeouts in seconds for individual tests
"--timeouts-order"
"moi"
- "--setup-timeout=60"
- "--execution-timeout=60"
- "--teardown-timeout=60"
+ "--setup-timeout=${TEST_TIMEOUT}"
+ "--execution-timeout=${TEST_TIMEOUT}"
+ "--teardown-timeout=${TEST_TIMEOUT}"
+ "--output=${WARNINGS_FILE}"
+ "--disable-warnings"
# Only display summary for non-expected case
# f - failed
# E - error
@@ -844,9 +855,11 @@ EXTRA_PYTEST_ARGS=(
)
if [[ "${TEST_TYPE}" == "Helm" ]]; then
+ _cpus="$(grep -c 'cpu[0-9]' /proc/stat)"
+ echo "Running tests with ${_cpus} CPUs in parallel"
# Enable parallelism
EXTRA_PYTEST_ARGS+=(
- "-n" "auto"
+ "-n" "${_cpus}"
)
else
EXTRA_PYTEST_ARGS+=(
@@ -858,7 +871,7 @@ if [[ ${ENABLE_TEST_COVERAGE:="false"} == "true" ]]; then
EXTRA_PYTEST_ARGS+=(
"--cov=airflow/"
"--cov-config=.coveragerc"
- "--cov-report=xml:/files/coverage-${TEST_TYPE}-${BACKEND}.xml"
+ "--cov-report=xml:/files/coverage-${TEST_TYPE/\[*\]/}-${BACKEND}.xml"
)
fi
@@ -900,11 +913,13 @@ else
)
WWW_TESTS=("tests/www")
HELM_CHART_TESTS=("tests/charts")
+ INTEGRATION_TESTS=("tests/integration")
ALL_TESTS=("tests")
ALL_PRESELECTED_TESTS=(
"${CLI_TESTS[@]}"
"${API_TESTS[@]}"
"${HELM_CHART_TESTS[@]}"
+ "${INTEGRATION_TESTS[@]}"
"${PROVIDERS_TESTS[@]}"
"${CORE_TESTS[@]}"
"${ALWAYS_TESTS[@]}"
@@ -925,33 +940,38 @@ else
SELECTED_TESTS=("${WWW_TESTS[@]}")
elif [[ ${TEST_TYPE:=""} == "Helm" ]]; then
SELECTED_TESTS=("${HELM_CHART_TESTS[@]}")
+ elif [[ ${TEST_TYPE:=""} == "Integration" ]]; then
+ SELECTED_TESTS=("${INTEGRATION_TESTS[@]}")
elif [[ ${TEST_TYPE:=""} == "Other" ]]; then
find_all_other_tests
SELECTED_TESTS=("${ALL_OTHER_TESTS[@]}")
elif [[ ${TEST_TYPE:=""} == "All" || ${TEST_TYPE} == "Quarantined" || \
${TEST_TYPE} == "Always" || \
${TEST_TYPE} == "Postgres" || ${TEST_TYPE} == "MySQL" || \
- ${TEST_TYPE} == "Long" || \
- ${TEST_TYPE} == "Integration" ]]; then
+ ${TEST_TYPE} == "Long" ]]; then
SELECTED_TESTS=("${ALL_TESTS[@]}")
+ elif [[ ${TEST_TYPE} =~ Providers\[(.*)\] ]]; then
+ SELECTED_TESTS=()
+ for provider in ${BASH_REMATCH[1]//,/ }
+ do
+ providers_dir="tests/providers/${provider//./\/}"
+ if [[ -d ${providers_dir} ]]; then
+ SELECTED_TESTS+=("${providers_dir}")
+ else
+ echo "${COLOR_YELLOW}Skip ${providers_dir} as the directory does not exist.${COLOR_RESET}"
+ fi
+ done
else
echo
echo "${COLOR_RED}ERROR: Wrong test type ${TEST_TYPE} ${COLOR_RESET}"
echo
exit 1
fi
-
fi
readonly SELECTED_TESTS CLI_TESTS API_TESTS PROVIDERS_TESTS CORE_TESTS WWW_TESTS \
ALL_TESTS ALL_PRESELECTED_TESTS
-if [[ -n ${LIST_OF_INTEGRATION_TESTS_TO_RUN=} ]]; then
- # Integration tests
- for INT in ${LIST_OF_INTEGRATION_TESTS_TO_RUN}
- do
- EXTRA_PYTEST_ARGS+=("--integration" "${INT}")
- done
-elif [[ ${TEST_TYPE:=""} == "Long" ]]; then
+if [[ ${TEST_TYPE:=""} == "Long" ]]; then
EXTRA_PYTEST_ARGS+=(
"-m" "long_running"
"--include-long-running"
@@ -998,16 +1018,6 @@ COPY <<"EOF" /entrypoint_exec.sh
exec /bin/bash "${@}"
EOF
-##############################################################################################
-# This is the www image where we keep all inlined files needed to build ui
-# It is copied separately to volume to speed up building and avoid cache miss on changed
-# file permissions.
-# We use PYTHON_BASE_IMAGE to make sure that the scripts are different for different platforms.
-##############################################################################################
-FROM ${PYTHON_BASE_IMAGE} as www
-COPY airflow/www/package.json airflow/www/yarn.lock airflow/www/webpack.config.js /
-COPY airflow/www/static/ /static
-
FROM ${PYTHON_BASE_IMAGE} as main
# Nolog bash flag is currently ignored - but you can replace it with other flags (for example
@@ -1018,7 +1028,7 @@ ARG PYTHON_BASE_IMAGE
ARG AIRFLOW_IMAGE_REPOSITORY="https://github.com/apache/airflow"
# By increasing this number we can do force build of all dependencies
-ARG DEPENDENCIES_EPOCH_NUMBER="6"
+ARG DEPENDENCIES_EPOCH_NUMBER="7"
# Make sure noninteractive debian install is used and language variables set
ENV PYTHON_BASE_IMAGE=${PYTHON_BASE_IMAGE} \
@@ -1031,68 +1041,35 @@ ENV PYTHON_BASE_IMAGE=${PYTHON_BASE_IMAGE} \
RUN echo "Base image version: ${PYTHON_BASE_IMAGE}"
-ARG ADDITIONAL_DEV_APT_DEPS=""
-ARG DEV_APT_COMMAND="\
- curl --silent --fail --location https://deb.nodesource.com/setup_14.x | bash - \
- && curl --silent --fail https://dl.yarnpkg.com/debian/pubkey.gpg | apt-key add - >/dev/null 2>&1 \
- && echo 'deb https://dl.yarnpkg.com/debian/ stable main' > /etc/apt/sources.list.d/yarn.list"
+ARG ADDITIONAL_DEV_APT_DEPS="git graphviz gosu libpq-dev netcat rsync"
+ARG DEV_APT_COMMAND=""
ARG ADDITIONAL_DEV_APT_COMMAND=""
ARG ADDITIONAL_DEV_ENV_VARS=""
+ARG ADDITIONAL_DEV_APT_DEPS="bash-completion dumb-init git graphviz gosu krb5-user \
+less libenchant-2-2 libgcc-10-dev libpq-dev net-tools netcat \
+openssh-server postgresql-client software-properties-common rsync tmux unzip vim xxd"
+
+ARG ADDITIONAL_DEV_APT_ENV=""
ENV DEV_APT_COMMAND=${DEV_APT_COMMAND} \
ADDITIONAL_DEV_APT_DEPS=${ADDITIONAL_DEV_APT_DEPS} \
ADDITIONAL_DEV_APT_COMMAND=${ADDITIONAL_DEV_APT_COMMAND}
-COPY --from=scripts determine_debian_version_specific_variables.sh /scripts/docker/
-
-# Install basic and additional apt dependencies
-RUN apt-get update \
- && apt-get install --no-install-recommends -yqq apt-utils >/dev/null 2>&1 \
- && apt-get install -y --no-install-recommends curl gnupg2 lsb-release \
- && mkdir -pv /usr/share/man/man1 \
- && mkdir -pv /usr/share/man/man7 \
- && export ${ADDITIONAL_DEV_ENV_VARS?} \
- && source /scripts/docker/determine_debian_version_specific_variables.sh \
- && bash -o pipefail -o errexit -o nounset -o nolog -c "${DEV_APT_COMMAND}" \
- && bash -o pipefail -o errexit -o nounset -o nolog -c "${ADDITIONAL_DEV_APT_COMMAND}" \
- && apt-get update \
- && apt-get install -y --no-install-recommends \
- apt-utils \
- build-essential \
- dirmngr \
- dumb-init \
- freetds-bin \
- freetds-dev \
- git \
- graphviz \
- gosu \
- libffi-dev \
- libldap2-dev \
- libkrb5-dev \
- libpq-dev \
- libsasl2-2 \
- libsasl2-dev \
- libsasl2-modules \
- libssl-dev \
- "${DISTRO_LIBENCHANT}" \
- locales \
- netcat \
- nodejs \
- rsync \
- sasl2-bin \
- sudo \
- unixodbc \
- unixodbc-dev \
- yarn \
- ${ADDITIONAL_DEV_APT_DEPS} \
- && apt-get autoremove -yqq --purge \
- && apt-get clean \
- && rm -rf /var/lib/apt/lists/*
+COPY --from=scripts install_os_dependencies.sh /scripts/docker/
+RUN bash /scripts/docker/install_os_dependencies.sh dev
# Only copy mysql/mssql installation scripts for now - so that changing the other
# scripts which are needed much later will not invalidate the docker layer here.
COPY --from=scripts install_mysql.sh install_mssql.sh install_postgres.sh /scripts/docker/
+ARG HOME=/root
+ARG AIRFLOW_HOME=/root/airflow
+ARG AIRFLOW_SOURCES=/opt/airflow
+
+ENV HOME=${HOME} \
+ AIRFLOW_HOME=${AIRFLOW_HOME} \
+ AIRFLOW_SOURCES=${AIRFLOW_SOURCES}
+
# We run scripts with bash here to make sure we can execute the scripts. Changing to +x might have an
# unexpected result - the cache for Dockerfiles might get invalidated in case the host system
# had different umask set and group x bit was not set. In Azure the bit might be not set at all.
@@ -1107,28 +1084,8 @@ RUN bash /scripts/docker/install_mysql.sh prod \
&& echo "airflow ALL=(ALL) NOPASSWD: ALL" > /etc/sudoers.d/airflow \
&& chmod 0440 /etc/sudoers.d/airflow
-ARG RUNTIME_APT_DEPS="\
- apt-transport-https \
- bash-completion \
- ca-certificates \
- software-properties-common \
- krb5-user \
- krb5-user \
- ldap-utils \
- less \
- lsb-release \
- net-tools \
- openssh-client \
- openssh-server \
- postgresql-client \
- sqlite3 \
- tmux \
- unzip \
- vim \
- xxd"
-
# Install Helm
-ARG HELM_VERSION="v3.6.3"
+ARG HELM_VERSION="v3.9.4"
RUN SYSTEM=$(uname -s | tr '[:upper:]' '[:lower:]') \
&& PLATFORM=$([ "$(uname -m)" = "aarch64" ] && echo "arm64" || echo "amd64" ) \
@@ -1136,42 +1093,6 @@ RUN SYSTEM=$(uname -s | tr '[:upper:]' '[:lower:]') \
&& curl --silent --location "${HELM_URL}" | tar -xz -O "${SYSTEM}-${PLATFORM}/helm" > /usr/local/bin/helm \
&& chmod +x /usr/local/bin/helm
-ARG ADDITIONAL_RUNTIME_APT_DEPS=""
-ARG RUNTIME_APT_COMMAND=""
-ARG ADDITIONAL_RUNTIME_APT_COMMAND=""
-ARG ADDITIONAL_DEV_APT_ENV=""
-ARG ADDITIONAL_RUNTIME_APT_ENV=""
-
-ARG DOCKER_CLI_VERSION=19.03.9
-ARG HOME=/root
-ARG AIRFLOW_HOME=/root/airflow
-ARG AIRFLOW_SOURCES=/opt/airflow
-
-ENV RUNTIME_APT_DEP=${RUNTIME_APT_DEPS} \
- ADDITIONAL_RUNTIME_APT_DEPS=${ADDITIONAL_RUNTIME_APT_DEPS} \
- RUNTIME_APT_COMMAND=${RUNTIME_APT_COMMAND} \
- ADDITIONAL_RUNTIME_APT_COMMAND=${ADDITIONAL_RUNTIME_APT_COMMAND}\
- DOCKER_CLI_VERSION=${DOCKER_CLI_VERSION} \
- HOME=${HOME} \
- AIRFLOW_HOME=${AIRFLOW_HOME} \
- AIRFLOW_SOURCES=${AIRFLOW_SOURCES}
-
-RUN export ${ADDITIONAL_DEV_APT_ENV?} \
- && export ${ADDITIONAL_RUNTIME_APT_ENV?} \
- && source /scripts/docker/determine_debian_version_specific_variables.sh \
- && bash -o pipefail -o errexit -o nounset -o nolog -c "${RUNTIME_APT_COMMAND}" \
- && bash -o pipefail -o errexit -o nounset -o nolog -c "${ADDITIONAL_RUNTIME_APT_COMMAND}" \
- && apt-get update \
- && apt-get install --no-install-recommends -y \
- "${DISTRO_LIBGCC}" \
- ${RUNTIME_APT_DEPS} \
- ${ADDITIONAL_RUNTIME_APT_DEPS} \
- && apt-get autoremove -yqq --purge \
- && apt-get clean \
- && rm -rf /var/lib/apt/lists/* \
- && curl --silent "https://download.docker.com/linux/static/stable/x86_64/docker-${DOCKER_CLI_VERSION}.tgz" \
- | tar -C /usr/bin --strip-components=1 -xvzf - docker/docker
-
WORKDIR ${AIRFLOW_SOURCES}
RUN mkdir -pv ${AIRFLOW_HOME} && \
@@ -1195,7 +1116,7 @@ ARG AIRFLOW_CI_BUILD_EPOCH="3"
ARG AIRFLOW_PRE_CACHED_PIP_PACKAGES="true"
# By default in the image, we are installing all providers when installing from sources
ARG INSTALL_PROVIDERS_FROM_SOURCES="true"
-ARG AIRFLOW_PIP_VERSION=22.1.2
+ARG AIRFLOW_PIP_VERSION=22.3.1
# Setup PIP
# By default PIP install run without cache to make image smaller
ARG PIP_NO_CACHE_DIR="true"
@@ -1208,7 +1129,10 @@ ARG CASS_DRIVER_NO_CYTHON="1"
# Build cassandra driver on multiple CPUs
ARG CASS_DRIVER_BUILD_CONCURRENCY="8"
-ARG AIRFLOW_VERSION="2.3.0.dev"
+ARG AIRFLOW_VERSION="2.5.0.dev0"
+
+# Additional PIP flags passed to all pip install commands except reinstalling pip itself
+ARG ADDITIONAL_PIP_INSTALL_FLAGS=""
ENV AIRFLOW_REPO=${AIRFLOW_REPO}\
AIRFLOW_BRANCH=${AIRFLOW_BRANCH} \
@@ -1237,6 +1161,7 @@ ENV AIRFLOW_REPO=${AIRFLOW_REPO}\
AIRFLOW_VERSION_SPECIFICATION="" \
PIP_NO_CACHE_DIR=${PIP_NO_CACHE_DIR} \
PIP_PROGRESS_BAR=${PIP_PROGRESS_BAR} \
+ ADDITIONAL_PIP_INSTALL_FLAGS=${ADDITIONAL_PIP_INSTALL_FLAGS} \
CASS_DRIVER_BUILD_CONCURRENCY=${CASS_DRIVER_BUILD_CONCURRENCY} \
CASS_DRIVER_NO_CYTHON=${CASS_DRIVER_NO_CYTHON}
@@ -1245,7 +1170,16 @@ RUN echo "Airflow version: ${AIRFLOW_VERSION}"
# Those are additional constraints that are needed for some extras but we do not want to
# force them on the main Airflow package. Those limitations are:
# * dill<0.3.3 required by apache-beam
-ARG EAGER_UPGRADE_ADDITIONAL_REQUIREMENTS="dill<0.3.3"
+# * pyarrow>=6.0.0 is because pip resolver decides for Python 3.10 to downgrade pyarrow to 5 even if it is OK
+# for python 3.10 and other dependencies adding the limit helps resolver to make better decisions
+# We need to limit the protobuf library to < 4.21.0 because not all google libraries we use
+# are compatible with the new protobuf version. All the google python client libraries need
+# to be upgraded to >= 2.0.0 in order to able to lift that limitation
+# https://developers.google.com/protocol-buffers/docs/news/2022-05-06#python-updates
+# * authlib, gcloud_aio_auth, adal are needed to generate constraints for PyPI packages and can be removed after we release
+# new google, azure providers
+# !!! MAKE SURE YOU SYNCHRONIZE THE LIST BETWEEN: Dockerfile, Dockerfile.ci, find_newer_dependencies.py
+ARG EAGER_UPGRADE_ADDITIONAL_REQUIREMENTS="dill<0.3.3 pyarrow>=6.0.0 protobuf<4.21.0 authlib>=1.0.0 gcloud_aio_auth>=4.0.0 adal>=1.2.7"
ARG UPGRADE_TO_NEWER_DEPENDENCIES="false"
ENV EAGER_UPGRADE_ADDITIONAL_REQUIREMENTS=${EAGER_UPGRADE_ADDITIONAL_REQUIREMENTS} \
UPGRADE_TO_NEWER_DEPENDENCIES=${UPGRADE_TO_NEWER_DEPENDENCIES}
@@ -1279,32 +1213,13 @@ COPY --from=scripts install_pipx_tools.sh /scripts/docker/
# dependencies installed in Airflow
RUN bash /scripts/docker/install_pipx_tools.sh
-# Copy package.json and yarn.lock to install node modules
-# this way even if other static check files change, node modules will not need to be installed
-# we want to keep node_modules so we can do this step separately from compiling assets
-COPY --from=www package.json yarn.lock ${AIRFLOW_SOURCES}/airflow/www/
-COPY --from=scripts prepare_node_modules.sh /scripts/docker/
-
-# Package JS/css for production
-RUN bash /scripts/docker/prepare_node_modules.sh
-
-# Copy all the needed www/ for assets compilation. Done as two separate COPY
-# commands so as otherwise it copies the _contents_ of static/ in to www/
-COPY --from=www webpack.config.js ${AIRFLOW_SOURCES}/airflow/www/
-COPY --from=www static ${AIRFLOW_SOURCES}/airflow/www/static/
-COPY --from=scripts compile_www_assets.sh /scripts/docker/
-
-# Build artifacts without removing temporary artifacts (we will need them for incremental changes)
-# in build mode
-RUN REMOVE_ARTIFACTS="false" BUILD_TYPE="build" bash /scripts/docker/compile_www_assets.sh
-
# Airflow sources change frequently but dependency configuration won't change that often
# We copy setup.py and other files needed to perform setup of dependencies
# So in case setup.py changes we can install latest dependencies required.
COPY setup.py ${AIRFLOW_SOURCES}/setup.py
COPY setup.cfg ${AIRFLOW_SOURCES}/setup.cfg
-
COPY airflow/__init__.py ${AIRFLOW_SOURCES}/airflow/
+COPY generated/provider_dependencies.json ${AIRFLOW_SOURCES}/generated/
COPY --from=scripts install_airflow.sh /scripts/docker/
diff --git a/IMAGES.rst b/IMAGES.rst
index 58ef0cca54852..f6b78aee0f310 100644
--- a/IMAGES.rst
+++ b/IMAGES.rst
@@ -85,13 +85,13 @@ You can build the CI image using current sources this command:
.. code-block:: bash
- breeze build-image
+ breeze ci-image build
You can build the PROD image using current sources with this command:
.. code-block:: bash
- breeze build-prod-image
+ breeze prod-image build
By adding ``--python `` parameter you can build the
image version for the chosen Python version.
@@ -104,13 +104,13 @@ For example if you want to build Python 3.7 version of production image with
.. code-block:: bash
- breeze build-prod-image --python 3.7 --extras "all"
+ breeze prod-image build --python 3.7 --extras "all"
If you just want to add new extras you can add them like that:
.. code-block:: bash
- breeze build-prod-image --python 3.7 --additional-extras "all"
+ breeze prod-image build --python 3.7 --additional-extras "all"
The command that builds the CI image is optimized to minimize the time needed to rebuild the image when
the source code of Airflow evolves. This means that if you already have the image locally downloaded and
@@ -128,7 +128,7 @@ parameter to Breeze:
.. code-block:: bash
- breeze build-prod-image --python 3.7 --additional-extras=trino --install-airflow-version=2.0.0
+ breeze prod-image build --python 3.7 --additional-extras=trino --install-airflow-version=2.0.0
This will build the image using command similar to:
@@ -165,8 +165,7 @@ You can also skip installing airflow and install it from locally provided files
.. code-block:: bash
- breeze build-prod-image --python 3.7 --additional-extras=trino \
- --airflow-is-in-context-pypi --install-packages-from-context
+ breeze prod-image build --python 3.7 --additional-extras=trino --install-packages-from-context
In this case you airflow and all packages (.whl files) should be placed in ``docker-context-files`` folder.
@@ -193,21 +192,21 @@ or ``disabled`` flags when you run Breeze commands. For example:
.. code-block:: bash
- breeze build-image --python 3.7 --docker-cache local
+ breeze ci-image build --python 3.7 --docker-cache local
Will build the CI image using local build cache (note that it will take quite a long time the first
time you run it).
.. code-block:: bash
- breeze build-prod-image --python 3.7 --docker-cache registry
+ breeze prod-image build --python 3.7 --docker-cache registry
Will build the production image with cache used from registry.
.. code-block:: bash
- breeze build-prod-image --python 3.7 --docker-cache disabled
+ breeze prod-image build --python 3.7 --docker-cache disabled
Will build the production image from the scratch.
@@ -281,7 +280,7 @@ to refresh them.
Every developer can also pull and run images being result of a specific CI run in GitHub Actions.
This is a powerful tool that allows to reproduce CI failures locally, enter the images and fix them much
-faster. It is enough to pass ``--github-image-id`` and the registry and Breeze will download and execute
+faster. It is enough to pass ``--image-tag`` and the registry and Breeze will download and execute
commands using the same image that was used during the CI tests.
For example this command will run the same Python 3.8 image as was used in build identified with
@@ -289,8 +288,7 @@ For example this command will run the same Python 3.8 image as was used in build
.. code-block:: bash
- ./breeze-legacy --github-image-id 9a621eaa394c0a0a336f8e1b31b35eff4e4ee86e \
- --python 3.8 --integration rabbitmq
+ breeze --image-tag 9a621eaa394c0a0a336f8e1b31b35eff4e4ee86e --python 3.8 --integration rabbitmq
You can see more details and examples in `Breeze `_
@@ -318,10 +316,9 @@ you have ``buildx`` plugin installed.
DOCKER_BUILDKIT=1 docker build . -f Dockerfile.ci \
--pull \
--build-arg PYTHON_BASE_IMAGE="python:3.7-slim-bullseye" \
- --build-arg ADDITIONAL_AIRFLOW_EXTRAS="jdbc"
- --build-arg ADDITIONAL_PYTHON_DEPS="pandas"
- --build-arg ADDITIONAL_DEV_APT_DEPS="gcc g++"
- --build-arg ADDITIONAL_RUNTIME_APT_DEPS="default-jre-headless"
+ --build-arg ADDITIONAL_AIRFLOW_EXTRAS="jdbc" \
+ --build-arg ADDITIONAL_PYTHON_DEPS="pandas" \
+ --build-arg ADDITIONAL_DEV_APT_DEPS="gcc g++" \
--tag my-image:0.0.1
@@ -329,8 +326,8 @@ the same image can be built using ``breeze`` (it supports auto-completion of the
.. code-block:: bash
- breeze build-prod-image --python 3.7 --additional-extras=jdbc --additional-python-deps="pandas" \
- --additional-dev-apt-deps="gcc g++" --additional-runtime-apt-deps="default-jre-headless"
+ breeze ci-image build --python 3.7 --additional-extras=jdbc --additional-python-deps="pandas" \
+ --additional-dev-apt-deps="gcc g++"
You can customize more aspects of the image - such as additional commands executed before apt dependencies
are installed, or adding extra sources to install your dependencies from. You can see all the arguments
@@ -355,10 +352,7 @@ based on example in `this comment `_ where they describe
+the issues they think are Airflow issues and should be solved. There are two kinds of issues:
+
+* Bugs - when the user thinks the reported issue is a bug in Airflow
+* Features - when there are small features that the user would like to see in Airflow
+
+We have `templates `_ for both types
+of issues defined in Airflow.
+
+However, important part of our issue reporting process are
+`GitHub Discussions `_ . Issues should represent
+clear, small feature requests or reproducible bugs which can/should be either implemented or fixed.
+Users are encouraged to open discussions rather than issues if there are no clear, reproducible
+steps, or when they have troubleshooting problems, and one of the important points of issue triaging is
+to determine if the issue reported should be rather a discussion. Converting an issue to a discussion
+while explaining the user why is an important part of triaging process.
+
+Responding to issues/discussions (relatively) quickly
+'''''''''''''''''''''''''''''''''''''''''''''''''''''
+
+It is vital to provide rather quick feedback to issues and discussions opened by our users, so that they
+feel listened to rather than ignored. Even if the response is "we are not going to work on it because ...",
+or "converting this issue to discussion because ..." or "closing because it is a duplicate of #xxx", it is
+far more welcoming than leaving issues and discussions unanswered. Sometimes issues and discussions are
+answered by other users (and this is cool) but if an issue/discussion is not responded to for a few days or
+weeks, this gives an impression that the user was ignored and that the Airflow project is unwelcoming.
+
+We strive to provide relatively quick responses to all such issues and discussions. Users should exercise
+patience while waiting for those (knowing that people might be busy, on vacations etc.) however they should
+not wait weeks until someone looks at their issues.
+
+
+Issue Triage Team
+''''''''''''''''''
+
+While many of the issues can be responded to by other users and committers, the committer team is not
+big enough to handle all such requests and sometimes they are busy with implementing important
+Therefore, some people who are regularly contributing and helping other users and shown their deep interest
+in the project can be invited to join the triage team.
+`the .asf.yaml <.asf.yaml>`_ file in the ``collaborators`` section.
+
+Committers can invite people to become members of the triage team if they see that the users are already
+helping and responding to issues and when they see the users are involved regularly. But you can also ask
+to become a member of the team (on devlist) if you can show that you have done that and when you want to have
+more ways to help others.
+
+The triage team members do not have committer privileges but they can
+assign, edit, and close issues and pull requests without having capabilities to merge the code. They can
+also convert issues into discussions and back. The expectation for the issue triage team is that they
+spend a bit of their time on those efforts. Triaging means not only assigning the labels but often responding
+to the issues and answering user concerns or if additional input is needed - tagging the committers or other community members who might be able to help provide more complete answers.
+
+Being an active and helpful member of the "Issue Triage Team" is actually one of the paths towards
+becoming a committer. By actively helping the users, triaging the issues, responding to them and
+involving others (when needed) shows that you are not only willing to help our users and the community,
+but are also ready to learn about parts of the projects you are not actively contributing to - all of that
+are super valuable components of being eligible to `become a committer `_.
+
+If you are a member of the triage team and not able to make any commitment, it's best to ask to have yourself
+removed from the triage team.
+
+BTW. Every committer is pretty much automatically part of the "Issue Triage Team" - so if you are committer,
+feel free to follow the process for every issue you stumble upon.
+
+Actions that can be taken by the issue triager
+''''''''''''''''''''''''''''''''''''''''''''''
+
+There are several actions an issue triager might take:
+
+* Closing and issue with "invalid" label explaining why it is closed in case the issue is invalid. This
+ should be accompanied by information that we can always re-open an issue if our understanding was wrong
+ or if the user provides more information.
+
+* Converting an issue to a discussion, if it is not very likely it is an Airflow issue or when it is not
+ responsible, or when it is a bigger feature proposal requiring discussion or when it's really users
+ troubleshooting or when the issue description is not at all clear. This also involves inviting the user
+ to a discussion if more information might change it.
+
+* Assigning the issue to a milestone, if the issue seems important enough that it should likely be looked
+ at before the next release but there is not enough information or doubts on why and what can be fixed.
+ Usually we assign to the the next bugfix release - then, no matter what the issue will be looked at
+ by the release manager and it might trigger additional actions during the release preparation.
+ This is usually followed by one of the actions below.
+
+* Fixing the issue in a PR if you see it is easy to fix. This is a great way also to learn and
+ contribute to parts that you usually are not contributing to, and sometimes it is surprisingly easy.
+
+* Assigning "good first issue" label if an issue is clear but not important to be fixed immediately, This
+ often lead to contributors picking up the issues when they are interested. This can be followed by assigning
+ the user who comments "I want to work on this" in the issue (which is most welcome).
+
+* Asking the user for additional information if it is needed to perform further investigations. This should
+ be accompanied by assigning ``pending response`` label so that we can clearly see the issues that need
+ extra information.
+
+* Calling other people who might be knowledgeable in the area by @-mentioning them in a comment.
+
+* Assigning other labels to the issue as described below.
+
Labels
''''''
-Since Apache Airflow uses GitHub Issues as the issue tracking system, the
-use of labels is extensive. Though issue labels tend to change over time
-based on components within the project, the majority of the ones listed
-below should stand the test of time.
+Since Apache Airflow uses "GitHub Issues" and "Github Discussions" as the
+issue tracking systems, the use of labels is extensive. Though issue
+labels tend to change over time based on components within the project,
+the majority of the ones listed below should stand the test of time.
The intention with the use of labels with the Apache Airflow project is
that they should ideally be non-temporal in nature and primarily used
@@ -44,16 +140,16 @@ to indicate the following elements:
**Kind**
-The “kind” labels indicate “what kind of issue it is”. The most
-commonly used “kind” labels are: bug, feature, documentation, or task.
+The "kind" labels indicate "what kind of issue it is". The most
+commonly used "kind" labels are: bug, feature, documentation, or task.
Therefore, when reporting an issue, the label of ``kind:bug`` is to
indicate a problem with the functionality, whereas the label of
``kind:feature`` is a desire to extend the functionality.
There has been discussion within the project about whether to separate
-the desire for “new features” from “enhancements to existing features”,
-but in practice most “feature requests” are actually enhancement requests,
+the desire for "new features" from "enhancements to existing features",
+but in practice most "feature requests" are actually enhancement requests,
so we decided to combine them both into ``kind:feature``.
The ``kind:task`` is used to categorize issues which are
@@ -67,7 +163,7 @@ made to the documentation within the project.
**Area**
-The “area” set of labels should indicate the component of the code
+The "area" set of labels should indicate the component of the code
referenced by the issue. At a high level, the biggest areas of the project
are: Airflow Core and Airflow Providers, which are referenced by ``area:core``
and ``area:providers``. This is especially important since these are now
@@ -75,15 +171,18 @@ being released and versioned independently.
There are more detailed areas of the Core Airflow project such as Scheduler, Webserver,
API, UI, Logging, and Kubernetes, which are all conceptually under the
-“Airflow Core” area of the project.
+"Airflow Core" area of the project.
Similarly within Airflow Providers, the larger providers such as Apache, AWS, Azure,
and Google who have many hooks and operators within them, have labels directly
-associated with them such as ``provider/Apache``, ``provider/AWS``,
-``provider/Azure``, and ``provider/Google``.
+associated with them such as ``provider:Apache``, ``provider:AWS``,
+``provider:Azure``, and ``provider:Google``.
+
These make it easier for developers working on a single provider to
track issues for that provider.
+Some provider labels may couple several providers for example: ``provider:Protocols``
+
Most issues need a combination of "kind" and "area" labels to be actionable.
For example:
@@ -91,7 +190,6 @@ For example:
* Bug report on the User Interface would have ``kind:bug`` and ``area:UI``
* Documentation request on the Kubernetes Executor, would have ``kind:documentation`` and ``area:kubernetes``
-
Response to issues
''''''''''''''''''
@@ -116,7 +214,7 @@ Therefore, the priority labels used are:
It's important to use priority labels effectively so we can triage incoming issues
appropriately and make sure that when we release a new version of Airflow,
-we can ship a release confident that there are no “production blocker” issues in it.
+we can ship a release confident that there are no "production blocker" issues in it.
This applies to both Core Airflow as well as the Airflow Providers. With the separation
of the Providers release from Core Airflow, a ``priority:critical`` bug in a single
@@ -151,11 +249,7 @@ to be able to reproduce the issue. Typically, this may require a response to the
issue creator asking for more information, with the issue then being tagged with
the label ``pending-response``.
Also, during this stage, additional labels may be added to the issue to help
-classification and triage, such as ``reported_version`` and ``area``.
-
-Occasionally an issue may require a larger discussion among the Airflow PMC or
-the developer mailing list. This issue may then be tagged with the
-``needs:discussion`` label.
+classification and triage, such as ``affected_version`` and ``area``.
Some issues may need a detailed review by one of the core committers of the project
and this could be tagged with a ``needs:triage`` label.
@@ -164,7 +258,7 @@ and this could be tagged with a ``needs:triage`` label.
**Good First Issue**
Issues which are relatively straight forward to solve, will be tagged with
-the ``GoodFirstIssue`` label.
+the ``good first issue`` label.
The intention here is to galvanize contributions from new and inexperienced
contributors who are looking to contribute to the project. This has been successful
@@ -175,13 +269,13 @@ Ideally, these issues only require one or two files to be changed. The intention
here is that incremental changes to existing files are a lot easier for a new
contributor as compared to adding something completely new.
-Another possibility here is to add “how to fix” in the comments of such issues, so
+Another possibility here is to add "how to fix" in the comments of such issues, so
that new contributors have a running start when then pick up these issues.
**Timeliness**
-For the sake of quick responses, the general “soft" rule within the Airflow project
+For the sake of quick responses, the general "soft" rule within the Airflow project
is that if there is no assignee, anyone can take an issue to solve.
However, this depends on timely resolution of the issue by the assignee. The
@@ -196,13 +290,20 @@ issue creator. After the pending-response label has been assigned, if there is n
further information for a period of 1 month, the issue will be automatically closed.
-
**Invalidity**
At times issues are marked as invalid and later closed because of one of the
following situations:
* The issue is a duplicate of an already reported issue. In such cases, the latter issue is marked as ``duplicate``.
-* Despite attempts to reproduce the issue to resolve it, the issue cannot be reproduced by the Airflow team based on the given information. In such cases, the issue is marked as ``Can’t Reproduce``.
+* Despite attempts to reproduce the issue to resolve it, the issue cannot be reproduced by the Airflow team based on the given information. In such cases, the issue is marked as ``Can't Reproduce``.
* In some cases, the original creator realizes that the issue was incorrectly reported and then marks it as ``invalid``. Also, a committer could mark it as ``invalid`` if the issue being reported is for an unsupported operation or environment.
* In some cases, the issue may be legitimate, but may not be addressed in the short to medium term based on current project priorities or because this will be irrelevant because of an upcoming change. The committer could mark this as ``wontfix`` to set expectations that it won't be directly addressed in the near term.
+
+**GitHub Discussions**
+
+Issues should represent clear feature requests which can/should be implemented. If the idea is vague or can be solved with easier steps
+we normally convert such issues to discussions in the Ideas category.
+Issues that seems more like support requests are also converted to discussions in the Q&A category.
+We use judgment about which Issues to convert to discussions, it's best to always clarify with a comment why the issue is being converted.
+Note that we can always convert discussions back to issues.
diff --git a/LOCAL_VIRTUALENV.rst b/LOCAL_VIRTUALENV.rst
index 1cbf5b6b12e81..3208634a4ee62 100644
--- a/LOCAL_VIRTUALENV.rst
+++ b/LOCAL_VIRTUALENV.rst
@@ -205,21 +205,15 @@ Activate your virtualenv, e.g. by using ``workon``, and once you are in it, run:
.. code-block:: bash
- ./breeze-legacy initialize-local-virtualenv
+ ./scripts/tools/initialize_virtualenv.py
-By default Breeze installs the ``devel`` extra only. You can optionally control which extras are installed by exporting ``VIRTUALENV_EXTRAS`` before calling Breeze:
+By default Breeze installs the ``devel`` extra only. You can optionally control which extras are
+Adding extra dependencies as parameter.
.. code-block:: bash
- export VIRTUALENV_EXTRAS="devel,google,postgres"
- ./breeze-legacy initialize-local-virtualenv
+ ./scripts/tools/initialize_virtualenv.py devel,google,postgres
-5. (optionally) run yarn build if you plan to run the webserver
-
-.. code-block:: bash
-
- cd airflow/www
- yarn build
Developing Providers
--------------------
diff --git a/MANIFEST.in b/MANIFEST.in
index 8f4b22b7ca01c..bfcaf6057c8a1 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -38,3 +38,4 @@ include airflow/customized_form_field_behaviours.schema.json
include airflow/serialization/schema.json
include airflow/utils/python_virtualenv_script.jinja2
include airflow/utils/context.pyi
+include generated
diff --git a/NOTICE b/NOTICE
index 4c7b795d88ce6..84c77cd4fc12c 100644
--- a/NOTICE
+++ b/NOTICE
@@ -20,3 +20,10 @@ This product contains a modified portion of 'Flask App Builder' developed by Dan
(https://github.com/dpgaspar/Flask-AppBuilder).
* Copyright 2013, Daniel Vaz Gaspar
+
+Chakra UI:
+-----
+This product contains a modified portion of 'Chakra UI' developed by Segun Adebayo.
+(https://github.com/chakra-ui/chakra-ui).
+
+* Copyright 2019, Segun Adebayo
diff --git a/PULL_REQUEST_WORKFLOW.rst b/PULL_REQUEST_WORKFLOW.rst
deleted file mode 100644
index d7ca2f9b93eaa..0000000000000
--- a/PULL_REQUEST_WORKFLOW.rst
+++ /dev/null
@@ -1,158 +0,0 @@
- .. Licensed to the Apache Software Foundation (ASF) under one
- or more contributor license agreements. See the NOTICE file
- distributed with this work for additional information
- regarding copyright ownership. The ASF licenses this file
- to you under the Apache License, Version 2.0 (the
- "License"); you may not use this file except in compliance
- with the License. You may obtain a copy of the License at
-
- .. http://www.apache.org/licenses/LICENSE-2.0
-
- .. Unless required by applicable law or agreed to in writing,
- software distributed under the License is distributed on an
- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- KIND, either express or implied. See the License for the
- specific language governing permissions and limitations
- under the License.
-
-.. contents:: :local:
-
-Why non-standard pull request workflow?
----------------------------------------
-
-This document describes the Pull Request Workflow we've implemented in Airflow. The workflow is slightly
-more complex than regular workflow you might encounter in most of the projects because after experiencing
-some huge delays in processing queues in October 2020 with GitHub Actions, we've decided to optimize the
-workflow to minimize the use of GitHub Actions build time by utilising selective approach on which tests
-and checks in the CI system are run depending on analysis of which files changed in the incoming PR and
-allowing the Committers to control the scope of the tests during the approval/review process.
-
-Just to give a bit of context, we started off with the approach that we always run all tests for all the
-incoming PRs, however due to our matrix of tests growing, this approach did not scale with the increasing
-number of PRs and when we had to compete with other Apache Software Foundation projects for the 180
-slots that are available for the whole organization. More Apache Software Foundation projects started
-to use GitHub Actions and we've started to experience long queues when our jobs waited for free slots.
-
-We approached the problem by:
-
-1) Improving mechanism of cancelling duplicate workflow runs more efficiently in case of queue conditions
- (duplicate workflow runs are generated when someone pushes a fixup quickly - leading to running both
- out-dated and current run to completion, taking precious slots. This has been implemented by improving
- `cancel-workflow-run `_ action we are using. In version
- 4.1 it got a new feature of cancelling all duplicates even if there is a long queue of builds.
-
-2) Heavily decreasing strain on the GitHub Actions jobs by introducing selective checks - mechanism
- to control which parts of the tests are run during the tests. This is implemented by the
- ``scripts/ci/selective_ci_checks.sh`` script in our repository. This script analyses which part of the
- code has changed and based on that it sets the right outputs that control which tests are executed in
- the ``Tests`` workflow, and whether we need to build CI images necessary to run those steps. This allowed to
- heavily decrease the strain especially for the Pull Requests that were not touching code (in which case
- the builds can complete in < 2 minutes) but also by limiting the number of tests executed in PRs that do
- not touch the "core" of Airflow, or only touching some - standalone - parts of Airflow such as
- "Providers", "WWW" or "CLI". This solution is not yet perfect as there are likely some edge cases but
- it is easy to maintain and we have an escape-hatch - all the tests are always executed in main pushes,
- so contributors can easily spot if there is a "missed" case and fix it - both by fixing the problem and
- adding those exceptions to the code. More about it can be found in `Selective checks `_
-
-3) Even more optimisation came from limiting the scope of tests to only "default" matrix parameters. So far
- in Airflow we always run all tests for all matrix combinations. The primary matrix components are:
-
- * Python versions (currently 3.7, 3.8, 3.9, 3.10)
- * Backend types (currently MySQL/Postgres)
- * Backed version (currently MySQL 5.7, MySQL 8, Postgres 13
-
- We've decided that instead of running all the combinations of parameters for all matrix component we will
- only run default values (Python 3.7, Mysql 5.7, Postgres 13) for all PRs which are not approved yet by
- the committers. This has a nice effect, that full set of tests (though with limited combinations of
- the matrix) are still run in the CI for every Pull Request that needs tests at all - allowing the
- contributors to make sure that their PR is "good enough" to be reviewed.
-
- Even after approval, the automated workflows we've implemented, check if the PR seems to need
- "full test matrix" and provide helpful information to both contributors and committers in the form of
- explanatory comments and labels set automatically showing the status of the PR. Committers have still
- control whether they want to merge such requests automatically or ask for rebase or re-run the tests
- and run "full tests" by applying the "full tests needed" label and re-running such request.
- The "full tests needed" label is also applied automatically after approval when the change touches
- the "core" of Airflow - also a separate check is added to the PR so that the "merge" button status
- will indicate to the committer that full tests are still needed. The committer might still decide,
- whether to merge such PR without the "full matrix". The "escape hatch" we have - i.e. running the full
- matrix of tests in the "merge push" will enable committers to catch and fix such problems quickly.
- More about it can be found in `Approval workflow and Matrix tests <#approval-workflow-and-matrix-tests>`_
- chapter.
-
-4) We've also applied (and received) funds to run self-hosted runners. They are used for ``main`` runs
- and whenever the PRs are done by one of the maintainers. Maintainers can force using Public GitHub runners
- by applying "use public runners" label to the PR before submitting it.
-
-
-Approval Workflow and Matrix tests
-----------------------------------
-
-As explained above the approval and matrix tests workflow works according to the algorithm below:
-
-1) In case of "no-code" changes - so changes that do not change any of the code or environment of
- the application, no test are run (this is done via selective checks). Also no CI/PROD images are
- build saving extra minutes. Such build takes less than 2 minutes currently and only few jobs are run
- which is a very small fraction of the "full build" time.
-
-2) When new PR is created, only a "default set" of matrix test are running. Only default
- values for each of the parameters are used effectively limiting it to running matrix builds for only
- one python version and one version of each of the backends. In this case only one CI and one PROD
- image is built, saving precious job slots. This build takes around 50% less time than the "full matrix"
- build.
-
-3) When such PR gets approved, the system further analyses the files changed in this PR and further
- decision is made that should be communicated to both Committer and Reviewer.
-
-3a) In case of "no-code" builds, a message is communicated that the PR is ready to be merged and
- no tests are needed.
-
-.. image:: images/pr/pr-no-tests-needed-comment.png
- :align: center
- :alt: No tests needed for "no-code" builds
-
-3b) In case of "non-core" builds a message is communicated that such PR is likely OK to be merged as is with
- limited set of tests, but that the committer might decide to re-run the PR after applying
- "full tests needed" label, which will trigger full matrix build for tests for this PR. The committer
- might make further decision on what to do with this PR.
-
-.. image:: images/pr/pr-likely-ok-to-merge.png
- :align: center
- :alt: Likely ok to merge the PR with only small set of tests
-
-3c) In case of "core" builds (i. e. when the PR touches some "core" part of Airflow) a message is
- communicated that this PR needs "full test matrix", the "full tests needed" label is applied
- automatically and either the contributor might rebase the request to trigger full test build or the
- committer might re-run the build manually to trigger such full test rebuild. Also a check "in-progress"
- is added, so that the committer realises that the PR is not yet "green to merge". Pull requests with
- "full tests needed" label always trigger the full matrix build when rebased or re-run so if the
- PR gets rebased, it will continue triggering full matrix build.
-
-.. image:: images/pr/pr-full-tests-needed.png
- :align: center
- :alt: Full tests are needed for the PR
-
-4) If this or another committer "request changes" in a previously approved PR with "full tests needed"
- label, the bot automatically removes the label, moving it back to "run only default set of parameters"
- mode. For PRs touching core of airflow once the PR gets approved back, the label will be restored.
- If it was manually set by the committer, it has to be restored manually.
-
-.. note:: Note that setting the labels and adding comments might be delayed, due to limitation of GitHub Actions,
- in case of queues, processing of Pull Request reviews might take some time, so it is advised not to merge
- PR immediately after approval. Luckily, the comments describing the status of the PR trigger notifications
- for the PRs and they provide good "notification" for the committer to act on a PR that was recently
- approved.
-
-The PR approval workflow is possible thanks to two custom GitHub Actions we've developed:
-
-* `Get workflow origin `_
-* `Label when approved `_
-
-
-Next steps
-----------
-
-We are planning to also propose the approach to other projects from Apache Software Foundation to
-make it a common approach, so that our effort is not limited only to one project.
-
-Discussion about it in `this discussion `_
diff --git a/README.md b/README.md
index 3acb883dc4ba8..3b5e704e5f086 100644
--- a/README.md
+++ b/README.md
@@ -31,6 +31,7 @@
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
[![Twitter Follow](https://img.shields.io/twitter/follow/ApacheAirflow.svg?style=social&label=Follow)](https://twitter.com/ApacheAirflow)
[![Slack Status](https://img.shields.io/badge/slack-join_chat-white.svg?logo=slack&style=social)](https://s.apache.org/airflow-slack)
+[![Contributors](https://img.shields.io/github/contributors/apache/airflow)](https://github.com/apache/airflow/graphs/contributors)
[Apache Airflow](https://airflow.apache.org/docs/apache-airflow/stable/) (or simply Airflow) is a platform to programmatically author, schedule, and monitor workflows.
@@ -55,7 +56,7 @@ Use Airflow to author workflows as directed acyclic graphs (DAGs) of tasks. The
- [Support for Python and Kubernetes versions](#support-for-python-and-kubernetes-versions)
- [Base OS support for reference Airflow images](#base-os-support-for-reference-airflow-images)
- [Approach to dependencies of Airflow](#approach-to-dependencies-of-airflow)
-- [Support for providers](#support-for-providers)
+- [Release process for Providers](#release-process-for-providers)
- [Contributing](#contributing)
- [Who uses Apache Airflow?](#who-uses-apache-airflow)
- [Who Maintains Apache Airflow?](#who-maintains-apache-airflow)
@@ -70,7 +71,7 @@ Use Airflow to author workflows as directed acyclic graphs (DAGs) of tasks. The
Airflow works best with workflows that are mostly static and slowly changing. When the DAG structure is similar from one run to the next, it clarifies the unit of work and continuity. Other similar projects include [Luigi](https://github.com/spotify/luigi), [Oozie](https://oozie.apache.org/) and [Azkaban](https://azkaban.github.io/).
-Airflow is commonly used to process data, but has the opinion that tasks should ideally be idempotent (i.e., results of the task will be the same, and will not create duplicated data in a destination system), and should not pass large quantities of data from one task to the next (though tasks can pass metadata using Airflow's [Xcom feature](https://airflow.apache.org/docs/apache-airflow/stable/concepts.html#xcoms)). For high-volume, data-intensive tasks, a best practice is to delegate to external services specializing in that type of work.
+Airflow is commonly used to process data, but has the opinion that tasks should ideally be idempotent (i.e., results of the task will be the same, and will not create duplicated data in a destination system), and should not pass large quantities of data from one task to the next (though tasks can pass metadata using Airflow's [XCom feature](https://airflow.apache.org/docs/apache-airflow/stable/concepts/xcoms.html)). For high-volume, data-intensive tasks, a best practice is to delegate to external services specializing in that type of work.
Airflow is not a streaming solution, but it is often used to process real-time data, pulling data off streams in batches.
@@ -85,12 +86,12 @@ Airflow is not a streaming solution, but it is often used to process real-time d
Apache Airflow is tested with:
-| | Main version (dev) | Stable version (2.3.1) |
+| | Main version (dev) | Stable version (2.5.1) |
|---------------------|------------------------------|------------------------------|
| Python | 3.7, 3.8, 3.9, 3.10 | 3.7, 3.8, 3.9, 3.10 |
| Platform | AMD64/ARM64(\*) | AMD64/ARM64(\*) |
-| Kubernetes | 1.20, 1.21, 1.22, 1.23, 1.24 | 1.20, 1.21, 1.22, 1.23, 1.24 |
-| PostgreSQL | 10, 11, 12, 13, 14 | 10, 11, 12, 13, 14 |
+| Kubernetes | 1.21, 1.22, 1.23, 1.24, 1.25 | 1.21, 1.22, 1.23, 1.24, 1.25 |
+| PostgreSQL | 11, 12, 13, 14, 15 | 11, 12, 13, 14, 15 |
| MySQL | 5.7, 8 | 5.7, 8 |
| SQLite | 3.15.0+ | 3.15.0+ |
| MSSQL | 2017(\*), 2019 (\*) | 2017(\*), 2019 (\*) |
@@ -104,9 +105,6 @@ MariaDB is not tested/recommended.
**Note**: SQLite is used in Airflow tests. Do not use it in production. We recommend
using the latest stable version of SQLite for local development.
-**Note**: Support for Python v3.10 will be available from Airflow 2.3.0. The `main` (development) branch
-already supports Python 3.10.
-
**Note**: Airflow currently can be run on POSIX-compliant Operating Systems. For development it is regularly
tested on fairly modern Linux Distros and recent versions of MacOS.
On Windows you can run it via WSL2 (Windows Subsystem for Linux 2) or via Linux Containers.
@@ -120,13 +118,13 @@ is used in the [Community managed DockerHub image](https://hub.docker.com/p/apac
Visit the official Airflow website documentation (latest **stable** release) for help with
[installing Airflow](https://airflow.apache.org/docs/apache-airflow/stable/installation.html),
-[getting started](https://airflow.apache.org/docs/apache-airflow/stable/start/index.html), or walking
+[getting started](https://airflow.apache.org/docs/apache-airflow/stable/start.html), or walking
through a more complete [tutorial](https://airflow.apache.org/docs/apache-airflow/stable/tutorial.html).
> Note: If you're looking for documentation for the main branch (latest development branch): you can find it on [s.apache.org/airflow-docs](https://s.apache.org/airflow-docs/).
For more information on Airflow Improvement Proposals (AIPs), visit
-the [Airflow Wiki](https://cwiki.apache.org/confluence/display/AIRFLOW/Airflow+Improvements+Proposals).
+the [Airflow Wiki](https://cwiki.apache.org/confluence/display/AIRFLOW/Airflow+Improvement+Proposals).
Documentation for dependent projects like provider packages, Docker image, Helm Chart, you'll find it in [the documentation index](https://airflow.apache.org/docs/).
@@ -160,15 +158,15 @@ them to the appropriate format and workflow that your tool requires.
```bash
-pip install 'apache-airflow==2.3.1' \
- --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.3.1/constraints-3.7.txt"
+pip install 'apache-airflow==2.5.1' \
+ --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.5.1/constraints-3.7.txt"
```
2. Installing with extras (i.e., postgres, google)
```bash
-pip install 'apache-airflow[postgres,google]==2.3.1' \
- --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.3.1/constraints-3.7.txt"
+pip install 'apache-airflow[postgres,google]==2.5.1' \
+ --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.5.1/constraints-3.7.txt"
```
For information on installing provider packages, check
@@ -273,7 +271,7 @@ Apache Airflow version life cycle:
| Version | Current Patch/Minor | State | First Release | Limited Support | EOL/Terminated |
|-----------|-----------------------|-----------|-----------------|-------------------|------------------|
-| 2 | 2.3.1 | Supported | Dec 17, 2020 | TBD | TBD |
+| 2 | 2.5.1 | Supported | Dec 17, 2020 | TBD | TBD |
| 1.10 | 1.10.15 | EOL | Aug 27, 2018 | Dec 17, 2020 | June 17, 2021 |
| 1.9 | 1.9.0 | EOL | Jan 03, 2018 | Aug 27, 2018 | Aug 27, 2018 |
| 1.8 | 1.8.2 | EOL | Mar 19, 2017 | Jan 03, 2018 | Jan 03, 2018 |
@@ -293,8 +291,8 @@ They are based on the official release schedule of Python and Kubernetes, nicely
[Python Developer's Guide](https://devguide.python.org/#status-of-python-branches) and
[Kubernetes version skew policy](https://kubernetes.io/docs/setup/release/version-skew-policy/).
-1. We drop support for Python and Kubernetes versions when they reach EOL. Except for kubernetes, a
- version stay supported by Airflow if two major cloud provider still provide support for it. We drop
+1. We drop support for Python and Kubernetes versions when they reach EOL. Except for Kubernetes, a
+ version stays supported by Airflow if two major cloud providers still provide support for it. We drop
support for those EOL versions in main right after EOL date, and it is effectively removed when we release
the first new MINOR (Or MAJOR if there is no new MINOR version) of Airflow. For example, for Python 3.7 it
means that we will drop support in main right after 27.06.2023, and the first MAJOR or MINOR version of
@@ -303,7 +301,7 @@ They are based on the official release schedule of Python and Kubernetes, nicely
2. The "oldest" supported version of Python/Kubernetes is the default one until we decide to switch to
later version. "Default" is only meaningful in terms of "smoke tests" in CI PRs, which are run using this
default version and the default reference image available. Currently `apache/airflow:latest`
- and `apache/airflow:2.3.1` images are Python 3.7 images. This means that default reference image will
+ and `apache/airflow:2.5.1` images are Python 3.7 images. This means that default reference image will
become the default at the time when we start preparing for dropping 3.7 support which is few months
before the end of life for Python 3.7.
@@ -321,7 +319,7 @@ we publish an Apache Airflow release. Those images contain:
Airflow released (so there could be different versions for 2.3 and 2.2 line for example)
* Libraries required to connect to suppoerted Databases (again the set of databases supported depends
on the MINOR version of Airflow.
-* Predefined set of popular providers (for details see the [Dockerfile](Dockerfile)).
+* Predefined set of popular providers (for details see the [Dockerfile](https://raw.githubusercontent.com/apache/airflow/main/Dockerfile)).
* Possibility of building your own, custom image where the user can choose their own set of providers
and libraries (see [Building the image](https://airflow.apache.org/docs/docker-stack/build.html))
* In the future Airflow might also support a "slim" version without providers nor database clients installed
@@ -330,14 +328,17 @@ The version of the base OS image is the stable version of Debian. Airflow suppor
stable versions - as soon as all Airflow dependencies support building, and we set up the CI pipeline for
building and testing the OS version. Approximately 6 months before the end-of-life of a previous stable
version of the OS, Airflow switches the images released to use the latest supported version of the OS.
-For example since Debian Buster end-of-life is August 2022, Airflow switches the images in `main` branch
-to use Debian Bullseye in February/March 2022. The version will be used in the next MINOR release after
-the switch happens. In case of the Bullseye switch - 2.3.0 version will use Bullseye. The images released
-in the previous MINOR version continue to use the version that all other releases for the MINOR version
-used.
+For example since ``Debian Buster`` end-of-life was August 2022, Airflow switched the images in `main` branch
+to use ``Debian Bullseye`` in February/March 2022. The version was used in the next MINOR release after
+the switch happened. In case of the Bullseye switch - 2.3.0 version used ``Debian Bullseye``.
+The images released in the previous MINOR version continue to use the version that all other releases
+for the MINOR version used.
+
+Support for ``Debian Buster`` image was dropped in August 2022 completely and everyone is expected to
+stop building their images using ``Debian Buster``.
Users will continue to be able to build their images using stable Debian releases until the end of life and
-building and verifying of the images happens in our CI but no unit tests are executed using this image in
+building and verifying of the images happens in our CI but no unit tests were executed using this image in
the `main` branch.
## Approach to dependencies of Airflow
@@ -391,22 +392,81 @@ The important dependencies are:
### Approach for dependencies in Airflow Providers and extras
-Those `extras` and `providers` dependencies are maintained in `setup.py`.
+Those `extras` and `providers` dependencies are maintained in `provider.yaml` of each provider.
By default, we should not upper-bound dependencies for providers, however each provider's maintainer
might decide to add additional limits (and justify them with comment)
-## Support for providers
+## Release process for Providers
-Providers released by the community have limitation of a minimum supported version of Airflow. The minimum
-version of Airflow is the `MINOR` version (2.1, 2.2 etc.) indicating that the providers might use features
-that appeared in this release. The default support timespan for the minimum version of Airflow
-(there could be justified exceptions) is that we increase the minimum Airflow version, when 12 months passed
-since the first release for the MINOR version of Airflow.
+Providers released by the community (with roughly monthly cadence) have
+limitation of a minimum supported version of Airflow. The minimum version of
+Airflow is the `MINOR` version (2.2, 2.3 etc.) indicating that the providers
+might use features that appeared in this release. The default support timespan
+for the minimum version of Airflow (there could be justified exceptions) is
+that we increase the minimum Airflow version, when 12 months passed since the
+first release for the MINOR version of Airflow.
For example this means that by default we upgrade the minimum version of Airflow supported by providers
-to 2.2.0 in the first Provider's release after 21st of May 2022 (21st of May 2021 is the date when the
-first `PATCHLEVEL` of 2.1 (2.1.0) has been released.
+to 2.4.0 in the first Provider's release after 30th of April 2023. The 30th of April 2022 is the date when the
+first `PATCHLEVEL` of 2.3 (2.3.0) has been released.
+
+When we increase the minimum Airflow version, this is not a reason to bump `MAJOR` version of the providers
+(unless there are other breaking changes in the provider). The reason for that is that people who use
+older version of Airflow will not be able to use that provider (so it is not a breaking change for them)
+and for people who are using supported version of Airflow this is not a breaking change on its own - they
+will be able to use the new version without breaking their workflows. When we upgraded min-version to
+2.2+, our approach was different but as of 2.3+ upgrade (November 2022) we only bump `MINOR` version of the
+provider when we increase minimum Airflow version.
+
+Providers are often connected with some stakeholders that are vitally interested in maintaining backwards
+compatibilities in their integrations (for example cloud providers, or specific service providers). But,
+we are also bound with the [Apache Software Foundation release policy](https://www.apache.org/legal/release-policy.html)
+which describes who releases, and how to release the ASF software. The provider's governance model is something we name
+"mixed governance" - where we follow the release policies, while the burden of maintaining and testing
+the cherry-picked versions is on those who commit to perform the cherry-picks and make PRs to older
+branches.
+
+The "mixed governance" (optional, per-provider) means that:
+
+* The Airflow Community and release manager decide when to release those providers.
+ This is fully managed by the community and the usual release-management process following the
+ [Apache Software Foundation release policy](https://www.apache.org/legal/release-policy.html)
+* The contributors (who might or might not be direct stakeholders in the provider) will carry the burden
+ of cherry-picking and testing the older versions of providers.
+* There is no "selection" and acceptance process to determine which version of the provider is released.
+ It is determined by the actions of contributors raising the PR with cherry-picked changes and it follows
+ the usual PR review process where maintainer approves (or not) and merges (or not) such PR. Simply
+ speaking - the completed action of cherry-picking and testing the older version of the provider make
+ it eligible to be released. Unless there is someone who volunteers and perform the cherry-picking and
+ testing, the provider is not released.
+* Branches to raise PR against are created when a contributor commits to perform the cherry-picking
+ (as a comment in PR to cherry-pick for example)
+
+Usually, community effort is focused on the most recent version of each provider. The community approach is
+that we should rather aggressively remove deprecations in "major" versions of the providers - whenever
+there is an opportunity to increase major version of a provider, we attempt to remove all deprecations.
+However, sometimes there is a contributor (who might or might not represent stakeholder),
+willing to make their effort on cherry-picking and testing the non-breaking changes to a selected,
+previous major branch of the provider. This results in releasing at most two versions of a
+provider at a time:
+
+* potentially breaking "latest" major version
+* selected past major version with non-breaking changes applied by the contributor
+
+Cherry-picking such changes follows the same process for releasing Airflow
+patch-level releases for a previous minor Airflow version. Usually such cherry-picking is done when
+there is an important bugfix and the latest version contains breaking changes that are not
+coupled with the bugfix. Releasing them together in the latest version of the provider effectively couples
+them, and therefore they're released separately. The cherry-picked changes have to be merged by the committer following the usual rules of the
+community.
+
+There is no obligation to cherry-pick and release older versions of the providers.
+The community continues to release such older versions of the providers for as long as there is an effort
+of the contributors to perform the cherry-picks and carry-on testing of the older provider version.
+
+The availability of stakeholder that can manage "service-oriented" maintenance and agrees to such a
+responsibility, will also drive our willingness to accept future, new providers to become community managed.
## Contributing
diff --git a/RELEASE_NOTES.rst b/RELEASE_NOTES.rst
index 87319b55aaebb..e879df02c328a 100644
--- a/RELEASE_NOTES.rst
+++ b/RELEASE_NOTES.rst
@@ -21,6 +21,1146 @@
.. towncrier release notes start
+Airflow 2.5.1 (2023-01-16)
+--------------------------
+
+Significant Changes
+^^^^^^^^^^^^^^^^^^^
+
+Trigger gevent ``monkeypatching`` via environment variable (#28283)
+"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
+
+ If you are using gevent for your webserver deployment and used local settings to ``monkeypatch`` gevent,
+ you might want to replace local settings patching with an ``_AIRFLOW_PATCH_GEVENT`` environment variable
+ set to 1 in your webserver. This ensures gevent patching is done as early as possible. (#8212)
+
+Bug Fixes
+^^^^^^^^^
+- Fix masking of non-sensitive environment variables (#28802)
+- Remove swagger-ui extra from connexion and install ``swagger-ui-dist`` via npm package (#28788)
+- Fix ``UIAlert`` should_show when ``AUTH_ROLE_PUBLIC`` set (#28781)
+- Only patch single label when adopting pod (#28776)
+- Update CSRF token to expire with session (#28730)
+- Fix "airflow tasks render" cli command for mapped task instances (#28698)
+- Allow XComArgs for ``external_task_ids`` of ExternalTaskSensor (#28692)
+- Row-lock TIs to be removed during mapped task expansion (#28689)
+- Handle ConnectionReset exception in Executor cleanup (#28685)
+- Fix description of output redirection for access_log for gunicorn (#28672)
+- Add back join to zombie query that was dropped in #28198 (#28544)
+- Fix calendar view for CronTriggerTimeTable dags (#28411)
+- After running the DAG the employees table is empty. (#28353)
+- Fix ``DetachedInstanceError`` when finding zombies in Dag Parsing process (#28198)
+- Nest header blocks in ``divs`` to fix ``dagid`` copy nit on dag.html (#28643)
+- Fix UI caret direction (#28624)
+- Guard not-yet-expanded ti in trigger rule dep (#28592)
+- Move TI ``setNote`` endpoints under TaskInstance in OpenAPI (#28566)
+- Consider previous run in ``CronTriggerTimetable`` (#28532)
+- Ensure correct log dir in file task handler (#28477)
+- Fix bad pods pickled in executor_config (#28454)
+- Add ``ensure_ascii=False`` in trigger dag run API (#28451)
+- Add setters to MappedOperator on_*_callbacks (#28313)
+- Fix ``ti._try_number`` for deferred and up_for_reschedule tasks (#26993)
+- separate ``callModal`` from dag.js (#28410)
+- A manual run can't look like a scheduled one (#28397)
+- Dont show task/run durations when there is no start_date (#28395)
+- Maintain manual scroll position in task logs (#28386)
+- Correctly select a mapped task's "previous" task (#28379)
+- Trigger gevent ``monkeypatching`` via environment variable (#28283)
+- Fix db clean warnings (#28243)
+- Make arguments 'offset' and 'length' not required (#28234)
+- Make live logs reading work for "other" k8s executors (#28213)
+- Add custom pickling hooks to ``LazyXComAccess`` (#28191)
+- fix next run datasets error (#28165)
+- Ensure that warnings from ``@dag`` decorator are reported in dag file (#28153)
+- Do not warn when airflow dags tests command is used (#28138)
+- Ensure the ``dagbag_size`` metric decreases when files are deleted (#28135)
+- Improve run/task grid view actions (#28130)
+- Make BaseJob.most_recent_job favor "running" jobs (#28119)
+- Don't emit FutureWarning when code not calling old key (#28109)
+- Add ``airflow.api.auth.backend.session`` to backend sessions in compose (#28094)
+- Resolve false warning about calling conf.get on moved item (#28075)
+- Return list of tasks that will be changed (#28066)
+- Handle bad zip files nicely when parsing DAGs. (#28011)
+- Prevent double loading of providers from local paths (#27988)
+- Fix deadlock when chaining multiple empty mapped tasks (#27964)
+- fix: current_state method on TaskInstance doesn't filter by map_index (#27898)
+- Don't log CLI actions if db not initialized (#27851)
+- Make sure we can get out of a faulty scheduler state (#27834)
+- dagrun, ``next_dagruns_to_examine``, add MySQL index hint (#27821)
+- Handle DAG disappearing mid-flight when dag verification happens (#27720)
+- fix: continue checking sla (#26968)
+- Allow generation of connection URI to work when no conn type (#26765)
+
+Misc/Internal
+^^^^^^^^^^^^^
+- Add automated version replacement in example dag indexes (#28090)
+- Cleanup and do housekeeping with plugin examples (#28537)
+- Limit ``SQLAlchemy`` to below ``2.0`` (#28725)
+- Bump ``json5`` from ``1.0.1`` to ``1.0.2`` in ``/airflow/www`` (#28715)
+- Fix some docs on using sensors with taskflow (#28708)
+- Change Architecture and OperatingSystem classes into ``Enums`` (#28627)
+- Add doc-strings and small improvement to email util (#28634)
+- Fix ``Connection.get_extra`` type (#28594)
+- navbar, cap dropdown size, and add scroll bar (#28561)
+- Emit warnings for ``conf.get*`` from the right source location (#28543)
+- Move MyPY plugins of ours to dev folder (#28498)
+- Add retry to ``purge_inactive_dag_warnings`` (#28481)
+- Re-enable Plyvel on ARM as it now builds cleanly (#28443)
+- Add SIGUSR2 handler for LocalTaskJob and workers to aid debugging (#28309)
+- Convert ``test_task_command`` to Pytest and ``unquarantine`` tests in it (#28247)
+- Make invalid characters exception more readable (#28181)
+- Bump decode-uri-component from ``0.2.0`` to ``0.2.2`` in ``/airflow/www`` (#28080)
+- Use asserts instead of exceptions for executor not started (#28019)
+- Simplify dataset ``subgraph`` logic (#27987)
+- Order TIs by ``map_index`` (#27904)
+- Additional info about Segmentation Fault in ``LocalTaskJob`` (#27381)
+
+Doc Only Changes
+^^^^^^^^^^^^^^^^
+- Mention mapped operator in cluster policy doc (#28885)
+- Slightly improve description of Dynamic DAG generation preamble (#28650)
+- Restructure Docs (#27235)
+- Update scheduler docs about low priority tasks (#28831)
+- Clarify that versioned constraints are fixed at release time (#28762)
+- Clarify about docker compose (#28729)
+- Adding an example dag for dynamic task mapping (#28325)
+- Use docker compose v2 command (#28605)
+- Add AIRFLOW_PROJ_DIR to docker-compose example (#28517)
+- Remove outdated Optional Provider Feature outdated documentation (#28506)
+- Add documentation for [core] mp_start_method config (#27993)
+- Documentation for the LocalTaskJob return code counter (#27972)
+- Note which versions of Python are supported (#27798)
+
+
+Airflow 2.5.0 (2022-12-02)
+--------------------------
+
+Significant Changes
+^^^^^^^^^^^^^^^^^^^
+
+``airflow dags test`` no longer performs a backfill job (#26400)
+""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
+
+ In order to make ``airflow dags test`` more useful as a testing and debugging tool, we no
+ longer run a backfill job and instead run a "local task runner". Users can still backfill
+ their DAGs using the ``airflow dags backfill`` command.
+
+Airflow config section ``kubernetes`` renamed to ``kubernetes_executor`` (#26873)
+"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
+ KubernetesPodOperator no longer considers any core kubernetes config params, so this section now only applies to kubernetes executor. Renaming it reduces potential for confusion.
+
+``AirflowException`` is now thrown as soon as any dependent tasks of ExternalTaskSensor fails (#27190)
+""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
+``ExternalTaskSensor`` no longer hangs indefinitely when ``failed_states`` is set, an ``execute_date_fn`` is used, and some but not all of the dependent tasks fail.
+ Instead, an ``AirflowException`` is thrown as soon as any of the dependent tasks fail.
+ Any code handling this failure in addition to timeouts should move to caching the ``AirflowException`` ``BaseClass`` and not only the ``AirflowSensorTimeout`` subclass.
+
+The Airflow config option ``scheduler.deactivate_stale_dags_interval`` has been renamed to ``scheduler.parsing_cleanup_interval`` (#27828).
+""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
+ The old option will continue to work but will issue deprecation warnings, and will be removed entirely in Airflow 3.
+
+New Features
+^^^^^^^^^^^^
+- ``TaskRunner``: notify of component start and finish (#27855)
+- Add DagRun state change to the Listener plugin system(#27113)
+- Metric for raw task return codes (#27155)
+- Add logic for XComArg to pull specific map indexes (#27771)
+- Clear TaskGroup (#26658, #28003)
+- Add critical section query duration metric (#27700)
+- Add: #23880 :: Audit log for ``AirflowModelViews(Variables/Connection)`` (#24079, #27994, #27923)
+- Add postgres 15 support (#27444)
+- Expand tasks in mapped group at run time (#27491)
+- reset commits, clean submodules (#27560)
+- scheduler_job, add metric for scheduler loop timer (#27605)
+- Allow datasets to be used in taskflow (#27540)
+- Add expanded_ti_count to ti context (#27680)
+- Add user comment to task instance and dag run (#26457, #27849, #27867)
+- Enable copying DagRun JSON to clipboard (#27639)
+- Implement extra controls for SLAs (#27557)
+- add dag parsed time in DAG view (#27573)
+- Add max_wait for exponential_backoff in BaseSensor (#27597)
+- Expand tasks in mapped group at parse time (#27158)
+- Add disable retry flag on backfill (#23829)
+- Adding sensor decorator (#22562)
+- Api endpoint update ti (#26165)
+- Filtering datasets by recent update events (#26942)
+- Support ``Is /not`` Null filter for value is None on ``webui`` (#26584)
+- Add search to datasets list (#26893)
+- Split out and handle 'params' in mapped operator (#26100)
+- Add authoring API for TaskGroup mapping (#26844)
+- Add ``one_done`` trigger rule (#26146)
+- Create a more efficient airflow dag test command that also has better local logging (#26400)
+- Support add/remove permissions to roles commands (#26338)
+- Auto tail file logs in Web UI (#26169)
+- Add triggerer info to task instance in API (#26249)
+- Flag to deserialize value on custom XCom backend (#26343)
+
+Improvements
+^^^^^^^^^^^^
+- Allow depth-first execution (#27827)
+- UI: Update offset height if data changes (#27865)
+- Improve TriggerRuleDep typing and readability (#27810)
+- Make views requiring session, keyword only args (#27790)
+- Optimize ``TI.xcom_pull()`` with explicit task_ids and map_indexes (#27699)
+- Allow hyphens in pod id used by k8s executor (#27737)
+- optimise task instances filtering (#27102)
+- Use context managers to simplify log serve management (#27756)
+- Fix formatting leftovers (#27750)
+- Improve task deadlock messaging (#27734)
+- Improve "sensor timeout" messaging (#27733)
+- Replace urlparse with ``urlsplit`` (#27389)
+- Align TaskGroup semantics to AbstractOperator (#27723)
+- Add new files to parsing queue on every loop of dag processing (#27060)
+- Make Kubernetes Executor & Scheduler resilient to error during PMH execution (#27611)
+- Separate dataset deps into individual graphs (#27356)
+- Use log.exception where more economical than log.error (#27517)
+- Move validation ``branch_task_ids`` into ``SkipMixin`` (#27434)
+- Coerce LazyXComAccess to list when pushed to XCom (#27251)
+- Update cluster-policies.rst docs (#27362)
+- Add warning if connection type already registered within the provider (#27520)
+- Activate debug logging in commands with --verbose option (#27447)
+- Add classic examples for Python Operators (#27403)
+- change ``.first()`` to ``.scalar()`` (#27323)
+- Improve reset_dag_run description (#26755)
+- Add examples and ``howtos`` about sensors (#27333)
+- Make grid view widths adjustable (#27273)
+- Sorting plugins custom menu links by category before name (#27152)
+- Simplify DagRun.verify_integrity (#26894)
+- Add mapped task group info to serialization (#27027)
+- Correct the JSON style used for Run config in Grid View (#27119)
+- No ``extra__conn_type__`` prefix required for UI behaviors (#26995)
+- Improve dataset update blurb (#26878)
+- Rename kubernetes config section to kubernetes_executor (#26873)
+- decode params for dataset searches (#26941)
+- Get rid of the DAGRun details page & rely completely on Grid (#26837)
+- Fix scheduler ``crashloopbackoff`` when using ``hostname_callable`` (#24999)
+- Reduce log verbosity in KubernetesExecutor. (#26582)
+- Don't iterate tis list twice for no reason (#26740)
+- Clearer code for PodGenerator.deserialize_model_file (#26641)
+- Don't import kubernetes unless you have a V1Pod (#26496)
+- Add updated_at column to DagRun and Ti tables (#26252)
+- Move the deserialization of custom XCom Backend to 2.4.0 (#26392)
+- Avoid calculating all elements when one item is needed (#26377)
+- Add ``__future__``.annotations automatically by isort (#26383)
+- Handle list when serializing expand_kwargs (#26369)
+- Apply PEP-563 (Postponed Evaluation of Annotations) to core airflow (#26290)
+- Add more weekday operator and sensor examples #26071 (#26098)
+- Align TaskGroup semantics to AbstractOperator (#27723)
+
+Bug Fixes
+^^^^^^^^^
+- Gracefully handle whole config sections being renamed (#28008)
+- Add allow list for imports during deserialization (#27887)
+- Soft delete datasets that are no longer referenced in DAG schedules or task outlets (#27828)
+- Redirect to home view when there are no valid tags in the URL (#25715)
+- Refresh next run datasets info in dags view (#27839)
+- Make MappedTaskGroup depend on its expand inputs (#27876)
+- Make DagRun state updates for paused DAGs faster (#27725)
+- Don't explicitly set include_examples to False on task run command (#27813)
+- Fix menu border color (#27789)
+- Fix backfill queued task getting reset to scheduled state. (#23720)
+- Fix clearing child dag mapped tasks from parent dag (#27501)
+- Handle json encoding of ``V1Pod`` in task callback (#27609)
+- Fix ExternalTaskSensor can't check zipped dag (#27056)
+- Avoid re-fetching DAG run in TriggerDagRunOperator (#27635)
+- Continue on exception when retrieving metadata (#27665)
+- External task sensor fail fix (#27190)
+- Add the default None when pop actions (#27537)
+- Display parameter values from serialized dag in trigger dag view. (#27482, #27944)
+- Move TriggerDagRun conf check to execute (#27035)
+- Resolve trigger assignment race condition (#27072)
+- Update google_analytics.html (#27226)
+- Fix some bug in web ui dags list page (auto-refresh & jump search null state) (#27141)
+- Fixed broken URL for docker-compose.yaml (#26721)
+- Fix xcom arg.py .zip bug (#26636)
+- Fix 404 ``taskInstance`` errors and split into two tables (#26575)
+- Fix browser warning of improper thread usage (#26551)
+- template rendering issue fix (#26390)
+- Clear ``autoregistered`` DAGs if there are any import errors (#26398)
+- Fix ``from airflow import version`` lazy import (#26239)
+- allow scroll in triggered dag runs modal (#27965)
+
+Misc/Internal
+^^^^^^^^^^^^^
+- Remove ``is_mapped`` attribute (#27881)
+- Simplify FAB table resetting (#27869)
+- Fix old-style typing in Base Sensor (#27871)
+- Switch (back) to late imports (#27730)
+- Completed D400 for multiple folders (#27748)
+- simplify notes accordion test (#27757)
+- completed D400 for ``airflow/callbacks/* airflow/cli/*`` (#27721)
+- Completed D400 for ``airflow/api_connexion/* directory`` (#27718)
+- Completed D400 for ``airflow/listener/* directory`` (#27731)
+- Completed D400 for ``airflow/lineage/* directory`` (#27732)
+- Update API & Python Client versions (#27642)
+- Completed D400 & D401 for ``airflow/api/*`` directory (#27716)
+- Completed D400 for multiple folders (#27722)
+- Bump ``minimatch`` from ``3.0.4 to 3.0.8`` in ``/airflow/www`` (#27688)
+- Bump loader-utils from ``1.4.1 to 1.4.2 ``in ``/airflow/www`` (#27697)
+- Disable nested task mapping for now (#27681)
+- bump alembic minimum version (#27629)
+- remove unused code.html (#27585)
+- Enable python string normalization everywhere (#27588)
+- Upgrade dependencies in order to avoid backtracking (#27531)
+- Strengthen a bit and clarify importance of triaging issues (#27262)
+- Deduplicate type hints (#27508)
+- Add stub 'yield' to ``BaseTrigger.run`` (#27416)
+- Remove upper-bound limit to dask (#27415)
+- Limit Dask to under ``2022.10.1`` (#27383)
+- Update old style typing (#26872)
+- Enable string normalization for docs (#27269)
+- Slightly faster up/downgrade tests (#26939)
+- Deprecate use of core get_kube_client in PodManager (#26848)
+- Add ``memray`` files to ``gitignore / dockerignore`` (#27001)
+- Bump sphinx and ``sphinx-autoapi`` (#26743)
+- Simplify ``RTIF.delete_old_records()`` (#26667)
+- migrate last react files to typescript (#26112)
+- Work around ``pyupgrade`` edge cases (#26384)
+
+Doc only changes
+^^^^^^^^^^^^^^^^
+- Document dag_file_processor_timeouts metric as deprecated (#27067)
+- Drop support for PostgreSQL 10 (#27594)
+- Update index.rst (#27529)
+- Add note about pushing the lazy XCom proxy to XCom (#27250)
+- Fix BaseOperator link (#27441)
+- [docs] best-practices add use variable with template example. (#27316)
+- docs for custom view using plugin (#27244)
+- Update graph view and grid view on overview page (#26909)
+- Documentation fixes (#26819)
+- make consistency on markup title string level (#26696)
+- Add documentation to dag test function (#26713)
+- Fix broken URL for ``docker-compose.yaml`` (#26726)
+- Add a note against use of top level code in timetable (#26649)
+- Fix example_datasets dag names (#26495)
+- Update docs: zip-like effect is now possible in task mapping (#26435)
+- changing to task decorator in docs from classic operator use (#25711)
+
+Airflow 2.4.3 (2022-11-14)
+--------------------------
+
+Significant Changes
+^^^^^^^^^^^^^^^^^^^
+
+Make ``RotatingFilehandler`` used in ``DagProcessor`` non-caching (#27223)
+""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
+
+ In case you want to decrease cache memory when ``CONFIG_PROCESSOR_MANAGER_LOGGER=True``, and you have your local settings created before,
+ you can update ``processor_manager_handler`` to use ``airflow.utils.log.non_caching_file_handler.NonCachingRotatingFileHandler`` handler instead of ``logging.RotatingFileHandler``. (#27065)
+
+Bug Fixes
+^^^^^^^^^
+- Fix double logging with some task logging handler (#27591)
+- Replace FAB url filtering function with Airflow's (#27576)
+- Fix mini scheduler expansion of mapped task (#27506)
+- ``SLAMiss`` is nullable and not always given back when pulling task instances (#27423)
+- Fix behavior of ``_`` when searching for DAGs (#27448)
+- Fix getting the ``dag/task`` ids from BaseExecutor (#27550)
+- Fix SQLAlchemy primary key black-out error on DDRQ (#27538)
+- Fix IntegrityError during webserver startup (#27297)
+- Add case insensitive constraint to username (#27266)
+- Fix python external template keys (#27256)
+- Reduce extraneous task log requests (#27233)
+- Make ``RotatingFilehandler`` used in ``DagProcessor`` non-caching (#27223)
+- Listener: Set task on SQLAlchemy TaskInstance object (#27167)
+- Fix dags list page auto-refresh & jump search null state (#27141)
+- Set ``executor.job_id`` to ``BackfillJob.id`` for backfills (#27020)
+
+Misc/Internal
+^^^^^^^^^^^^^
+- Bump loader-utils from ``1.4.0`` to ``1.4.1`` in ``/airflow/www`` (#27552)
+- Reduce log level for k8s ``TCP_KEEPALIVE`` etc warnings (#26981)
+
+Doc only changes
+^^^^^^^^^^^^^^^^
+- Use correct executable in docker compose docs (#27529)
+- Fix wording in DAG Runs description (#27470)
+- Document that ``KubernetesExecutor`` overwrites container args (#27450)
+- Fix ``BaseOperator`` links (#27441)
+- Correct timer units to seconds from milliseconds. (#27360)
+- Add missed import in the Trigger Rules example (#27309)
+- Update SLA wording to reflect it is relative to ``Dag Run`` start. (#27111)
+- Add ``kerberos`` environment variables to the docs (#27028)
+
+Airflow 2.4.2 (2022-10-23)
+--------------------------
+
+Significant Changes
+^^^^^^^^^^^^^^^^^^^
+
+Default for ``[webserver] expose_stacktrace`` changed to ``False`` (#27059)
+"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
+
+ The default for ``[webserver] expose_stacktrace`` has been set to ``False``, instead of ``True``. This means administrators must opt-in to expose tracebacks to end users.
+
+Bug Fixes
+^^^^^^^^^
+- Make tracebacks opt-in (#27059)
+- Add missing AUTOINC/SERIAL for FAB tables (#26885)
+- Add separate error handler for 405(Method not allowed) errors (#26880)
+- Don't re-patch pods that are already controlled by current worker (#26778)
+- Handle mapped tasks in task duration chart (#26722)
+- Fix task duration cumulative chart (#26717)
+- Avoid 500 on dag redirect (#27064)
+- Filter dataset dependency data on webserver (#27046)
+- Remove double collection of dags in ``airflow dags reserialize`` (#27030)
+- Fix auto refresh for graph view (#26926)
+- Don't overwrite connection extra with invalid json (#27142)
+- Fix next run dataset modal links (#26897)
+- Change dag audit log sort by date from asc to desc (#26895)
+- Bump min version of jinja2 (#26866)
+- Add missing colors to ``state_color_mapping`` jinja global (#26822)
+- Fix running debuggers inside ``airflow tasks test`` (#26806)
+- Fix warning when using xcomarg dependencies (#26801)
+- demote Removed state in priority for displaying task summaries (#26789)
+- Ensure the log messages from operators during parsing go somewhere (#26779)
+- Add restarting state to TaskState Enum in REST API (#26776)
+- Allow retrieving error message from data.detail (#26762)
+- Simplify origin string cleaning (#27143)
+- Remove DAG parsing from StandardTaskRunner (#26750)
+- Fix non-hidden cumulative chart on duration view (#26716)
+- Remove TaskFail duplicates check (#26714)
+- Fix airflow tasks run --local when dags_folder differs from that of processor (#26509)
+- Fix yarn warning from d3-color (#27139)
+- Fix version for a couple configurations (#26491)
+- Revert "No grid auto-refresh for backfill dag runs (#25042)" (#26463)
+- Retry on Airflow Schedule DAG Run DB Deadlock (#26347)
+
+Misc/Internal
+^^^^^^^^^^^^^
+- Clean-ups around task-mapping code (#26879)
+- Move user-facing string to template (#26815)
+- add icon legend to datasets graph (#26781)
+- Bump ``sphinx`` and ``sphinx-autoapi`` (#26743)
+- Simplify ``RTIF.delete_old_records()`` (#26667)
+- Bump FAB to ``4.1.4`` (#26393)
+
+Doc only changes
+^^^^^^^^^^^^^^^^
+- Fixed triple quotes in task group example (#26829)
+- Documentation fixes (#26819)
+- make consistency on markup title string level (#26696)
+- Add a note against use of top level code in timetable (#26649)
+- Fix broken URL for ``docker-compose.yaml`` (#26726)
+
+
+Airflow 2.4.1 (2022-09-30)
+--------------------------
+
+Significant Changes
+^^^^^^^^^^^^^^^^^^^
+
+No significant changes.
+
+Bug Fixes
+^^^^^^^^^
+
+- When rendering template, unmap task in context (#26702)
+- Fix scroll overflow for ConfirmDialog (#26681)
+- Resolve deprecation warning re ``Table.exists()`` (#26616)
+- Fix XComArg zip bug (#26636)
+- Use COALESCE when ordering runs to handle NULL (#26626)
+- Check user is active (#26635)
+- No missing user warning for public admin (#26611)
+- Allow MapXComArg to resolve after serialization (#26591)
+- Resolve warning about DISTINCT ON query on dags view (#26608)
+- Log warning when secret backend kwargs is invalid (#26580)
+- Fix grid view log try numbers (#26556)
+- Template rendering issue in passing ``templates_dict`` to task decorator (#26390)
+- Fix Deferrable stuck as ``scheduled`` during backfill (#26205)
+- Suppress SQLALCHEMY_TRACK_MODIFICATIONS warning in db init (#26617)
+- Correctly set ``json_provider_class`` on Flask app so it uses our encoder (#26554)
+- Fix WSGI root app (#26549)
+- Fix deadlock when mapped task with removed upstream is rerun (#26518)
+- ExecutorConfigType should be ``cacheable`` (#26498)
+- Fix proper joining of the path for logs retrieved from celery workers (#26493)
+- DAG Deps extends ``base_template`` (#26439)
+- Don't update backfill run from the scheduler (#26342)
+
+Doc only changes
+^^^^^^^^^^^^^^^^
+
+- Clarify owner links document (#26515)
+- Fix invalid RST in dataset concepts doc (#26434)
+- Document the ``non-sensitive-only`` option for ``expose_config`` (#26507)
+- Fix ``example_datasets`` dag names (#26495)
+- Zip-like effect is now possible in task mapping (#26435)
+- Use task decorator in docs instead of classic operators (#25711)
+
+Airflow 2.4.0 (2022-09-19)
+--------------------------
+
+Significant Changes
+^^^^^^^^^^^^^^^^^^^
+
+Data-aware Scheduling and ``Dataset`` concept added to Airflow
+""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
+
+New to this release of Airflow is the concept of Datasets to Airflow, and with it a new way of scheduling dags:
+data-aware scheduling.
+
+This allows DAG runs to be automatically created as a result of a task "producing" a dataset. In some ways
+this can be thought of as the inverse of ``TriggerDagRunOperator``, where instead of the producing DAG
+controlling which DAGs get created, the consuming DAGs can "listen" for changes.
+
+A dataset is identified by a URI:
+
+.. code-block:: python
+
+ from airflow import Dataset
+
+ # The URI doesn't have to be absolute
+ dataset = Dataset(uri="my-dataset")
+ # Or you can use a scheme to show where it lives.
+ dataset2 = Dataset(uri="s3://bucket/prefix")
+
+To create a DAG that runs whenever a Dataset is updated use the new ``schedule`` parameter (see below) and
+pass a list of 1 or more Datasets:
+
+.. code-block:: python
+
+ with DAG(dag_id='dataset-consumer', schedule=[dataset]):
+ ...
+
+And to mark a task as producing a dataset pass the dataset(s) to the ``outlets`` attribute:
+
+.. code-block:: python
+
+ @task(outlets=[dataset])
+ def my_task():
+ ...
+
+
+ # Or for classic operators
+ BashOperator(task_id="update-ds", bash_command=..., outlets=[dataset])
+
+If you have the producer and consumer in different files you do not need to use the same Dataset object, two
+``Dataset()``\s created with the same URI are equal.
+
+Datasets represent the abstract concept of a dataset, and (for now) do not have any direct read or write
+capability - in this release we are adding the foundational feature that we will build upon.
+
+For more info on Datasets please see :doc:`/authoring-and-scheduling/datasets`.
+
+Expanded dynamic task mapping support
+"""""""""""""""""""""""""""""""""""""
+
+Dynamic task mapping now includes support for ``expand_kwargs``, ``zip`` and ``map``.
+
+For more info on dynamic task mapping please see :doc:`/authoring-and-scheduling/dynamic-task-mapping`.
+
+DAGS used in a context manager no longer need to be assigned to a module variable (#23592)
+""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
+
+Previously you had to assign a DAG to a module-level variable in order for Airflow to pick it up. For example this
+
+
+.. code-block:: python
+
+ with DAG(dag_id="example") as dag:
+ ...
+
+
+ @dag
+ def dag_maker():
+ ...
+
+
+ dag2 = dag_maker()
+
+
+can become
+
+.. code-block:: python
+
+ with DAG(dag_id="example"):
+ ...
+
+
+ @dag
+ def dag_maker():
+ ...
+
+
+ dag_maker()
+
+If you want to disable the behaviour for any reason then set ``auto_register=False`` on the dag:
+
+.. code-block:: python
+
+ # This dag will not be picked up by Airflow as it's not assigned to a variable
+ with DAG(dag_id="example", auto_register=False):
+ ...
+
+Deprecation of ``schedule_interval`` and ``timetable`` arguments (#25410)
+"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
+
+We added new DAG argument ``schedule`` that can accept a cron expression, timedelta object, *timetable* object, or list of dataset objects. Arguments ``schedule_interval`` and ``timetable`` are deprecated.
+
+If you previously used the ``@daily`` cron preset, your DAG may have looked like this:
+
+.. code-block:: python
+
+ with DAG(
+ dag_id="my_example",
+ start_date=datetime(2021, 1, 1),
+ schedule_interval="@daily",
+ ):
+ ...
+
+Going forward, you should use the ``schedule`` argument instead:
+
+.. code-block:: python
+
+ with DAG(
+ dag_id="my_example",
+ start_date=datetime(2021, 1, 1),
+ schedule="@daily",
+ ):
+ ...
+
+The same is true if you used a custom timetable. Previously you would have used the ``timetable`` argument:
+
+.. code-block:: python
+
+ with DAG(
+ dag_id="my_example",
+ start_date=datetime(2021, 1, 1),
+ timetable=EventsTimetable(event_dates=[pendulum.datetime(2022, 4, 5)]),
+ ):
+ ...
+
+Now you should use the ``schedule`` argument:
+
+.. code-block:: python
+
+ with DAG(
+ dag_id="my_example",
+ start_date=datetime(2021, 1, 1),
+ schedule=EventsTimetable(event_dates=[pendulum.datetime(2022, 4, 5)]),
+ ):
+ ...
+
+Removal of experimental Smart Sensors (#25507)
+""""""""""""""""""""""""""""""""""""""""""""""
+
+Smart Sensors were added in 2.0 and deprecated in favor of Deferrable operators in 2.2, and have now been removed.
+
+``airflow.contrib`` packages and deprecated modules are dynamically generated (#26153, #26179, #26167)
+""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
+
+The ``airflow.contrib`` packages and deprecated modules from Airflow 1.10 in ``airflow.hooks``, ``airflow.operators``, ``airflow.sensors`` packages are now dynamically generated modules and while users can continue using the deprecated contrib classes, they are no longer visible for static code check tools and will be reported as missing. It is recommended for the users to move to the non-deprecated classes.
+
+``DBApiHook`` and ``SQLSensor`` have moved (#24836)
+"""""""""""""""""""""""""""""""""""""""""""""""""""
+
+``DBApiHook`` and ``SQLSensor`` have been moved to the ``apache-airflow-providers-common-sql`` provider.
+
+DAG runs sorting logic changed in grid view (#25090)
+""""""""""""""""""""""""""""""""""""""""""""""""""""
+
+The ordering of DAG runs in the grid view has been changed to be more "natural".
+The new logic generally orders by data interval, but a custom ordering can be
+applied by setting the DAG to use a custom timetable.
+
+
+New Features
+^^^^^^^^^^^^
+- Add Data-aware Scheduling (`AIP-48 `_)
+- Add ``@task.short_circuit`` TaskFlow decorator (#25752)
+- Make ``execution_date_or_run_id`` optional in ``tasks test`` command (#26114)
+- Automatically register DAGs that are used in a context manager (#23592, #26398)
+- Add option of sending DAG parser logs to stdout. (#25754)
+- Support multiple ``DagProcessors`` parsing files from different locations. (#25935)
+- Implement ``ExternalPythonOperator`` (#25780)
+- Make execution_date optional for command ``dags test`` (#26111)
+- Implement ``expand_kwargs()`` against a literal list (#25925)
+- Add trigger rule tooltip (#26043)
+- Add conf parameter to CLI for airflow dags test (#25900)
+- Include scheduled slots in pools view (#26006)
+- Add ``output`` property to ``MappedOperator`` (#25604)
+- Add roles delete command to cli (#25854)
+- Add Airflow specific warning classes (#25799)
+- Add support for ``TaskGroup`` in ``ExternalTaskSensor`` (#24902)
+- Add ``@task.kubernetes`` taskflow decorator (#25663)
+- Add a way to import Airflow without side-effects (#25832)
+- Let timetables control generated run_ids. (#25795)
+- Allow per-timetable ordering override in grid view (#25633)
+- Grid logs for mapped instances (#25610, #25621, #25611)
+- Consolidate to one ``schedule`` param (#25410)
+- DAG regex flag in backfill command (#23870)
+- Adding support for owner links in the Dags view UI (#25280)
+- Ability to clear a specific DAG Run's task instances via REST API (#23516)
+- Possibility to document DAG with a separate markdown file (#25509)
+- Add parsing context to DAG Parsing (#25161)
+- Implement ``CronTriggerTimetable`` (#23662)
+- Add option to mask sensitive data in UI configuration page (#25346)
+- Create new databases from the ORM (#24156)
+- Implement ``XComArg.zip(*xcom_args)`` (#25176)
+- Introduce ``sla_miss`` metric (#23402)
+- Implement ``map()`` semantic (#25085)
+- Add override method to TaskGroupDecorator (#25160)
+- Implement ``expand_kwargs()`` (#24989)
+- Add parameter to turn off SQL query logging (#24570)
+- Add ``DagWarning`` model, and a check for missing pools (#23317)
+- Add Task Logs to Grid details panel (#24249)
+- Added small health check server and endpoint in scheduler(#23905)
+- Add built-in External Link for ``ExternalTaskMarker`` operator (#23964)
+- Add default task retry delay config (#23861)
+- Add clear DagRun endpoint. (#23451)
+- Add support for timezone as string in cron interval timetable (#23279)
+- Add auto-refresh to dags home page (#22900, #24770)
+
+Improvements
+^^^^^^^^^^^^
+
+- Add more weekday operator and sensor examples #26071 (#26098)
+- Add subdir parameter to dags reserialize command (#26170)
+- Update zombie message to be more descriptive (#26141)
+- Only send an ``SlaCallbackRequest`` if the DAG is scheduled (#26089)
+- Promote ``Operator.output`` more (#25617)
+- Upgrade API files to typescript (#25098)
+- Less ``hacky`` double-rendering prevention in mapped task (#25924)
+- Improve Audit log (#25856)
+- Remove mapped operator validation code (#25870)
+- More ``DAG(schedule=...)`` improvements (#25648)
+- Reduce ``operator_name`` dupe in serialized JSON (#25819)
+- Make grid view group/mapped summary UI more consistent (#25723)
+- Remove useless statement in ``task_group_to_grid`` (#25654)
+- Add optional data interval to ``CronTriggerTimetable`` (#25503)
+- Remove unused code in ``/grid`` endpoint (#25481)
+- Add and document description fields (#25370)
+- Improve Airflow logging for operator Jinja template processing (#25452)
+- Update core example DAGs to use ``@task.branch`` decorator (#25242)
+- Update DAG ``audit_log`` route (#25415)
+- Change stdout and stderr access mode to append in commands (#25253)
+- Remove ``getTasks`` from Grid view (#25359)
+- Improve taskflow type hints with ParamSpec (#25173)
+- Use tables in grid details panes (#25258)
+- Explicitly list ``@dag`` arguments (#25044)
+- More typing in ``SchedulerJob`` and ``TaskInstance`` (#24912)
+- Patch ``getfqdn`` with more resilient version (#24981)
+- Replace all ``NBSP`` characters by ``whitespaces`` (#24797)
+- Re-serialize all DAGs on ``airflow db upgrade`` (#24518)
+- Rework contract of try_adopt_task_instances method (#23188)
+- Make ``expand()`` error vague so it's not misleading (#24018)
+- Add enum validation for ``[webserver]analytics_tool`` (#24032)
+- Add ``dttm`` searchable field in audit log (#23794)
+- Allow more parameters to be piped through via ``execute_in_subprocess`` (#23286)
+- Use ``func.count`` to count rows (#23657)
+- Remove stale serialized dags (#22917)
+- AIP45 Remove dag parsing in airflow run local (#21877)
+- Add support for queued state in DagRun update endpoint. (#23481)
+- Add fields to dagrun endpoint (#23440)
+- Use ``sql_alchemy_conn`` for celery result backend when ``result_backend`` is not set (#24496)
+
+Bug Fixes
+^^^^^^^^^
+
+- Have consistent types between the ORM and the migration files (#24044, #25869)
+- Disallow any dag tags longer than 100 char (#25196)
+- Add the dag_id to ``AirflowDagCycleException`` message (#26204)
+- Properly build URL to retrieve logs independently from system (#26337)
+- For worker log servers only bind to IPV6 when dual stack is available (#26222)
+- Fix ``TaskInstance.task`` not defined before ``handle_failure`` (#26040)
+- Undo secrets backend config caching (#26223)
+- Fix faulty executor config serialization logic (#26191)
+- Show ``DAGs`` and ``Datasets`` menu links based on role permission (#26183)
+- Allow setting ``TaskGroup`` tooltip via function docstring (#26028)
+- Fix RecursionError on graph view of a DAG with many tasks (#26175)
+- Fix backfill occasional deadlocking (#26161)
+- Fix ``DagRun.start_date`` not set during backfill with ``--reset-dagruns`` True (#26135)
+- Use label instead of id for dynamic task labels in graph (#26108)
+- Don't fail DagRun when leaf ``mapped_task`` is SKIPPED (#25995)
+- Add group prefix to decorated mapped task (#26081)
+- Fix UI flash when triggering with dup logical date (#26094)
+- Fix Make items nullable for ``TaskInstance`` related endpoints to avoid API errors (#26076)
+- Fix ``BranchDateTimeOperator`` to be ``timezone-awreness-insensitive`` (#25944)
+- Fix legacy timetable schedule interval params (#25999)
+- Fix response schema for ``list-mapped-task-instance`` (#25965)
+- Properly check the existence of missing mapped TIs (#25788)
+- Fix broken auto-refresh on grid view (#25950)
+- Use per-timetable ordering in grid UI (#25880)
+- Rewrite recursion when parsing DAG into iteration (#25898)
+- Find cross-group tasks in ``iter_mapped_dependants`` (#25793)
+- Fail task if mapping upstream fails (#25757)
+- Support ``/`` in variable get endpoint (#25774)
+- Use cfg default_wrap value for grid logs (#25731)
+- Add origin request args when triggering a run (#25729)
+- Operator name separate from class (#22834)
+- Fix incorrect data interval alignment due to assumption on input time alignment (#22658)
+- Return None if an ``XComArg`` fails to resolve (#25661)
+- Correct ``json`` arg help in ``airflow variables set`` command (#25726)
+- Added MySQL index hint to use ``ti_state`` on ``find_zombies`` query (#25725)
+- Only excluded actually expanded fields from render (#25599)
+- Grid, fix toast for ``axios`` errors (#25703)
+- Fix UI redirect (#26409)
+- Require dag_id arg for dags list-runs (#26357)
+- Check for queued states for dags auto-refresh (#25695)
+- Fix upgrade code for the ``dag_owner_attributes`` table (#25579)
+- Add map index to task logs api (#25568)
+- Ensure that zombie tasks for dags with errors get cleaned up (#25550)
+- Make extra link work in UI (#25500)
+- Sync up plugin API schema and definition (#25524)
+- First/last names can be empty (#25476)
+- Refactor DAG pages to be consistent (#25402)
+- Check ``expand_kwargs()`` input type before unmapping (#25355)
+- Filter XCOM by key when calculating map lengths (#24530)
+- Fix ``ExternalTaskSensor`` not working with dynamic task (#25215)
+- Added exception catching to send default email if template file raises any exception (#24943)
+- Bring ``MappedOperator`` members in sync with ``BaseOperator`` (#24034)
+
+
+Misc/Internal
+^^^^^^^^^^^^^
+
+- Add automatically generated ``ERD`` schema for the ``MetaData`` DB (#26217)
+- Mark serialization functions as internal (#26193)
+- Remove remaining deprecated classes and replace them with ``PEP562`` (#26167)
+- Move ``dag_edges`` and ``task_group_to_dict`` to corresponding util modules (#26212)
+- Lazily import many modules to improve import speed (#24486, #26239)
+- FIX Incorrect typing information (#26077)
+- Add missing contrib classes to deprecated dictionaries (#26179)
+- Re-configure/connect the ``ORM`` after forking to run a DAG processor (#26216)
+- Remove cattrs from lineage processing. (#26134)
+- Removed deprecated contrib files and replace them with ``PEP-562`` getattr (#26153)
+- Make ``BaseSerialization.serialize`` "public" to other classes. (#26142)
+- Change the template to use human readable task_instance description (#25960)
+- Bump ``moment-timezone`` from ``0.5.34`` to ``0.5.35`` in ``/airflow/www`` (#26080)
+- Fix Flask deprecation warning (#25753)
+- Add ``CamelCase`` to generated operations types (#25887)
+- Fix migration issues and tighten the CI upgrade/downgrade test (#25869)
+- Fix type annotations in ``SkipMixin`` (#25864)
+- Workaround setuptools editable packages path issue (#25848)
+- Bump ``undici`` from ``5.8.0 to 5.9.1`` in /airflow/www (#25801)
+- Add custom_operator_name attr to ``_BranchPythonDecoratedOperator`` (#25783)
+- Clarify ``filename_template`` deprecation message (#25749)
+- Use ``ParamSpec`` to replace ``...`` in Callable (#25658)
+- Remove deprecated modules (#25543)
+- Documentation on task mapping additions (#24489)
+- Remove Smart Sensors (#25507)
+- Fix ``elasticsearch`` test config to avoid warning on deprecated template (#25520)
+- Bump ``terser`` from ``4.8.0 to 4.8.1`` in /airflow/ui (#25178)
+- Generate ``typescript`` types from rest ``API`` docs (#25123)
+- Upgrade utils files to ``typescript`` (#25089)
+- Upgrade remaining context file to ``typescript``. (#25096)
+- Migrate files to ``ts`` (#25267)
+- Upgrade grid Table component to ``ts.`` (#25074)
+- Skip mapping against mapped ``ti`` if it returns None (#25047)
+- Refactor ``js`` file structure (#25003)
+- Move mapped kwargs introspection to separate type (#24971)
+- Only assert stuff for mypy when type checking (#24937)
+- Bump ``moment`` from ``2.29.3 to 2.29.4`` in ``/airflow/www`` (#24885)
+- Remove "bad characters" from our codebase (#24841)
+- Remove ``xcom_push`` flag from ``BashOperator`` (#24824)
+- Move Flask hook registration to end of file (#24776)
+- Upgrade more javascript files to ``typescript`` (#24715)
+- Clean up task decorator type hints and docstrings (#24667)
+- Preserve original order of providers' connection extra fields in UI (#24425)
+- Rename ``charts.css`` to ``chart.css`` (#24531)
+- Rename ``grid.css`` to ``chart.css`` (#24529)
+- Misc: create new process group by ``set_new_process_group`` utility (#24371)
+- Airflow UI fix Prototype Pollution (#24201)
+- Bump ``moto`` version (#24222)
+- Remove unused ``[github_enterprise]`` from ref docs (#24033)
+- Clean up ``f-strings`` in logging calls (#23597)
+- Add limit for ``JPype1`` (#23847)
+- Simply json responses (#25518)
+- Add min attrs version (#26408)
+
+Doc only changes
+^^^^^^^^^^^^^^^^
+- Add url prefix setting for ``Celery`` Flower (#25986)
+- Updating deprecated configuration in examples (#26037)
+- Fix wrong link for taskflow tutorial (#26007)
+- Reorganize tutorials into a section (#25890)
+- Fix concept doc for dynamic task map (#26002)
+- Update code examples from "classic" operators to taskflow (#25845, #25657)
+- Add instructions on manually fixing ``MySQL`` Charset problems (#25938)
+- Prefer the local Quick Start in docs (#25888)
+- Fix broken link to ``Trigger Rules`` (#25840)
+- Improve docker documentation (#25735)
+- Correctly link to Dag parsing context in docs (#25722)
+- Add note on ``task_instance_mutation_hook`` usage (#25607)
+- Note that TaskFlow API automatically passes data between tasks (#25577)
+- Update DAG run to clarify when a DAG actually runs (#25290)
+- Update tutorial docs to include a definition of operators (#25012)
+- Rewrite the Airflow documentation home page (#24795)
+- Fix ``task-generated mapping`` example (#23424)
+- Add note on subtle logical date change in ``2.2.0`` (#24413)
+- Add missing import in best-practices code example (#25391)
+
+
+
+Airflow 2.3.4 (2022-08-23)
+--------------------------
+
+Significant Changes
+^^^^^^^^^^^^^^^^^^^
+
+Added new config ``[logging]log_formatter_class`` to fix timezone display for logs on UI (#24811)
+"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
+
+If you are using a custom Formatter subclass in your ``[logging]logging_config_class``, please inherit from ``airflow.utils.log.timezone_aware.TimezoneAware`` instead of ``logging.Formatter``.
+For example, in your ``custom_config.py``:
+
+.. code-block:: python
+
+ from airflow.utils.log.timezone_aware import TimezoneAware
+
+ # before
+ class YourCustomFormatter(logging.Formatter):
+ ...
+
+
+ # after
+ class YourCustomFormatter(TimezoneAware):
+ ...
+
+
+ AIRFLOW_FORMATTER = LOGGING_CONFIG["formatters"]["airflow"]
+ AIRFLOW_FORMATTER["class"] = "somewhere.your.custom_config.YourCustomFormatter"
+ # or use TimezoneAware class directly. If you don't have custom Formatter.
+ AIRFLOW_FORMATTER["class"] = "airflow.utils.log.timezone_aware.TimezoneAware"
+
+Bug Fixes
+^^^^^^^^^
+
+- Disable ``attrs`` state management on ``MappedOperator`` (#24772)
+- Serialize ``pod_override`` to JSON before pickling ``executor_config`` (#24356)
+- Fix ``pid`` check (#24636)
+- Rotate session id during login (#25771)
+- Fix mapped sensor with reschedule mode (#25594)
+- Cache the custom secrets backend so the same instance gets re-used (#25556)
+- Add right padding (#25554)
+- Fix reducing mapped length of a mapped task at runtime after a clear (#25531)
+- Fix ``airflow db reset`` when dangling tables exist (#25441)
+- Change ``disable_verify_ssl`` behaviour (#25023)
+- Set default task group in dag.add_task method (#25000)
+- Removed interfering force of index. (#25404)
+- Remove useless logging line (#25347)
+- Adding mysql index hint to use index on ``task_instance.state`` in critical section query (#25673)
+- Configurable umask to all daemonized processes. (#25664)
+- Fix the errors raised when None is passed to template filters (#25593)
+- Allow wildcarded CORS origins (#25553)
+- Fix "This Session's transaction has been rolled back" (#25532)
+- Fix Serialization error in ``TaskCallbackRequest`` (#25471)
+- fix - resolve bash by absolute path (#25331)
+- Add ``__repr__`` to ParamsDict class (#25305)
+- Only load distribution of a name once (#25296)
+- convert ``TimeSensorAsync`` ``target_time`` to utc on call time (#25221)
+- call ``updateNodeLabels`` after ``expandGroup`` (#25217)
+- Stop SLA callbacks gazumping other callbacks and DOS'ing the ``DagProcessorManager`` queue (#25147)
+- Fix ``invalidateQueries`` call (#25097)
+- ``airflow/www/package.json``: Add name, version fields. (#25065)
+- No grid auto-refresh for backfill dag runs (#25042)
+- Fix tag link on dag detail page (#24918)
+- Fix zombie task handling with multiple schedulers (#24906)
+- Bind log server on worker to ``IPv6`` address (#24755) (#24846)
+- Add ``%z`` for ``%(asctime)s`` to fix timezone for logs on UI (#24811)
+- ``TriggerDagRunOperator.operator_extra_links`` is attr (#24676)
+- Send DAG timeout callbacks to processor outside of ``prohibit_commit`` (#24366)
+- Don't rely on current ORM structure for db clean command (#23574)
+- Clear next method when clearing TIs (#23929)
+- Two typing fixes (#25690)
+
+Doc only changes
+^^^^^^^^^^^^^^^^
+
+- Update set-up-database.rst (#24983)
+- Fix syntax in mysql setup documentation (#24893 (#24939)
+- Note how DAG policy works with default_args (#24804)
+- Update PythonVirtualenvOperator Howto (#24782)
+- Doc: Add hyperlinks to Github PRs for Release Notes (#24532)
+
+Misc/Internal
+^^^^^^^^^^^^^
+
+- Remove depreciation warning when use default remote tasks logging handlers (#25764)
+- clearer method name in scheduler_job.py (#23702)
+- Bump cattrs version (#25689)
+- Include missing mention of ``external_executor_id`` in ``sql_engine_collation_for_ids`` docs (#25197)
+- Refactor ``DR.task_instance_scheduling_decisions`` (#24774)
+- Sort operator extra links (#24992)
+- Extends ``resolve_xcom_backend`` function level documentation (#24965)
+- Upgrade FAB to 4.1.3 (#24884)
+- Limit Flask to <2.3 in the wake of 2.2 breaking our tests (#25511)
+- Limit astroid version to < 2.12 (#24982)
+- Move javascript compilation to host (#25169)
+- Bump typing-extensions and mypy for ParamSpec (#25088)
+
+
+Airflow 2.3.3 (2022-07-09)
+--------------------------
+
+Significant Changes
+^^^^^^^^^^^^^^^^^^^
+
+We've upgraded Flask App Builder to a major version 4.* (#24399)
+""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
+
+Flask App Builder is one of the important components of Airflow Webserver, as
+it uses a lot of dependencies that are essential to run the webserver and integrate it
+in enterprise environments - especially authentication.
+
+The FAB 4.* upgrades a number of dependencies to major releases, which upgrades them to versions
+that have a number of security issues fixed. A lot of tests were performed to bring the dependencies
+in a backwards-compatible way, however the dependencies themselves implement breaking changes in their
+internals so it might be that some of those changes might impact the users in case they are using the
+libraries for their own purposes.
+
+One important change that you likely will need to apply to Oauth configuration is to add
+``server_metadata_url`` or ``jwks_uri`` and you can read about it more
+in `this issue `_.
+
+Here is the list of breaking changes in dependencies that comes together with FAB 4:
+
+ * ``Flask`` from 1.X to 2.X `breaking changes `__
+
+ * ``flask-jwt-extended`` 3.X to 4.X `breaking changes: `__
+
+ * ``Jinja2`` 2.X to 3.X `breaking changes: `__
+
+ * ``Werkzeug`` 1.X to 2.X `breaking changes `__
+
+ * ``pyJWT`` 1.X to 2.X `breaking changes: `__
+
+ * ``Click`` 7.X to 8.X `breaking changes: `__
+
+ * ``itsdangerous`` 1.X to 2.X `breaking changes `__
+
+Bug Fixes
+^^^^^^^^^
+
+- Fix exception in mini task scheduler (#24865)
+- Fix cycle bug with attaching label to task group (#24847)
+- Fix timestamp defaults for ``sensorinstance`` (#24638)
+- Move fallible ``ti.task.dag`` assignment back inside ``try/except`` block (#24533) (#24592)
+- Add missing types to ``FSHook`` (#24470)
+- Mask secrets in ``stdout`` for ``airflow tasks test`` (#24362)
+- ``DebugExecutor`` use ``ti.run()`` instead of ``ti._run_raw_task`` (#24357)
+- Fix bugs in ``URI`` constructor for ``MySQL`` connection (#24320)
+- Missing ``scheduleinterval`` nullable true added in ``openapi`` (#24253)
+- Unify ``return_code`` interface for task runner (#24093)
+- Handle occasional deadlocks in trigger with retries (#24071)
+- Remove special serde logic for mapped ``op_kwargs`` (#23860)
+- ``ExternalTaskSensor`` respects ``soft_fail`` if the external task enters a ``failed_state`` (#23647)
+- Fix ``StatD`` timing metric units (#21106)
+- Add ``cache_ok`` flag to sqlalchemy TypeDecorators. (#24499)
+- Allow for ``LOGGING_LEVEL=DEBUG`` (#23360)
+- Fix grid date ticks (#24738)
+- Debounce status highlighting in Grid view (#24710)
+- Fix Grid vertical scrolling (#24684)
+- don't try to render child rows for closed groups (#24637)
+- Do not calculate grid root instances (#24528)
+- Maintain grid view selection on filtering upstream (#23779)
+- Speed up ``grid_data`` endpoint by 10x (#24284)
+- Apply per-run log templates to log handlers (#24153)
+- Don't crash scheduler if exec config has old k8s objects (#24117)
+- ``TI.log_url`` fix for ``map_index`` (#24335)
+- Fix migration ``0080_2_0_2`` - Replace null values before setting column not null (#24585)
+- Patch ``sql_alchemy_conn`` if old Postgres schemes used (#24569)
+- Seed ``log_template`` table (#24511)
+- Fix deprecated ``log_id_template`` value (#24506)
+- Fix toast messages (#24505)
+- Add indexes for CASCADE deletes for ``task_instance`` (#24488)
+- Return empty dict if Pod JSON encoding fails (#24478)
+- Improve grid rendering performance with a custom tooltip (#24417, #24449)
+- Check for ``run_id`` for grid group summaries (#24327)
+- Optimize calendar view for cron scheduled DAGs (#24262)
+- Use ``get_hostname`` instead of ``socket.getfqdn`` (#24260)
+- Check that edge nodes actually exist (#24166)
+- Fix ``useTasks`` crash on error (#24152)
+- Do not fail re-queued TIs (#23846)
+- Reduce grid view API calls (#24083)
+- Rename Permissions to Permission Pairs. (#24065)
+- Replace ``use_task_execution_date`` with ``use_task_logical_date`` (#23983)
+- Grid fix details button truncated and small UI tweaks (#23934)
+- Add TaskInstance State ``REMOVED`` to finished states and success states (#23797)
+- Fix mapped task immutability after clear (#23667)
+- Fix permission issue for dag that has dot in name (#23510)
+- Fix closing connection ``dbapi.get_pandas_df`` (#23452)
+- Check bag DAG ``schedule_interval`` match timetable (#23113)
+- Parse error for task added to multiple groups (#23071)
+- Fix flaky order of returned dag runs (#24405)
+- Migrate ``jsx`` files that affect run/task selection to ``tsx`` (#24509)
+- Fix links to sources for examples (#24386)
+- Set proper ``Content-Type`` and ``chartset`` on ``grid_data`` endpoint (#24375)
+
+Doc only changes
+^^^^^^^^^^^^^^^^
+
+- Update templates doc to mention ``extras`` and format Airflow ``Vars`` / ``Conns`` (#24735)
+- Document built in Timetables (#23099)
+- Alphabetizes two tables (#23923)
+- Clarify that users should not use Maria DB (#24556)
+- Add imports to deferring code samples (#24544)
+- Add note about image regeneration in June 2022 (#24524)
+- Small cleanup of ``get_current_context()`` chapter (#24482)
+- Fix default 2.2.5 ``log_id_template`` (#24455)
+- Update description of installing providers separately from core (#24454)
+- Mention context variables and logging (#24304)
+
+Misc/Internal
+^^^^^^^^^^^^^
+
+- Remove internet explorer support (#24495)
+- Removing magic status code numbers from ``api_connexion`` (#24050)
+- Upgrade FAB to ``4.1.2`` (#24619)
+- Switch Markdown engine to ``markdown-it-py`` (#19702)
+- Update ``rich`` to latest version across the board. (#24186)
+- Get rid of ``TimedJSONWebSignatureSerializer`` (#24519)
+- Update flask-appbuilder ``authlib``/ ``oauth`` dependency (#24516)
+- Upgrade to ``webpack`` 5 (#24485)
+- Add ``typescript`` (#24337)
+- The JWT claims in the request to retrieve logs have been standardized: we use ``nbf`` and ``aud`` claims for
+ maturity and audience of the requests. Also "filename" payload field is used to keep log name. (#24519)
+- Address all ``yarn`` test warnings (#24722)
+- Upgrade to react 18 and chakra 2 (#24430)
+- Refactor ``DagRun.verify_integrity`` (#24114)
+- Upgrade FAB to ``4.1.1`` (#24399)
+- We now need at least ``Flask-WTF 0.15`` (#24621)
+
+
+Airflow 2.3.2 (2022-06-04)
+--------------------------
+
+No significant changes
+
+Bug Fixes
+^^^^^^^^^
+
+- Run the ``check_migration`` loop at least once
+- Fix grid view for mapped tasks (#24059)
+- Icons in grid view for different DAG run types (#23970)
+- Faster grid view (#23951)
+- Disallow calling expand with no arguments (#23463)
+- Add missing ``is_mapped`` field to Task response. (#23319)
+- DagFileProcessorManager: Start a new process group only if current process not a session leader (#23872)
+- Mask sensitive values for not-yet-running TIs (#23807)
+- Add cascade to ``dag_tag`` to ``dag`` foreign key (#23444)
+- Use ``--subdir`` argument value for standalone dag processor. (#23864)
+- Highlight task states by hovering on legend row (#23678)
+- Fix and speed up grid view (#23947)
+- Prevent UI from crashing if grid task instances are null (#23939)
+- Remove redundant register exit signals in ``dag-processor`` command (#23886)
+- Add ``__wrapped__`` property to ``_TaskDecorator`` (#23830)
+- Fix UnboundLocalError when ``sql`` is empty list in DbApiHook (#23816)
+- Enable clicking on DAG owner in autocomplete dropdown (#23804)
+- Simplify flash message for ``_airflow_moved`` tables (#23635)
+- Exclude missing tasks from the gantt view (#23627)
+
+Doc only changes
+^^^^^^^^^^^^^^^^
+
+- Add column names for DB Migration Reference (#23853)
+
+Misc/Internal
+^^^^^^^^^^^^^
+
+- Remove pinning for xmltodict (#23992)
+
+
Airflow 2.3.1 (2022-05-25)
--------------------------
@@ -118,7 +1258,7 @@ Continuing the effort to bind TaskInstance to a DagRun, XCom entries are now als
Task log templates are now read from the metadata database instead of ``airflow.cfg`` (#20165)
""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
-Previously, a task’s log is dynamically rendered from the ``[core] log_filename_template`` and ``[elasticsearch] log_id_template`` config values at runtime. This resulted in unfortunate characteristics, e.g. it is impractical to modify the config value after an Airflow instance is running for a while, since all existing task logs have be saved under the previous format and cannot be found with the new config value.
+Previously, a task's log is dynamically rendered from the ``[core] log_filename_template`` and ``[elasticsearch] log_id_template`` config values at runtime. This resulted in unfortunate characteristics, e.g. it is impractical to modify the config value after an Airflow instance is running for a while, since all existing task logs have be saved under the previous format and cannot be found with the new config value.
A new ``log_template`` table is introduced to solve this problem. This table is synchronized with the aforementioned config values every time Airflow starts, and a new field ``log_template_id`` is added to every DAG run to point to the format used by tasks (``NULL`` indicates the first ever entry for compatibility).
@@ -135,9 +1275,9 @@ No change in behavior is expected. This was necessary in order to take advantag
XCom now defined by ``run_id`` instead of ``execution_date`` (#20975)
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
-As a continuation to the TaskInstance-DagRun relation change started in Airflow 2.2, the ``execution_date`` columns on XCom has been removed from the database, and replaced by an `association proxy `_ field at the ORM level. If you access Airflow’s metadata database directly, you should rewrite the implementation to use the ``run_id`` column instead.
+As a continuation to the TaskInstance-DagRun relation change started in Airflow 2.2, the ``execution_date`` columns on XCom has been removed from the database, and replaced by an `association proxy `_ field at the ORM level. If you access Airflow's metadata database directly, you should rewrite the implementation to use the ``run_id`` column instead.
-Note that Airflow’s metadatabase definition on both the database and ORM levels are considered implementation detail without strict backward compatibility guarantees.
+Note that Airflow's metadatabase definition on both the database and ORM levels are considered implementation detail without strict backward compatibility guarantees.
Non-JSON-serializable params deprecated (#21135).
"""""""""""""""""""""""""""""""""""""""""""""""""
@@ -178,14 +1318,14 @@ Details in the `SQLAlchemy Changelog `_ fields at the ORM level. If you access Airflow’s metadatabase directly, you should rewrite the implementation to use the ``run_id`` columns instead.
+As a part of the TaskInstance-DagRun relation change, the ``execution_date`` columns on TaskInstance and TaskReschedule have been removed from the database, and replaced by `association proxy `_ fields at the ORM level. If you access Airflow's metadatabase directly, you should rewrite the implementation to use the ``run_id`` columns instead.
-Note that Airflow’s metadatabase definition on both the database and ORM levels are considered implementation detail without strict backward compatibility guarantees.
+Note that Airflow's metadatabase definition on both the database and ORM levels are considered implementation detail without strict backward compatibility guarantees.
DaskExecutor - Dask Worker Resources and queues
"""""""""""""""""""""""""""""""""""""""""""""""
If dask workers are not started with complementary resources to match the specified queues, it will now result in an ``AirflowException``\ , whereas before it would have just ignored the ``queue`` argument.
+Logical date of a DAG run triggered from the web UI now have its sub-second component set to zero
+"""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""""
+
+Due to a change in how the logical date (``execution_date``) is generated for a manual DAG run, a manual DAG run's logical date may not match its time-of-trigger, but have its sub-second part zero-ed out. For example, a DAG run triggered on ``2021-10-11T12:34:56.78901`` would have its logical date set to ``2021-10-11T12:34:56.00000``.
+
+This may affect some logic that expects on this quirk to detect whether a run is triggered manually or not. Note that ``dag_run.run_type`` is a more authoritative value for this purpose. Also, if you need this distinction between automated and manually-triggered run for "next execution date" calculation, please also consider using the new data interval variables instead, which provide a more consistent behavior between the two run types.
+
New Features
^^^^^^^^^^^^
@@ -6828,9 +7975,7 @@ New signature:
.. code-block:: python
- def wait_for_transfer_job(
- self, job, expected_statuses=(GcpTransferOperationStatus.SUCCESS,)
- ):
+ def wait_for_transfer_job(self, job, expected_statuses=(GcpTransferOperationStatus.SUCCESS,)):
...
The behavior of ``wait_for_transfer_job`` has changed:
@@ -7798,7 +8943,6 @@ There are five roles created for Airflow by default: Admin, User, Op, Viewer, an
Breaking changes
~~~~~~~~~~~~~~~~
-
* AWS Batch Operator renamed property queue to job_queue to prevent conflict with the internal queue from CeleryExecutor - AIRFLOW-2542
* Users created and stored in the old users table will not be migrated automatically. FAB's built-in authentication support must be reconfigured.
* Airflow dag home page is now ``/home`` (instead of ``/admin``\ ).
@@ -8711,14 +9855,14 @@ A logger is the entry point into the logging system. Each logger is a named buck
Each message that is written to the logger is a Log Record. Each log record contains a log level indicating the severity of that specific message. A log record can also contain useful metadata that describes the event that is being logged. This can include details such as a stack trace or an error code.
-When a message is given to the logger, the log level of the message is compared to the log level of the logger. If the log level of the message meets or exceeds the log level of the logger itself, the message will undergo further processing. If it doesn’t, the message will be ignored.
+When a message is given to the logger, the log level of the message is compared to the log level of the logger. If the log level of the message meets or exceeds the log level of the logger itself, the message will undergo further processing. If it doesn't, the message will be ignored.
Once a logger has determined that a message needs to be processed, it is passed to a Handler. This configuration is now more flexible and can be easily be maintained in a single file.
Changes in Airflow Logging
~~~~~~~~~~~~~~~~~~~~~~~~~~
-Airflow's logging mechanism has been refactored to use Python’s built-in ``logging`` module to perform logging of the application. By extending classes with the existing ``LoggingMixin``\ , all the logging will go through a central logger. Also the ``BaseHook`` and ``BaseOperator`` already extend this class, so it is easily available to do logging.
+Airflow's logging mechanism has been refactored to use Python's built-in ``logging`` module to perform logging of the application. By extending classes with the existing ``LoggingMixin``\ , all the logging will go through a central logger. Also the ``BaseHook`` and ``BaseOperator`` already extend this class, so it is easily available to do logging.
The main benefit is easier configuration of the logging by setting a single centralized python file. Disclaimer; there is still some inline configuration, but this will be removed eventually. The new logging class is defined by setting the dotted classpath in your ``~/airflow/airflow.cfg`` file:
diff --git a/SELECTIVE_CHECKS.md b/SELECTIVE_CHECKS.md
deleted file mode 100644
index 3a92d9c817987..0000000000000
--- a/SELECTIVE_CHECKS.md
+++ /dev/null
@@ -1,144 +0,0 @@
-
-
-# Selective CI Checks
-
-In order to optimise our CI jobs, we've implemented optimisations to only run selected checks for some
-kind of changes. The logic implemented reflects the internal architecture of Airflow 2.0 packages
-and it helps to keep down both the usage of jobs in GitHub Actions as well as CI feedback time to
-contributors in case of simpler changes.
-
-We have the following test types (separated by packages in which they are):
-
-* Always - those are tests that should be always executed (always folder)
-* Core - for the core Airflow functionality (core folder)
-* API - Tests for the Airflow API (api and api_connexion folders)
-* CLI - Tests for the Airflow CLI (cli folder)
-* WWW - Tests for the Airflow webserver (www folder)
-* Providers - Tests for all Providers of Airflow (providers folder)
-* Other - all other tests (all other folders that are not part of any of the above)
-
-We also have several special kinds of tests that are not separated by packages but they are marked with
-pytest markers. They can be found in any of those packages and they can be selected by the appropriate
-pytest custom command line options. See `TESTING.rst `_ for details but those are:
-
-* Integration - tests that require external integration images running in docker-compose
-* Quarantined - tests that are flaky and need to be fixed
-* Postgres - tests that require Postgres database. They are only run when backend is Postgres
-* MySQL - tests that require MySQL database. They are only run when backend is MySQL
-
-Even if the types are separated, In case they share the same backend version/python version, they are
-run sequentially in the same job, on the same CI machine. Each of them in a separate `docker run` command
-and with additional docker cleaning between the steps to not fall into the trap of exceeding resource
-usage in one big test run, but also not to increase the number of jobs per each Pull Request.
-
-The logic implemented for the changes works as follows:
-
-1) In case of direct push (so when PR gets merged) or scheduled run, we always run all tests and checks.
- This is in order to make sure that the merge did not miss anything important. The remainder of the logic
- is executed only in case of Pull Requests. We do not add providers tests in case DEFAULT_BRANCH is
- different than main, because providers are only important in main branch and PRs to main branch.
-
-2) We retrieve which files have changed in the incoming Merge Commit (github.sha is a merge commit
- automatically prepared by GitHub in case of Pull Request, so we can retrieve the list of changed
- files from that commit directly).
-
-3) If any of the important, environment files changed (Dockerfile, ci scripts, setup.py, GitHub workflow
- files), then we again run all tests and checks. Those are cases where the logic of the checks changed
- or the environment for the checks changed so we want to make sure to check everything. We do not add
- providers tests in case DEFAULT_BRANCH is different than main, because providers are only
- important in main branch and PRs to main branch.
-
-4) If any of py files changed: we need to have CI image and run full static checks so we enable image building
-
-5) If any of docs changed: we need to have CI image so we enable image building
-
-6) If any of chart files changed, we need to run helm tests so we enable helm unit tests
-
-7) If any of API files changed, we need to run API tests so we enable them
-
-8) If any of the relevant source files that trigger the tests have changed at all. Those are airflow
- sources, chart, tests and kubernetes_tests. If any of those files changed, we enable tests and we
- enable image building, because the CI images are needed to run tests.
-
-9) Then we determine which types of the tests should be run. We count all the changed files in the
- relevant airflow sources (airflow, chart, tests, kubernetes_tests) first and then we count how many
- files changed in different packages:
-
- * in any case tests in `Always` folder are run. Those are special tests that should be run any time
- modifications to any Python code occurs. Example test of this type is verifying proper structure of
- the project including proper naming of all files.
- * if any of the Airflow API files changed we enable `API` test type
- * if any of the Airflow CLI files changed we enable `CLI` test type and Kubernetes tests (the
- K8S tests depend on CLI changes as helm chart uses CLI to run Airflow).
- * if this is a main branch and if any of the Provider files changed we enable `Providers` test type
- * if any of the WWW files changed we enable `WWW` test type
- * if any of the Kubernetes files changed we enable `Kubernetes` test type
- * Then we subtract count of all the `specific` above per-type changed files from the count of
- all changed files. In case there are any files changed, then we assume that some unknown files
- changed (likely from the core of airflow) and in this case we enable all test types above and the
- Core test types - simply because we do not want to risk to miss anything.
- * In all cases where tests are enabled we also add Integration and - depending on
- the backend used = Postgres or MySQL types of tests.
-
-10) Quarantined tests are always run when tests are run - we need to run them often to observe how
- often they fail so that we can decide to move them out of quarantine. Details about the
- Quarantined tests are described in `TESTING.rst `_
-
-11) There is a special case of static checks. In case the above logic determines that the CI image
- needs to be built, we run long and more comprehensive version of static checks - including
- Mypy, Flake8. And those tests are run on all files, no matter how many files changed.
- In case the image is not built, we run only simpler set of changes - the longer static checks
- that require CI image are skipped, and we only run the tests on the files that changed in the incoming
- commit - unlike flake8/mypy, those static checks are per-file based and they should not miss any
- important change.
-
-Similarly to selective tests we also run selective security scans. In Pull requests,
-the Python scan will only run when there is a python code change and JavaScript scan will only run if
-there is a JavaScript or `yarn.lock` file change. For main builds, all scans are always executed.
-
-The selective check algorithm is shown here:
-
-
-````mermaid
-flowchart TD
-A(PR arrives)-->B[Selective Check]
-B-->C{Direct push merge?}
-C-->|Yes| N[Enable images]
-N-->D(Run Full Test
+Quarantined
Run full static checks)
-C-->|No| E[Retrieve changed files]
-E-->F{Environment files changed?}
-F-->|Yes| N
-F-->|No| G{Docs changed}
-G-->|Yes| O[Enable images building]
-O-->I{Chart files changed?}
-G-->|No| I
-I-->|Yes| P[Enable helm tests]
-P-->J{API files changed}
-I-->|No| J
-J-->|Yes| Q[Enable API tests]
-Q-->H{Sources changed?}
-J-->|No| H
-H-->|Yes| R[Enable Pytests]
-R-->K[Determine test type]
-K-->S{Core files changed}
-S-->|Yes| N
-S-->|No| M(Run selected test+
Integration, Quarantined
Full static checks)
-H-->|No| L[Skip running test
Run subset of static checks]
-```
diff --git a/STATIC_CODE_CHECKS.rst b/STATIC_CODE_CHECKS.rst
index dccade808e425..b27d170d1c01e 100644
--- a/STATIC_CODE_CHECKS.rst
+++ b/STATIC_CODE_CHECKS.rst
@@ -59,7 +59,7 @@ After installation, pre-commit hooks are run automatically when you commit the c
only run on the files that you change during your commit, so they are usually pretty fast and do
not slow down your iteration speed on your changes. There are also ways to disable the ``pre-commits``
temporarily when you commit your code with ``--no-verify`` switch or skip certain checks that you find
-to much disturbing your local workflow. See `Available pre-commit checks<#available-pre-commit-checks>`_
+to much disturbing your local workflow. See `Available pre-commit checks <#available-pre-commit-checks>`_
and `Using pre-commit <#using-pre-commit>`_
.. note:: Additional prerequisites might be needed
@@ -76,7 +76,7 @@ The current list of prerequisites is limited to ``xmllint``:
Some pre-commit hooks also require the Docker Engine to be configured as the static
checks are executed in the Docker environment (See table in the
-Available pre-commit checks<#available-pre-commit-checks>`_ . You should build the images
+`Available pre-commit checks <#available-pre-commit-checks>`_ . You should build the images
locally before installing pre-commit checks as described in `BREEZE.rst `__.
Sometimes your image is outdated and needs to be rebuilt because some dependencies have been changed.
@@ -129,195 +129,210 @@ require Breeze Docker image to be build locally.
.. BEGIN AUTO-GENERATED STATIC CHECK LIST
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| ID | Description | Image |
-+========================================================+==================================================================+=========+
-| black | Run Black (the uncompromising Python code formatter) | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| blacken-docs | Run black on python code blocks in documentation files | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| check-airflow-2-1-compatibility | Check that providers are 2.1 compatible. | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| check-airflow-config-yaml-consistent | Checks for consistency between config.yml and default_config.cfg | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| check-airflow-providers-have-extras | Checks providers available when declared by extras in setup.py | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| check-apache-license-rat | Check if licenses are OK for Apache | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| check-base-operator-partial-arguments | Check BaseOperator and partial() arguments | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| check-base-operator-usage | * Check BaseOperator[Link] core imports | |
-| | * Check BaseOperator[Link] other imports | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| check-boring-cyborg-configuration | Checks for Boring Cyborg configuration consistency | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| check-breeze-top-dependencies-limited | Breeze should have small number of top-level dependencies | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| check-builtin-literals | Require literal syntax when initializing Python builtin types | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| check-changelog-has-no-duplicates | Check changelogs for duplicate entries | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| check-daysago-import-from-utils | Make sure days_ago is imported from airflow.utils.dates | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| check-docstring-param-types | Check that docstrings do not specify param types | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| check-executables-have-shebangs | Check that executables have shebang | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| check-extra-packages-references | Checks setup extra packages | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| check-extras-order | Check order of extras in Dockerfile | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| check-for-inclusive-language | Check for language that we do not accept as community | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| check-hooks-apply | Check if all hooks apply to the repository | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| check-incorrect-use-of-LoggingMixin | Make sure LoggingMixin is not used alone | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| check-integrations-are-consistent | Check if integration list is consistent in various places | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| check-merge-conflict | Check that merge conflicts are not being committed | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| check-newsfragments-are-valid | Check newsfragments are valid | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| check-no-providers-in-core-examples | No providers imports in core example DAGs | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| check-no-relative-imports | No relative imports | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| check-persist-credentials-disabled-in-github-workflows | Check that workflow files have persist-credentials disabled | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| check-pre-commit-information-consistent | Update information re pre-commit hooks and verify ids and names | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| check-provide-create-sessions-imports | Check provide_session and create_session imports | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| check-provider-yaml-valid | Validate providers.yaml files | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| check-providers-init-file-missing | Provider init file is missing | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| check-providers-subpackages-init-file-exist | Provider subpackage init files are there | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| check-pydevd-left-in-code | Check for pydevd debug statements accidentally left | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| check-revision-heads-map | Check that the REVISION_HEADS_MAP is up-to-date | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| check-safe-filter-usage-in-html | Don't use safe in templates | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| check-setup-order | Check order of dependencies in setup.cfg and setup.py | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| check-start-date-not-used-in-defaults | 'start_date' not to be defined in default_args in example_dags | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| check-system-tests-present | Check if system tests have required segments of code | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| check-xml | Check XML files with xmllint | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| codespell | Run codespell to check for common misspellings in files | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| debug-statements | Detect accidentally committed debug statements | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| detect-private-key | Detect if private key is added to the repository | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| doctoc | Add TOC for md and rst files | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| end-of-file-fixer | Make sure that there is an empty line at the end | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| fix-encoding-pragma | Remove encoding header from python files | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| flynt | Run flynt string format converter for Python | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| forbid-tabs | Fail if tabs are used in the project | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| identity | Print input to the static check hooks for troubleshooting | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| insert-license | * Add license for all SQL files | |
-| | * Add license for all rst files | |
-| | * Add license for all CSS/JS/PUML/TS/TSX files | |
-| | * Add license for all JINJA template files | |
-| | * Add license for all shell files | |
-| | * Add license for all Python files | |
-| | * Add license for all XML files | |
-| | * Add license for all YAML files | |
-| | * Add license for all md files | |
-| | * Add license for all other files | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| isort | Run isort to sort imports in Python files | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| lint-chart-schema | Lint chart/values.schema.json file | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| lint-css | stylelint | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| lint-dockerfile | Lint dockerfile | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| lint-helm-chart | Lint Helm Chart | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| lint-javascript | * ESLint against airflow/ui | * |
-| | * ESLint against current UI JavaScript files | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| lint-json-schema | * Lint JSON Schema files with JSON Schema | |
-| | * Lint NodePort Service with JSON Schema | |
-| | * Lint Docker compose files with JSON Schema | |
-| | * Lint chart/values.schema.json file with JSON Schema | |
-| | * Lint chart/values.yaml file with JSON Schema | |
-| | * Lint airflow/config_templates/config.yml file with JSON Schema | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| lint-markdown | Run markdownlint | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| lint-openapi | * Lint OpenAPI using spectral | |
-| | * Lint OpenAPI using openapi-spec-validator | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| mixed-line-ending | Detect if mixed line ending is used (\r vs. \r\n) | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| pretty-format-json | Format json files | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| pydocstyle | Run pydocstyle | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| python-no-log-warn | Check if there are no deprecate log warn | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| pyupgrade | Upgrade Python code automatically | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| rst-backticks | Check if RST files use double backticks for code | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| run-flake8 | Run flake8 | * |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| run-mypy | * Run mypy for dev | * |
-| | * Run mypy for core | |
-| | * Run mypy for providers | |
-| | * Run mypy for /docs/ folder | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| run-shellcheck | Check Shell scripts syntax correctness | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| static-check-autoflake | Remove all unused code | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| trailing-whitespace | Remove trailing whitespace at end of line | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| update-breeze-file | Update output of breeze commands in BREEZE.rst | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| update-breeze-readme-config-hash | Update Breeze README.md with config files hash | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| update-extras | Update extras in documentation | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| update-in-the-wild-to-be-sorted | Sort INTHEWILD.md alphabetically | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| update-inlined-dockerfile-scripts | Inline Dockerfile and Dockerfile.ci scripts | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| update-local-yml-file | Update mounts in the local yml file | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| update-migration-references | Update migration ref doc | * |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| update-providers-dependencies | Update cross-dependencies for providers packages | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| update-setup-cfg-file | Update setup.cfg file with all licenses | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| update-spelling-wordlist-to-be-sorted | Sort alphabetically and uniquify spelling_wordlist.txt | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| update-supported-versions | Updates supported versions in documentation | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| update-vendored-in-k8s-json-schema | Vendor k8s definitions into values.schema.json | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| update-version | Update version to the latest version in the documentation | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| yamllint | Check YAML files with yamllint | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
-| yesqa | Remove unnecessary noqa statements | |
-+--------------------------------------------------------+------------------------------------------------------------------+---------+
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| ID | Description | Image |
++===========================================================+==================================================================+=========+
+| black | Run black (python formatter) | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| blacken-docs | Run black on python code blocks in documentation files | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| check-airflow-config-yaml-consistent | Checks for consistency between config.yml and default_config.cfg | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| check-airflow-provider-compatibility | Check compatibility of Providers with Airflow | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| check-apache-license-rat | Check if licenses are OK for Apache | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| check-base-operator-partial-arguments | Check BaseOperator and partial() arguments | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| check-base-operator-usage | * Check BaseOperator[Link] core imports | |
+| | * Check BaseOperator[Link] other imports | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| check-boring-cyborg-configuration | Checks for Boring Cyborg configuration consistency | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| check-breeze-top-dependencies-limited | Breeze should have small number of top-level dependencies | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| check-builtin-literals | Require literal syntax when initializing Python builtin types | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| check-changelog-has-no-duplicates | Check changelogs for duplicate entries | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| check-core-deprecation-classes | Verify using of dedicated Airflow deprecation classes in core | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| check-daysago-import-from-utils | Make sure days_ago is imported from airflow.utils.dates | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| check-decorated-operator-implements-custom-name | Check @task decorator implements custom_operator_name | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| check-docstring-param-types | Check that docstrings do not specify param types | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| check-example-dags-urls | Check that example dags url include provider versions | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| check-executables-have-shebangs | Check that executables have shebang | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| check-extra-packages-references | Checks setup extra packages | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| check-extras-order | Check order of extras in Dockerfile | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| check-for-inclusive-language | Check for language that we do not accept as community | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| check-hooks-apply | Check if all hooks apply to the repository | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| check-incorrect-use-of-LoggingMixin | Make sure LoggingMixin is not used alone | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| check-init-decorator-arguments | Check model __init__ and decorator arguments are in sync | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| check-lazy-logging | Check that all logging methods are lazy | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| check-links-to-example-dags-do-not-use-hardcoded-versions | Check that example dags do not use hard-coded version numbers | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| check-merge-conflict | Check that merge conflicts are not being committed | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| check-newsfragments-are-valid | Check newsfragments are valid | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| check-no-providers-in-core-examples | No providers imports in core example DAGs | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| check-no-relative-imports | No relative imports | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| check-persist-credentials-disabled-in-github-workflows | Check that workflow files have persist-credentials disabled | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| check-pre-commit-information-consistent | Update information re pre-commit hooks and verify ids and names | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| check-provide-create-sessions-imports | Check provide_session and create_session imports | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| check-provider-yaml-valid | Validate provider.yaml files | * |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| check-providers-init-file-missing | Provider init file is missing | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| check-providers-subpackages-init-file-exist | Provider subpackage init files are there | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| check-pydevd-left-in-code | Check for pydevd debug statements accidentally left | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| check-revision-heads-map | Check that the REVISION_HEADS_MAP is up-to-date | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| check-safe-filter-usage-in-html | Don't use safe in templates | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| check-setup-order | Check order of dependencies in setup.cfg and setup.py | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| check-start-date-not-used-in-defaults | 'start_date' not to be defined in default_args in example_dags | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| check-system-tests-present | Check if system tests have required segments of code | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| check-system-tests-tocs | Check that system tests is properly added | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| check-xml | Check XML files with xmllint | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| codespell | Run codespell to check for common misspellings in files | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| compile-www-assets | Compile www assets | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| compile-www-assets-dev | Compile www assets in dev mode | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| create-missing-init-py-files-tests | Create missing init.py files in tests | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| debug-statements | Detect accidentally committed debug statements | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| detect-private-key | Detect if private key is added to the repository | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| doctoc | Add TOC for md and rst files | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| end-of-file-fixer | Make sure that there is an empty line at the end | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| fix-encoding-pragma | Remove encoding header from python files | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| flynt | Run flynt string format converter for Python | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| identity | Print input to the static check hooks for troubleshooting | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| insert-license | * Add license for all SQL files | |
+| | * Add license for all rst files | |
+| | * Add license for all CSS/JS/PUML/TS/TSX files | |
+| | * Add license for all JINJA template files | |
+| | * Add license for all shell files | |
+| | * Add license for all Python files | |
+| | * Add license for all XML files | |
+| | * Add license for all YAML files | |
+| | * Add license for all md files | |
+| | * Add license for all other files | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| isort | Run isort to sort imports in Python files | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| lint-chart-schema | Lint chart/values.schema.json file | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| lint-css | stylelint | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| lint-dockerfile | Lint dockerfile | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| lint-helm-chart | Lint Helm Chart | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| lint-json-schema | * Lint JSON Schema files with JSON Schema | |
+| | * Lint NodePort Service with JSON Schema | |
+| | * Lint Docker compose files with JSON Schema | |
+| | * Lint chart/values.schema.json file with JSON Schema | |
+| | * Lint chart/values.yaml file with JSON Schema | |
+| | * Lint airflow/config_templates/config.yml file with JSON Schema | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| lint-markdown | Run markdownlint | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| lint-openapi | * Lint OpenAPI using spectral | |
+| | * Lint OpenAPI using openapi-spec-validator | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| mixed-line-ending | Detect if mixed line ending is used (\r vs. \r\n) | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| pretty-format-json | Format json files | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| pydocstyle | Run pydocstyle | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| python-no-log-warn | Check if there are no deprecate log warn | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| pyupgrade | Upgrade Python code automatically | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| replace-bad-characters | Replace bad characters | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| rst-backticks | Check if RST files use double backticks for code | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| run-flake8 | Run flake8 | * |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| run-mypy | * Run mypy for dev | * |
+| | * Run mypy for core | |
+| | * Run mypy for providers | |
+| | * Run mypy for /docs/ folder | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| run-shellcheck | Check Shell scripts syntax correctness | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| static-check-autoflake | Remove all unused code | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| trailing-whitespace | Remove trailing whitespace at end of line | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| ts-compile-and-lint-javascript | TS types generation and ESLint against current UI files | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| update-breeze-cmd-output | Update output of breeze commands in BREEZE.rst | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| update-breeze-readme-config-hash | Update Breeze README.md with config files hash | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| update-er-diagram | Update ER diagram | * |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| update-extras | Update extras in documentation | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| update-in-the-wild-to-be-sorted | Sort INTHEWILD.md alphabetically | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| update-inlined-dockerfile-scripts | Inline Dockerfile and Dockerfile.ci scripts | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| update-local-yml-file | Update mounts in the local yml file | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| update-migration-references | Update migration ref doc | * |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| update-providers-dependencies | Update cross-dependencies for providers packages | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| update-spelling-wordlist-to-be-sorted | Sort alphabetically and uniquify spelling_wordlist.txt | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| update-supported-versions | Updates supported versions in documentation | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| update-vendored-in-k8s-json-schema | Vendor k8s definitions into values.schema.json | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| update-version | Update version to the latest version in the documentation | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| yamllint | Check YAML files with yamllint | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
+| yesqa | Remove unnecessary noqa statements | |
++-----------------------------------------------------------+------------------------------------------------------------------+---------+
.. END AUTO-GENERATED STATIC CHECK LIST
@@ -387,7 +402,7 @@ The static code checks can be launched using the Breeze environment.
You run the static code checks via ``breeze static-check`` or commands.
You can see the list of available static checks either via ``--help`` flag or by using the autocomplete
-option. Note that the ``all`` static check runs all configured static checks.
+option.
Run the ``mypy`` check for the currently staged changes:
@@ -417,19 +432,19 @@ Run all checks for the currently staged files:
.. code-block:: bash
- breeze static-checks --type all
+ breeze static-checks
Run all checks for all files:
.. code-block:: bash
- breeze static-checks --type all --all-files
+ breeze static-checks --all-files
Run all checks for last commit :
.. code-block:: bash
- breeze static-checks --type all --last-commit
+ breeze static-checks --last-commit
Debugging pre-commit check scripts requiring image
--------------------------------------------------
diff --git a/TESTING.rst b/TESTING.rst
index 12983726e1ebb..7cbb9d49c9057 100644
--- a/TESTING.rst
+++ b/TESTING.rst
@@ -51,10 +51,48 @@ Follow the guidelines when writing unit tests:
* For new tests, use standard "asserts" of Python and ``pytest`` decorators/context managers for testing
rather than ``unittest`` ones. See `pytest docs `_ for details.
* Use a parameterized framework for tests that have variations in parameters.
+* Use with ``pytest.warn`` to capture warnings rather than ``recwarn`` fixture. We are aiming for 0-warning in our
+ tests, so we run Pytest with ``--disable-warnings`` but instead we have ``pytest-capture-warnings`` plugin that
+ overrides ``recwarn`` fixture behaviour.
**NOTE:** We plan to convert all unit tests to standard "asserts" semi-automatically, but this will be done later
in Airflow 2.0 development phase. That will include setUp/tearDown/context managers and decorators.
+Airflow test types
+------------------
+
+Airflow tests in the CI environment are split into several test types:
+
+* Always - those are tests that should be always executed (always folder)
+* Core - for the core Airflow functionality (core folder)
+* API - Tests for the Airflow API (api and api_connexion folders)
+* CLI - Tests for the Airflow CLI (cli folder)
+* WWW - Tests for the Airflow webserver (www folder)
+* Providers - Tests for all Providers of Airflow (providers folder)
+* Other - all other tests (all other folders that are not part of any of the above)
+
+This is done for three reasons:
+
+1. in order to selectively run only subset of the test types for some PRs
+2. in order to allow parallel execution of the tests on Self-Hosted runners
+
+For case 2. We can utilise memory and CPUs available on both CI and local development machines to run
+test in parallel. This way we can decrease the time of running all tests in self-hosted runners from
+60 minutes to ~15 minutes.
+
+.. note::
+
+ We need to split tests manually into separate suites rather than utilise
+ ``pytest-xdist`` or ``pytest-parallel`` which could be a simpler and much more "native" parallelization
+ mechanism. Unfortunately, we cannot utilise those tools because our tests are not truly ``unit`` tests that
+ can run in parallel. A lot of our tests rely on shared databases - and they update/reset/cleanup the
+ databases while they are executing. They are also exercising features of the Database such as locking which
+ further increases cross-dependency between tests. Until we make all our tests truly unit tests (and not
+ touching the database or until we isolate all such tests to a separate test type, we cannot really rely on
+ frameworks that run tests in parallel. In our solution each of the test types is run in parallel with its
+ own database (!) so when we have 8 test types running in parallel, there are in fact 8 databases run
+ behind the scenes to support them and each of the test types executes its own tests sequentially.
+
Running Unit Tests from PyCharm IDE
-----------------------------------
@@ -108,8 +146,9 @@ To run unit tests from the Visual Studio Code:
:align: center
:alt: Running tests
-Running Unit Tests
---------------------------------
+Running Unit Tests in local virtualenv
+--------------------------------------
+
To run unit, integration, and system tests from the Breeze and your
virtualenv, you can use the `pytest `_ framework.
@@ -158,8 +197,8 @@ for debugging purposes, enter:
pytest --log-cli-level=DEBUG tests/core/test_core.py::TestCore
-Running Tests for a Specified Target Using Breeze from the Host
----------------------------------------------------------------
+Running Tests using Breeze from the Host
+----------------------------------------
If you wish to only run tests and not to drop into the shell, apply the
``tests`` command. You can add extra targets and pytest flags after the ``--`` command. Note that
@@ -168,20 +207,37 @@ to breeze.
.. code-block:: bash
- breeze tests tests/providers/http/hooks/test_http.py tests/core/test_core.py --db-reset --log-cli-level=DEBUG
+ breeze testing tests tests/providers/http/hooks/test_http.py tests/core/test_core.py --db-reset --log-cli-level=DEBUG
You can run the whole test suite without adding the test target:
.. code-block:: bash
- breeze tests --db-reset
+ breeze testing tests --db-reset
You can also specify individual tests or a group of tests:
.. code-block:: bash
- breeze tests --db-reset tests/core/test_core.py::TestCore
+ breeze testing tests --db-reset tests/core/test_core.py::TestCore
+
+You can also limit the tests to execute to specific group of tests
+
+.. code-block:: bash
+
+ breeze testing tests --test-type Core
+
+In case of Providers tests, you can run tests for all providers
+
+.. code-block:: bash
+
+ breeze testing tests --test-type Providers
+You can also limit the set of providers you would like to run tests of
+
+.. code-block:: bash
+
+ breeze testing tests --test-type "Providers[airbyte,http]"
Running Tests of a specified type from the Host
-----------------------------------------------
@@ -201,15 +257,15 @@ kinds of test types:
.. code-block:: bash
- ./breeze-legacy --test-type Core --db-reset tests
+ breeze testing tests --test-type Core --db-reset tests
Runs all provider tests:
.. code-block:: bash
- ./breeze-legacy --test-type Providers --db-reset tests
+ breeze testing tests --test-type Providers --db-reset tests
-* Special kinds of tests - Integration, Quarantined, Postgres, MySQL, which are marked with pytest
+* Special kinds of tests Quarantined, Postgres, MySQL, which are marked with pytest
marks and for those you need to select the type using test-type switch. If you want to run such tests
using breeze, you need to pass appropriate ``--test-type`` otherwise the test will be skipped.
Similarly to the per-directory tests if you do not specify the test or tests to run,
@@ -219,74 +275,168 @@ kinds of test types:
.. code-block:: bash
- ./breeze-legacy --test-type Quarantined tests tests/cli/commands/test_task_command.py --db-reset
+ breeze testing tests --test-type Quarantined tests tests/cli/commands/test_task_command.py --db-reset
Run all Quarantined tests:
.. code-block:: bash
- ./breeze-legacy --test-type Quarantined tests --db-reset
+ breeze testing tests --test-type Quarantined tests --db-reset
-Helm Unit Tests
-===============
-On the Airflow Project, we have decided to stick with pythonic testing for our Helm chart. This makes our chart
-easier to test, easier to modify, and able to run with the same testing infrastructure. To add Helm unit tests
-add them in ``tests/charts``.
+Running full Airflow unit test suite in parallel
+------------------------------------------------
+
+If you run ``breeze testing tests --run-in-parallel`` tests run in parallel
+on your development machine - maxing out the number of parallel runs at the number of cores you
+have available in your Docker engine.
+
+In case you do not have enough memory available to your Docker (8 GB), the ``Integration``. ``Provider``
+and ``Core`` test type are executed sequentially with cleaning the docker setup in-between. This
+allows to print
+
+This allows for massive speedup in full test execution. On 8 CPU machine with 16 cores and 64 GB memory
+and fast SSD disk, the whole suite of tests completes in about 5 minutes (!). Same suite of tests takes
+more than 30 minutes on the same machine when tests are run sequentially.
+
+.. note::
+
+ On MacOS you might have less CPUs and less memory available to run the tests than you have in the host,
+ simply because your Docker engine runs in a Linux Virtual Machine under-the-hood. If you want to make
+ use of the parallelism and memory usage for the CI tests you might want to increase the resources available
+ to your docker engine. See the `Resources `_ chapter
+ in the ``Docker for Mac`` documentation on how to do it.
+
+You can also limit the parallelism by specifying the maximum number of parallel jobs via
+MAX_PARALLEL_TEST_JOBS variable. If you set it to "1", all the test types will be run sequentially.
+
+.. code-block:: bash
+
+ MAX_PARALLEL_TEST_JOBS="1" ./scripts/ci/testing/ci_run_airflow_testing.sh
+
+.. note::
+
+ In case you would like to cleanup after execution of such tests you might have to cleanup
+ some of the docker containers running in case you use ctrl-c to stop execution. You can easily do it by
+ running this command (it will kill all docker containers running so do not use it if you want to keep some
+ docker containers running):
+
+ .. code-block:: bash
+
+ docker kill $(docker ps -q)
+
+Running Backend-Specific Tests
+------------------------------
+
+Tests that are using a specific backend are marked with a custom pytest marker ``pytest.mark.backend``.
+The marker has a single parameter - the name of a backend. It corresponds to the ``--backend`` switch of
+the Breeze environment (one of ``mysql``, ``sqlite``, or ``postgres``). Backend-specific tests only run when
+the Breeze environment is running with the right backend. If you specify more than one backend
+in the marker, the test runs for all specified backends.
+
+Example of the ``postgres`` only test:
.. code-block:: python
- class TestBaseChartTest(unittest.TestCase):
+ @pytest.mark.backend("postgres")
+ def test_copy_expert(self):
...
-To render the chart create a YAML string with the nested dictionary of options you wish to test. You can then
-use our ``render_chart`` function to render the object of interest into a testable Python dictionary. Once the chart
-has been rendered, you can use the ``render_k8s_object`` function to create a k8s model object. It simultaneously
-ensures that the object created properly conforms to the expected resource spec and allows you to use object values
-instead of nested dictionaries.
-Example test here:
+Example of the ``postgres,mysql`` test (they are skipped with the ``sqlite`` backend):
.. code-block:: python
- from tests.charts.helm_template_generator import render_chart, render_k8s_object
+ @pytest.mark.backend("postgres", "mysql")
+ def test_celery_executor(self):
+ ...
- git_sync_basic = """
- dags:
- gitSync:
- enabled: true
- """
+You can use the custom ``--backend`` switch in pytest to only run tests specific for that backend.
+Here is an example of running only postgres-specific backend tests:
+
+.. code-block:: bash
- class TestGitSyncScheduler(unittest.TestCase):
- def test_basic(self):
- helm_settings = yaml.safe_load(git_sync_basic)
- res = render_chart(
- "GIT-SYNC",
- helm_settings,
- show_only=["templates/scheduler/scheduler-deployment.yaml"],
- )
- dep: k8s.V1Deployment = render_k8s_object(res[0], k8s.V1Deployment)
- assert "dags" == dep.spec.template.spec.volumes[1].name
+ pytest --backend postgres
+
+Running Long-running tests
+--------------------------
+
+Some of the tests rung for a long time. Such tests are marked with ``@pytest.mark.long_running`` annotation.
+Those tests are skipped by default. You can enable them with ``--include-long-running`` flag. You
+can also decide to only run tests with ``-m long-running`` flags to run only those tests.
-To run tests using breeze run the following command
+Running Quarantined tests
+-------------------------
+
+Some of our tests are quarantined. This means that this test will be run in isolation and that it will be
+re-run several times. Also when quarantined tests fail, the whole test suite will not fail. The quarantined
+tests are usually flaky tests that need some attention and fix.
+
+Those tests are marked with ``@pytest.mark.quarantined`` annotation.
+Those tests are skipped by default. You can enable them with ``--include-quarantined`` flag. You
+can also decide to only run tests with ``-m quarantined`` flag to run only those tests.
+
+Running Tests with provider packages
+------------------------------------
+
+Airflow 2.0 introduced the concept of splitting the monolithic Airflow package into separate
+providers packages. The main "apache-airflow" package contains the bare Airflow implementation,
+and additionally we have 70+ providers that we can install additionally to get integrations with
+external services. Those providers live in the same monorepo as Airflow, but we build separate
+packages for them and the main "apache-airflow" package does not contain the providers.
+
+Most of the development in Breeze happens by iterating on sources and when you run
+your tests during development, you usually do not want to build packages and install them separately.
+Therefore by default, when you enter Breeze airflow and all providers are available directly from
+sources rather than installed from packages. This is for example to test the "provider discovery"
+mechanism available that reads provider information from the package meta-data.
+
+When Airflow is run from sources, the metadata is read from provider.yaml
+files, but when Airflow is installed from packages, it is read via the package entrypoint
+``apache_airflow_provider``.
+
+By default, all packages are prepared in wheel format. To install Airflow from packages you
+need to run the following steps:
+
+1. Prepare provider packages
.. code-block:: bash
- ./breeze-legacy --test-type Helm tests
+ breeze release-management prepare-provider-packages [PACKAGE ...]
+
+If you run this command without packages, you will prepare all packages. However, You can specify
+providers that you would like to build if you just want to build few provider packages.
+The packages are prepared in ``dist`` folder. Note that this command cleans up the ``dist`` folder
+before running, so you should run it before generating ``apache-airflow`` package.
+
+2. Prepare airflow packages
+
+.. code-block:: bash
+
+ breeze release-management prepare-airflow-package
+
+This prepares airflow .whl package in the dist folder.
+
+3. Enter breeze installing both airflow and providers from the dist packages
+
+.. code-block:: bash
+
+ breeze --use-airflow-version wheel --use-packages-from-dist --skip-mounting-local-sources
+
Airflow Integration Tests
=========================
Some of the tests in Airflow are integration tests. These tests require ``airflow`` Docker
-image and extra images with integrations (such as ``redis``, ``mongodb``, etc.).
-
+image and extra images with integrations (such as ``celery``, ``mongodb``, etc.).
+The integration tests are all stored in the ``tests/integration`` folder.
Enabling Integrations
---------------------
Airflow integration tests cannot be run in the local virtualenv. They can only run in the Breeze
-environment with enabled integrations and in the CI. See `<.github/workflows/ci.yml>`_ for details about Airflow CI.
+environment with enabled integrations and in the CI. See `CI `_ for details about Airflow CI.
When you are in the Breeze environment, by default, all integrations are disabled. This enables only true unit tests
to be executed in Breeze. You can enable the integration by passing the ``--integration ``
@@ -295,8 +445,7 @@ or using the ``--integration all`` switch that enables all integrations.
NOTE: Every integration requires a separate container with the corresponding integration image.
These containers take precious resources on your PC, mainly the memory. The started integrations are not stopped
-until you stop the Breeze environment with the ``stop`` command and restart it
-via ``restart`` command.
+until you stop the Breeze environment with the ``stop`` command and started with the ``start`` command.
The following integrations are available:
@@ -312,13 +461,9 @@ The following integrations are available:
- Integration that provides Kerberos authentication
* - mongo
- Integration required for MongoDB hooks
- * - openldap
- - Integration required for OpenLDAP hooks
* - pinot
- Integration required for Apache Pinot hooks
- * - rabbitmq
- - Integration required for Celery executor tests
- * - redis
+ * - celery
- Integration required for Celery executor tests
* - trino
- Integration required for Trino hooks
@@ -341,10 +486,6 @@ To start all integrations, enter:
breeze --integration all
-In the CI environment, integrations can be enabled by specifying the ``ENABLED_INTEGRATIONS`` variable
-storing a space-separated list of integrations to start. Thanks to that, we can run integration and
-integration-less tests separately in different jobs, which is desired from the memory usage point of view.
-
Note that Kerberos is a special kind of integration. Some tests run differently when
Kerberos integration is enabled (they retrieve and use a Kerberos authentication token) and differently when the
Kerberos integration is disabled (they neither retrieve nor use the token). Therefore, one of the test jobs
@@ -356,11 +497,11 @@ Running Integration Tests
All tests using an integration are marked with a custom pytest marker ``pytest.mark.integration``.
The marker has a single parameter - the name of integration.
-Example of the ``redis`` integration test:
+Example of the ``celery`` integration test:
.. code-block:: python
- @pytest.mark.integration("redis")
+ @pytest.mark.integration("celery")
def test_real_ping(self):
hook = RedisHook(redis_conn_id="redis_default")
redis = hook.get_conn()
@@ -384,267 +525,202 @@ To run only ``mongo`` integration tests:
.. code-block:: bash
- pytest --integration mongo
+ pytest --integration mongo tests/integration
-To run integration tests for ``mongo`` and ``rabbitmq``:
+To run integration tests for ``mongo`` and ``celery``:
.. code-block:: bash
- pytest --integration mongo --integration rabbitmq
-
-Note that collecting all tests takes some time. So, if you know where your tests are located, you can
-speed up the test collection significantly by providing the folder where the tests are located.
-
-Here is an example of the collection limited to the ``providers/apache`` directory:
-
-.. code-block:: bash
-
- pytest --integration cassandra tests/providers/apache/
-
-Running Backend-Specific Tests
-------------------------------
-
-Tests that are using a specific backend are marked with a custom pytest marker ``pytest.mark.backend``.
-The marker has a single parameter - the name of a backend. It corresponds to the ``--backend`` switch of
-the Breeze environment (one of ``mysql``, ``sqlite``, or ``postgres``). Backend-specific tests only run when
-the Breeze environment is running with the right backend. If you specify more than one backend
-in the marker, the test runs for all specified backends.
-
-Example of the ``postgres`` only test:
-
-.. code-block:: python
-
- @pytest.mark.backend("postgres")
- def test_copy_expert(self):
- ...
-
+ pytest --integration mongo --integration celery tests/integration
-Example of the ``postgres,mysql`` test (they are skipped with the ``sqlite`` backend):
-
-.. code-block:: python
-
- @pytest.mark.backend("postgres", "mysql")
- def test_celery_executor(self):
- ...
-
-You can use the custom ``--backend`` switch in pytest to only run tests specific for that backend.
-Here is an example of running only postgres-specific backend tests:
+Here is an example of the collection limited to the ``providers/apache`` sub-directory:
.. code-block:: bash
- pytest --backend postgres
-
-Running Long-running tests
---------------------------
-
-Some of the tests rung for a long time. Such tests are marked with ``@pytest.mark.long_running`` annotation.
-Those tests are skipped by default. You can enable them with ``--include-long-running`` flag. You
-can also decide to only run tests with ``-m long-running`` flags to run only those tests.
+ pytest --integration cassandra tests/integrations/providers/apache
-Quarantined tests
------------------
+Running Integration Tests from the Host
+---------------------------------------
-Some of our tests are quarantined. This means that this test will be run in isolation and that it will be
-re-run several times. Also when quarantined tests fail, the whole test suite will not fail. The quarantined
-tests are usually flaky tests that need some attention and fix.
+You can also run integration tests using Breeze from the host.
-Those tests are marked with ``@pytest.mark.quarantined`` annotation.
-Those tests are skipped by default. You can enable them with ``--include-quarantined`` flag. You
-can also decide to only run tests with ``-m quarantined`` flag to run only those tests.
+Runs all integration tests:
+ .. code-block:: bash
-Airflow test types
-==================
+ breeze testing integration-tests --db-reset --integration all
-Airflow tests in the CI environment are split into several test types:
+Runs all mongo DB tests:
-* Always - those are tests that should be always executed (always folder)
-* Core - for the core Airflow functionality (core folder)
-* API - Tests for the Airflow API (api and api_connexion folders)
-* CLI - Tests for the Airflow CLI (cli folder)
-* WWW - Tests for the Airflow webserver (www folder)
-* Providers - Tests for all Providers of Airflow (providers folder)
-* Other - all other tests (all other folders that are not part of any of the above)
+ .. code-block:: bash
-This is done for three reasons:
+ breeze testing integration-tests --db-reset --integration mongo
-1. in order to selectively run only subset of the test types for some PRs
-2. in order to allow parallel execution of the tests on Self-Hosted runners
+Helm Unit Tests
+===============
-For case 1. see `Pull Request Workflow `_ for details.
+On the Airflow Project, we have decided to stick with pythonic testing for our Helm chart. This makes our chart
+easier to test, easier to modify, and able to run with the same testing infrastructure. To add Helm unit tests
+add them in ``tests/charts``.
-For case 2. We can utilise memory and CPUs available on both CI and local development machines to run
-test in parallel. This way we can decrease the time of running all tests in self-hosted runners from
-60 minutes to ~15 minutes.
+.. code-block:: python
-.. note::
+ class TestBaseChartTest:
+ ...
- We need to split tests manually into separate suites rather than utilise
- ``pytest-xdist`` or ``pytest-parallel`` which could be a simpler and much more "native" parallelization
- mechanism. Unfortunately, we cannot utilise those tools because our tests are not truly ``unit`` tests that
- can run in parallel. A lot of our tests rely on shared databases - and they update/reset/cleanup the
- databases while they are executing. They are also exercising features of the Database such as locking which
- further increases cross-dependency between tests. Until we make all our tests truly unit tests (and not
- touching the database or until we isolate all such tests to a separate test type, we cannot really rely on
- frameworks that run tests in parallel. In our solution each of the test types is run in parallel with its
- own database (!) so when we have 8 test types running in parallel, there are in fact 8 databases run
- behind the scenes to support them and each of the test types executes its own tests sequentially.
+To render the chart create a YAML string with the nested dictionary of options you wish to test. You can then
+use our ``render_chart`` function to render the object of interest into a testable Python dictionary. Once the chart
+has been rendered, you can use the ``render_k8s_object`` function to create a k8s model object. It simultaneously
+ensures that the object created properly conforms to the expected resource spec and allows you to use object values
+instead of nested dictionaries.
+Example test here:
-Running full Airflow test suite in parallel
-===========================================
+.. code-block:: python
-If you run ``./scripts/ci/testing/ci_run_airflow_testing.sh`` tests run in parallel
-on your development machine - maxing out the number of parallel runs at the number of cores you
-have available in your Docker engine.
+ from tests.charts.helm_template_generator import render_chart, render_k8s_object
-In case you do not have enough memory available to your Docker (~32 GB), the ``Integration`` test type
-is always run sequentially - after all tests are completed (docker cleanup is performed in-between).
+ git_sync_basic = """
+ dags:
+ gitSync:
+ enabled: true
+ """
-This allows for massive speedup in full test execution. On 8 CPU machine with 16 cores and 64 GB memory
-and fast SSD disk, the whole suite of tests completes in about 5 minutes (!). Same suite of tests takes
-more than 30 minutes on the same machine when tests are run sequentially.
-.. note::
+ class TestGitSyncScheduler:
+ def test_basic(self):
+ helm_settings = yaml.safe_load(git_sync_basic)
+ res = render_chart(
+ "GIT-SYNC",
+ helm_settings,
+ show_only=["templates/scheduler/scheduler-deployment.yaml"],
+ )
+ dep: k8s.V1Deployment = render_k8s_object(res[0], k8s.V1Deployment)
+ assert "dags" == dep.spec.template.spec.volumes[1].name
- On MacOS you might have less CPUs and less memory available to run the tests than you have in the host,
- simply because your Docker engine runs in a Linux Virtual Machine under-the-hood. If you want to make
- use of the parallelism and memory usage for the CI tests you might want to increase the resources available
- to your docker engine. See the `Resources `_ chapter
- in the ``Docker for Mac`` documentation on how to do it.
-You can also limit the parallelism by specifying the maximum number of parallel jobs via
-MAX_PARALLEL_TEST_JOBS variable. If you set it to "1", all the test types will be run sequentially.
+To execute all Helm tests using breeze command and utilize parallel pytest tests, you can run the
+following command (but it takes quite a long time even in a multi-processor machine).
.. code-block:: bash
- MAX_PARALLEL_TEST_JOBS="1" ./scripts/ci/testing/ci_run_airflow_testing.sh
-
-.. note::
-
- In case you would like to cleanup after execution of such tests you might have to cleanup
- some of the docker containers running in case you use ctrl-c to stop execution. You can easily do it by
- running this command (it will kill all docker containers running so do not use it if you want to keep some
- docker containers running):
-
- .. code-block:: bash
-
- docker kill $(docker ps -q)
-
+ breeze testing helm-tests
-Running Tests with provider packages
-====================================
-
-Airflow 2.0 introduced the concept of splitting the monolithic Airflow package into separate
-providers packages. The main "apache-airflow" package contains the bare Airflow implementation,
-and additionally we have 70+ providers that we can install additionally to get integrations with
-external services. Those providers live in the same monorepo as Airflow, but we build separate
-packages for them and the main "apache-airflow" package does not contain the providers.
+You can also run Helm tests individually via the usual ``breeze`` command. Just enter breeze and run the
+tests with pytest as you would do with regular unit tests (you can add ``-n auto`` command to run Helm
+tests in parallel - unlike most of the regular unit tests of ours that require a database, the Helm tests are
+perfectly safe to be run in parallel (and if you have multiple processors, you can gain significant
+speedups when using parallel runs):
-Most of the development in Breeze happens by iterating on sources and when you run
-your tests during development, you usually do not want to build packages and install them separately.
-Therefore by default, when you enter Breeze airflow and all providers are available directly from
-sources rather than installed from packages. This is for example to test the "provider discovery"
-mechanism available that reads provider information from the package meta-data.
-
-When Airflow is run from sources, the metadata is read from provider.yaml
-files, but when Airflow is installed from packages, it is read via the package entrypoint
-``apache_airflow_provider``.
+.. code-block:: bash
-By default, all packages are prepared in wheel format. To install Airflow from packages you
-need to run the following steps:
+ breeze
-1. Prepare provider packages
+This enters breeze container.
.. code-block:: bash
- breeze prepare-provider-packages [PACKAGE ...]
-
-If you run this command without packages, you will prepare all packages. However, You can specify
-providers that you would like to build if you just want to build few provider packages.
-The packages are prepared in ``dist`` folder. Note that this command cleans up the ``dist`` folder
-before running, so you should run it before generating ``apache-airflow`` package.
+ pytest tests/charts -n auto
-2. Prepare airflow packages
+This runs all chart tests using all processors you have available.
.. code-block:: bash
- breeze prepare-airflow-package
-
-This prepares airflow .whl package in the dist folder.
-
-3. Enter breeze installing both airflow and providers from the packages
+ pytest tests/charts/test_airflow_common.py -n auto
-This installs airflow and enters
+This will run all tests from ``tests_airflow_common.py`` file using all processors you have available.
.. code-block:: bash
- ./breeze-legacy --use-airflow-version wheel --use-packages-from-dist --skip-mounting-local-sources
+ pytest tests/charts/test_airflow_common.py
+This will run all tests from ``tests_airflow_common.py`` file sequentially.
-Running Tests with Kubernetes
-=============================
+Kubernetes tests
+================
Airflow has tests that are run against real Kubernetes cluster. We are using
`Kind `_ to create and run the cluster. We integrated the tools to start/stop/
deploy and run the cluster tests in our repository and into Breeze development environment.
-Configuration for the cluster is kept in ``./build/.kube/config`` file in your Airflow source repository, and
-our scripts set the ``KUBECONFIG`` variable to it. If you want to interact with the Kind cluster created
-you can do it from outside of the scripts by exporting this variable and point it to this file.
+KinD has a really nice ``kind`` tool that you can use to interact with the cluster. Run ``kind --help`` to
+learn more.
-Starting Kubernetes Cluster
----------------------------
+K8S test environment
+------------------------
-For your testing, you manage Kind cluster with ``kind-cluster`` breeze command:
+Before running ``breeze k8s`` cluster commands you need to setup the environment. This is done
+by ``breeze k8s setup-env`` command. Breeze in this command makes sure to download tools that
+are needed to run k8s tests: Helm, Kind, Kubectl in the right versions and sets up a
+Python virtualenv that is needed to run the tests. All those tools and env are setup in
+``.build/.k8s-env`` folder. You can activate this environment yourselves as usual by sourcing
+``bin/activate`` script, but since we are supporting multiple clusters in the same installation
+it is best if you use ``breeze k8s shell`` with the right parameters specifying which cluster
+to use.
-.. code-block:: bash
+Multiple cluster support
+------------------------
- ./breeze-legacy kind-cluster [ start | stop | recreate | status | deploy | test | shell | k9s ]
+The main feature of ``breeze k8s`` command is that it allows you to manage multiple KinD clusters - one
+per each combination of Python and Kubernetes version. This is used during CI where we can run same
+tests against those different clusters - even in parallel.
-The command allows you to start/stop/recreate/status Kind Kubernetes cluster, deploy Airflow via Helm
-chart as well as interact with the cluster (via test and shell commands).
+The cluster name follows the pattern ``airflow-python-X.Y-vA.B.C`` where X.Y is a major/minor Python version
+and A.B.C is Kubernetes version. Example cluster name: ``airflow-python-3.7-v1.24.0``
-Setting up the Kind Kubernetes cluster takes some time, so once you started it, the cluster continues running
-until it is stopped with the ``kind-cluster stop`` command or until ``kind-cluster recreate``
-command is used (it will stop and recreate the cluster image).
+Most of the commands can be executed in parallel for multiple images/clusters by adding ``--run-in-parallel``
+to create clusters or deploy airflow. Similarly checking for status, dumping logs and deleting clusters
+can be run with ``--all`` flag and they will be executed sequentially for all locally created clusters.
-The cluster name follows the pattern ``airflow-python-X.Y-vA.B.C`` where X.Y is a Python version
-and A.B.C is a Kubernetes version. This way you can have multiple clusters set up and running at the same
-time for different Python versions and different Kubernetes versions.
+Per-cluster configuration files
+-------------------------------
+Once you start the cluster, the configuration for it is stored in a dynamically created folder - separate
+folder for each python/kubernetes_version combination. The folder is ``./build/.k8s-clusters/``
-Deploying Airflow to Kubernetes Cluster
----------------------------------------
+There are two files there:
-Deploying Airflow to the Kubernetes cluster created is also done via ``kind-cluster deploy`` breeze command:
+* kubectl config file stored in .kubeconfig file - our scripts set the ``KUBECONFIG`` variable to it
+* KinD cluster configuration in .kindconfig.yml file - our scripts set the ``KINDCONFIG`` variable to it
-.. code-block:: bash
+The ``KUBECONFIG`` file is automatically used when you enter any of the ``breeze k8s`` commands that use
+``kubectl`` or when you run ``kubectl`` in the k8s shell. The ``KINDCONFIG`` file is used when cluster is
+started but You and the ``k8s`` command can inspect it to know for example what port is forwarded to the
+webserver running in the cluster.
+
+The files are deleted by ``breeze k8s delete-cluster`` command.
- ./breeze-legacy kind-cluster deploy
+Managing Kubernetes Cluster
+---------------------------
-The deploy command performs those steps:
+For your testing, you manage Kind cluster with ``k8s`` breeze command group. Those commands allow to
+created:
-1. It rebuilds the latest ``apache/airflow:main-pythonX.Y`` production images using the
- latest sources using local caching. It also adds example DAGs to the image, so that they do not
- have to be mounted inside.
-2. Loads the image to the Kind Cluster using the ``kind load`` command.
-3. Starts airflow in the cluster using the official helm chart (in ``airflow`` namespace)
-4. Forwards Local 8080 port to the webserver running in the cluster
-5. Applies the volumes.yaml to get the volumes deployed to ``default`` namespace - this is where
- KubernetesExecutor starts its pods.
+.. image:: ./images/breeze/output_k8s.svg
+ :width: 100%
+ :alt: Breeze k8s
-You can also specify a different executor by providing the ``--executor`` optional argument:
+The command group allows you to setup environment, start/stop/recreate/status Kind Kubernetes cluster,
+configure cluster (via ``create-cluster``, ``configure-cluster`` command). Those commands can be run with
+``--run-in-parallel`` flag for all/selected clusters and they can be executed in parallel.
-.. code-block:: bash
+In order to deploy Airflow, the PROD image of Airflow need to be extended and example dags and POD
+template files should be added to the image. This is done via ``build-k8s-image``, ``upload-k8s-image``.
+This can also be done for all/selected images/clusters in parallel via ``--run-in-parallel`` flag.
- ./breeze-legacy kind-cluster deploy --executor CeleryExecutor
+Then Airflow (by using Helm Chart) can be deployed to the cluster via ``deploy-airflow`` command.
+This can also be done for all/selected images/clusters in parallel via ``--run-in-parallel`` flag. You can
+pass extra options when deploying airflow to configure your depliyment.
-Note that when you specify the ``--executor`` option, it becomes the default. Therefore, every other operations
-on ``./breeze-legacy kind-cluster`` will default to using this executor. To change that, use the ``--executor`` option on the
-subsequent commands too.
+You can check the status, dump logs and finally delete cluster via ``status``, ``logs``, ``delete-cluster``
+commands. This can also be done for all created clusters in parallel via ``--all`` flag.
+
+You can interact with the cluster (via ``shell`` and ``k9s`` commands).
+
+You can run set of k8s tests via ``tests`` command. You can also run tests in parallel on all/selected
+clusters by ``--run-in-parallel`` flag.
Running tests with Kubernetes Cluster
@@ -653,29 +729,20 @@ Running tests with Kubernetes Cluster
You can either run all tests or you can select which tests to run. You can also enter interactive virtualenv
to run the tests manually one by one.
-Running Kubernetes tests via shell:
-
-.. code-block:: bash
-
- export EXECUTOR="KubernetesExecutor" ## can be also CeleryExecutor or CeleryKubernetesExecutor
-
- ./scripts/ci/kubernetes/ci_run_kubernetes_tests.sh - runs all kubernetes tests
- ./scripts/ci/kubernetes/ci_run_kubernetes_tests.sh TEST [TEST ...] - runs selected kubernetes tests (from kubernetes_tests folder)
-
Running Kubernetes tests via breeze:
.. code-block:: bash
- ./breeze-legacy kind-cluster test
- ./breeze-legacy kind-cluster test -- TEST TEST [TEST ...]
+ breeze k8s tests
+ breeze k8s tests TEST TEST [TEST ...]
Optionally add ``--executor``:
.. code-block:: bash
- ./breeze-legacy kind-cluster test --executor CeleryExecutor
- ./breeze-legacy kind-cluster test -- TEST TEST [TEST ...] --executor CeleryExecutor
+ breeze k8s tests --executor CeleryExecutor
+ breeze k8s tests --executor CeleryExecutor TEST TEST [TEST ...]
Entering shell with Kubernetes Cluster
--------------------------------------
@@ -683,31 +750,18 @@ Entering shell with Kubernetes Cluster
This shell is prepared to run Kubernetes tests interactively. It has ``kubectl`` and ``kind`` cli tools
available in the path, it has also activated virtualenv environment that allows you to run tests via pytest.
-The binaries are available in ./.build/kubernetes-bin/``KUBERNETES_VERSION`` path.
-The virtualenv is available in ./.build/.kubernetes_venv/``KIND_CLUSTER_NAME``_host_python_``HOST_PYTHON_VERSION``
-
-Where ``KIND_CLUSTER_NAME`` is the name of the cluster and ``HOST_PYTHON_VERSION`` is the version of python
-in the host.
-
-You can enter the shell via those scripts
-
-.. code-block:: bash
-
- export EXECUTOR="KubernetesExecutor" ## can be also CeleryExecutor or CeleryKubernetesExecutor
-
- ./scripts/ci/kubernetes/ci_run_kubernetes_tests.sh [-i|--interactive] - Activates virtual environment ready to run tests and drops you in
- ./scripts/ci/kubernetes/ci_run_kubernetes_tests.sh [--help] - Prints this help message
-
+The virtualenv is available in ./.build/.k8s-env/
+The binaries are available in ``.build/.k8s-env/bin`` path.
.. code-block:: bash
- ./breeze-legacy kind-cluster shell
+ breeze k8s shell
Optionally add ``--executor``:
.. code-block:: bash
- ./breeze-legacy kind-cluster shell --executor CeleryExecutor
+ breeze k8s shell --executor CeleryExecutor
K9s CLI - debug Kubernetes in style!
@@ -734,7 +788,7 @@ You can enter the k9s tool via breeze (after you deployed Airflow):
.. code-block:: bash
- ./breeze-legacy kind-cluster k9s
+ breeze k8s k9s
You can exit k9s by pressing Ctrl-C.
@@ -743,73 +797,316 @@ Typical testing pattern for Kubernetes tests
The typical session for tests with Kubernetes looks like follows:
-1. Start the Kind cluster:
+
+1. Prepare the environment:
.. code-block:: bash
- ./breeze-legacy kind-cluster start
+ breeze k8s setup-env
- Starts Kind Kubernetes cluster
+The first time you run it, it should result in creating the virtualenv and installing good versions
+of kind, kubectl and helm. All of them are installed in ``./build/.k8s-env`` (binaries available in ``bin``
+sub-folder of it).
- Use CI image.
+.. code-block:: text
+
+ Initializing K8S virtualenv in /Users/jarek/IdeaProjects/airflow/.build/.k8s-env
+ Reinstalling PIP version in /Users/jarek/IdeaProjects/airflow/.build/.k8s-env
+ Installing necessary packages in /Users/jarek/IdeaProjects/airflow/.build/.k8s-env
+ The ``kind`` tool is not downloaded yet. Downloading 0.14.0 version.
+ Downloading from: https://github.com/kubernetes-sigs/kind/releases/download/v0.14.0/kind-darwin-arm64
+ The ``kubectl`` tool is not downloaded yet. Downloading 1.24.3 version.
+ Downloading from: https://storage.googleapis.com/kubernetes-release/release/v1.24.3/bin/darwin/arm64/kubectl
+ The ``helm`` tool is not downloaded yet. Downloading 3.9.2 version.
+ Downloading from: https://get.helm.sh/helm-v3.9.2-darwin-arm64.tar.gz
+ Extracting the darwin-arm64/helm to /Users/jarek/IdeaProjects/airflow/.build/.k8s-env/bin
+ Moving the helm to /Users/jarek/IdeaProjects/airflow/.build/.k8s-env/bin/helm
- Branch name: main
- Docker image: apache/airflow:main-python3.7-ci
- Airflow source version: 2.0.0.dev0
- Python version: 3.7
- Backend: postgres 10
+This prepares the virtual environment for tests and downloads the right versions of the tools
+to ``./build/.k8s-env``
- No kind clusters found.
+2. Create the KinD cluster:
+
+.. code-block:: bash
- Creating cluster
+ breeze k8s create-cluster
+
+Should result in KinD creating the K8S cluster.
+
+.. code-block:: text
- Creating cluster "airflow-python-3.7-v1.17.0" ...
- ✓ Ensuring node image (kindest/node:v1.17.0) 🖼
+ Config created in /Users/jarek/IdeaProjects/airflow/.build/.k8s-clusters/airflow-python-3.7-v1.24.2/.kindconfig.yaml:
+
+ # Licensed to the Apache Software Foundation (ASF) under one
+ # or more contributor license agreements. See the NOTICE file
+ # distributed with this work for additional information
+ # regarding copyright ownership. The ASF licenses this file
+ # to you under the Apache License, Version 2.0 (the
+ # "License"); you may not use this file except in compliance
+ # with the License. You may obtain a copy of the License at
+ #
+ # http://www.apache.org/licenses/LICENSE-2.0
+ #
+ # Unless required by applicable law or agreed to in writing,
+ # software distributed under the License is distributed on an
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ # KIND, either express or implied. See the License for the
+ # specific language governing permissions and limitations
+ # under the License.
+ ---
+ kind: Cluster
+ apiVersion: kind.x-k8s.io/v1alpha4
+ networking:
+ ipFamily: ipv4
+ apiServerAddress: "127.0.0.1"
+ apiServerPort: 48366
+ nodes:
+ - role: control-plane
+ - role: worker
+ extraPortMappings:
+ - containerPort: 30007
+ hostPort: 18150
+ listenAddress: "127.0.0.1"
+ protocol: TCP
+
+
+
+ Creating cluster "airflow-python-3.7-v1.24.2" ...
+ ✓ Ensuring node image (kindest/node:v1.24.2) 🖼
✓ Preparing nodes 📦 📦
✓ Writing configuration 📜
✓ Starting control-plane 🕹️
✓ Installing CNI 🔌
- Could not read storage manifest, falling back on old k8s.io/host-path default ...
✓ Installing StorageClass 💾
✓ Joining worker nodes 🚜
- Set kubectl context to "kind-airflow-python-3.7-v1.17.0"
+ Set kubectl context to "kind-airflow-python-3.7-v1.24.2"
You can now use your cluster with:
- kubectl cluster-info --context kind-airflow-python-3.7-v1.17.0
+ kubectl cluster-info --context kind-airflow-python-3.7-v1.24.2
- Have a question, bug, or feature request? Let us know! https://kind.sigs.k8s.io/#community 🙂
+ Not sure what to do next? 😅 Check out https://kind.sigs.k8s.io/docs/user/quick-start/
- Created cluster airflow-python-3.7-v1.17.0
+ KinD Cluster API server URL: http://localhost:48366
+ Connecting to localhost:18150. Num try: 1
+ Error when connecting to localhost:18150 : ('Connection aborted.', RemoteDisconnected('Remote end closed connection without response'))
+ Airflow webserver is not available at port 18150. Run `breeze k8s deploy-airflow --python 3.7 --kubernetes-version v1.24.2` to (re)deploy airflow
-2. Check the status of the cluster
+ KinD cluster airflow-python-3.7-v1.24.2 created!
+
+ NEXT STEP: You might now configure your cluster by:
+
+ breeze k8s configure-cluster
+
+3. Configure cluster for Airflow - this will recreate namespace and upload test resources for Airflow.
.. code-block:: bash
- ./breeze-legacy kind-cluster status
+ breeze k8s configure-cluster
+
+.. code-block:: text
- Checks status of Kind Kubernetes cluster
+ Configuring airflow-python-3.7-v1.24.2 to be ready for Airflow deployment
+ Deleting K8S namespaces for kind-airflow-python-3.7-v1.24.2
+ Error from server (NotFound): namespaces "airflow" not found
+ Error from server (NotFound): namespaces "test-namespace" not found
+ Creating namespaces
+ namespace/airflow created
+ namespace/test-namespace created
+ Created K8S namespaces for cluster kind-airflow-python-3.7-v1.24.2
- Use CI image.
+ Deploying test resources for cluster kind-airflow-python-3.7-v1.24.2
+ persistentvolume/test-volume created
+ persistentvolumeclaim/test-volume created
+ service/airflow-webserver-node-port created
+ Deployed test resources for cluster kind-airflow-python-3.7-v1.24.2
- Branch name: main
- Docker image: apache/airflow:main-python3.7-ci
- Airflow source version: 2.0.0.dev0
- Python version: 3.7
- Backend: postgres 10
+ NEXT STEP: You might now build your k8s image by:
- airflow-python-3.7-v1.17.0-control-plane
- airflow-python-3.7-v1.17.0-worker
+ breeze k8s build-k8s-image
-3. Deploy Airflow to the cluster
+4. Check the status of the cluster
.. code-block:: bash
- ./breeze-legacy kind-cluster deploy
+ breeze k8s status
+
+Should show the status of current KinD cluster.
+
+.. code-block:: text
+
+ ========================================================================================================================
+ Cluster: airflow-python-3.7-v1.24.2
+
+ * KUBECONFIG=/Users/jarek/IdeaProjects/airflow/.build/.k8s-clusters/airflow-python-3.7-v1.24.2/.kubeconfig
+ * KINDCONFIG=/Users/jarek/IdeaProjects/airflow/.build/.k8s-clusters/airflow-python-3.7-v1.24.2/.kindconfig.yaml
+
+ Cluster info: airflow-python-3.7-v1.24.2
+
+ Kubernetes control plane is running at https://127.0.0.1:48366
+ CoreDNS is running at https://127.0.0.1:48366/api/v1/namespaces/kube-system/services/kube-dns:dns/proxy
+
+ To further debug and diagnose cluster problems, use 'kubectl cluster-info dump'.
+
+ Storage class for airflow-python-3.7-v1.24.2
+
+ NAME PROVISIONER RECLAIMPOLICY VOLUMEBINDINGMODE ALLOWVOLUMEEXPANSION AGE
+ standard (default) rancher.io/local-path Delete WaitForFirstConsumer false 83s
+
+ Running pods for airflow-python-3.7-v1.24.2
+
+ NAME READY STATUS RESTARTS AGE
+ coredns-6d4b75cb6d-rwp9d 1/1 Running 0 71s
+ coredns-6d4b75cb6d-vqnrc 1/1 Running 0 71s
+ etcd-airflow-python-3.7-v1.24.2-control-plane 1/1 Running 0 84s
+ kindnet-ckc8l 1/1 Running 0 69s
+ kindnet-qqt8k 1/1 Running 0 71s
+ kube-apiserver-airflow-python-3.7-v1.24.2-control-plane 1/1 Running 0 84s
+ kube-controller-manager-airflow-python-3.7-v1.24.2-control-plane 1/1 Running 0 84s
+ kube-proxy-6g7hn 1/1 Running 0 69s
+ kube-proxy-dwfvp 1/1 Running 0 71s
+ kube-scheduler-airflow-python-3.7-v1.24.2-control-plane 1/1 Running 0 84s
+
+ KinD Cluster API server URL: http://localhost:48366
+ Connecting to localhost:18150. Num try: 1
+ Error when connecting to localhost:18150 : ('Connection aborted.', RemoteDisconnected('Remote end closed connection without response'))
+
+ Airflow webserver is not available at port 18150. Run `breeze k8s deploy-airflow --python 3.7 --kubernetes-version v1.24.2` to (re)deploy airflow
+
+
+ Cluster healthy: airflow-python-3.7-v1.24.2
+
+5. Build the image base on PROD Airflow image. You need to build the PROD image first (the command will
+ guide you if you did not - either by running the build separately or passing ``--rebuild-base-image`` flag
+
+.. code-block:: bash
+
+ breeze k8s build-k8s-image
+
+.. code-block:: text
+
+ Building the K8S image for Python 3.7 using airflow base image: ghcr.io/apache/airflow/main/prod/python3.7:latest
+
+ [+] Building 0.1s (8/8) FINISHED
+ => [internal] load build definition from Dockerfile 0.0s
+ => => transferring dockerfile: 301B 0.0s
+ => [internal] load .dockerignore 0.0s
+ => => transferring context: 35B 0.0s
+ => [internal] load metadata for ghcr.io/apache/airflow/main/prod/python3.7:latest 0.0s
+ => [1/3] FROM ghcr.io/apache/airflow/main/prod/python3.7:latest 0.0s
+ => [internal] load build context 0.0s
+ => => transferring context: 3.00kB 0.0s
+ => CACHED [2/3] COPY airflow/example_dags/ /opt/airflow/dags/ 0.0s
+ => CACHED [3/3] COPY airflow/kubernetes_executor_templates/ /opt/airflow/pod_templates/ 0.0s
+ => exporting to image 0.0s
+ => => exporting layers 0.0s
+ => => writing image sha256:c0bdd363c549c3b0731b8e8ce34153d081f239ee2b582355b7b3ffd5394c40bb 0.0s
+ => => naming to ghcr.io/apache/airflow/main/prod/python3.7-kubernetes:latest
-4. Run Kubernetes tests
+ NEXT STEP: You might now upload your k8s image by:
+
+ breeze k8s upload-k8s-image
+
+
+5. Upload the image to KinD cluster - this uploads your image to make it available for the KinD cluster.
+
+.. code-block:: bash
+
+ breeze k8s upload-k8s-image
+
+.. code-block:: text
+
+ K8S Virtualenv is initialized in /Users/jarek/IdeaProjects/airflow/.build/.k8s-env
+ Good version of kind installed: 0.14.0 in /Users/jarek/IdeaProjects/airflow/.build/.k8s-env/bin
+ Good version of kubectl installed: 1.25.0 in /Users/jarek/IdeaProjects/airflow/.build/.k8s-env/bin
+ Good version of helm installed: 3.9.2 in /Users/jarek/IdeaProjects/airflow/.build/.k8s-env/bin
+ Stable repo is already added
+ Uploading Airflow image ghcr.io/apache/airflow/main/prod/python3.7-kubernetes to cluster airflow-python-3.7-v1.24.2
+ Image: "ghcr.io/apache/airflow/main/prod/python3.7-kubernetes" with ID "sha256:fb6195f7c2c2ad97788a563a3fe9420bf3576c85575378d642cd7985aff97412" not yet present on node "airflow-python-3.7-v1.24.2-worker", loading...
+ Image: "ghcr.io/apache/airflow/main/prod/python3.7-kubernetes" with ID "sha256:fb6195f7c2c2ad97788a563a3fe9420bf3576c85575378d642cd7985aff97412" not yet present on node "airflow-python-3.7-v1.24.2-control-plane", loading...
+
+ NEXT STEP: You might now deploy airflow by:
+
+ breeze k8s deploy-airflow
+
+
+7. Deploy Airflow to the cluster - this will use Airflow Helm Chart to deploy Airflow to the cluster.
+
+.. code-block:: bash
+
+ breeze k8s deploy-airflow
+
+.. code-block:: text
+
+ Deploying Airflow for cluster airflow-python-3.7-v1.24.2
+ Deploying kind-airflow-python-3.7-v1.24.2 with airflow Helm Chart.
+ Copied chart sources to /private/var/folders/v3/gvj4_mw152q556w2rrh7m46w0000gn/T/chart_edu__kir/chart
+ Deploying Airflow from /private/var/folders/v3/gvj4_mw152q556w2rrh7m46w0000gn/T/chart_edu__kir/chart
+ NAME: airflow
+ LAST DEPLOYED: Tue Aug 30 22:57:54 2022
+ NAMESPACE: airflow
+ STATUS: deployed
+ REVISION: 1
+ TEST SUITE: None
+ NOTES:
+ Thank you for installing Apache Airflow 2.3.4!
+
+ Your release is named airflow.
+ You can now access your dashboard(s) by executing the following command(s) and visiting the corresponding port at localhost in your browser:
+
+ Airflow Webserver: kubectl port-forward svc/airflow-webserver 8080:8080 --namespace airflow
+ Default Webserver (Airflow UI) Login credentials:
+ username: admin
+ password: admin
+ Default Postgres connection credentials:
+ username: postgres
+ password: postgres
+ port: 5432
+
+ You can get Fernet Key value by running the following:
+
+ echo Fernet Key: $(kubectl get secret --namespace airflow airflow-fernet-key -o jsonpath="{.data.fernet-key}" | base64 --decode)
+
+ WARNING:
+ Kubernetes workers task logs may not persist unless you configure log persistence or remote logging!
+ Logging options can be found at: https://airflow.apache.org/docs/helm-chart/stable/manage-logs.html
+ (This warning can be ignored if logging is configured with environment variables or secrets backend)
+
+ ###########################################################
+ # WARNING: You should set a static webserver secret key #
+ ###########################################################
+
+ You are using a dynamically generated webserver secret key, which can lead to
+ unnecessary restarts of your Airflow components.
+
+ Information on how to set a static webserver secret key can be found here:
+ https://airflow.apache.org/docs/helm-chart/stable/production-guide.html#webserver-secret-key
+ Deployed kind-airflow-python-3.7-v1.24.2 with airflow Helm Chart.
+
+ Airflow for Python 3.7 and K8S version v1.24.2 has been successfully deployed.
+
+ The KinD cluster name: airflow-python-3.7-v1.24.2
+ The kubectl cluster name: kind-airflow-python-3.7-v1.24.2.
+
+
+ KinD Cluster API server URL: http://localhost:48366
+ Connecting to localhost:18150. Num try: 1
+ Established connection to webserver at http://localhost:18150/health and it is healthy.
+ Airflow Web server URL: http://localhost:18150 (admin/admin)
+
+ NEXT STEP: You might now run tests or interact with airflow via shell (kubectl, pytest etc.) or k9s commands:
+
+
+ breeze k8s tests
+
+ breeze k8s shell
+
+ breeze k8s k9s
+
+
+8. Run Kubernetes tests
Note that the tests are executed in production container not in the CI container.
There is no need for the tests to run inside the Airflow CI container image as they only
@@ -817,69 +1114,77 @@ communicate with the Kubernetes-run Airflow deployed via the production image.
Those Kubernetes tests require virtualenv to be created locally with airflow installed.
The virtualenv required will be created automatically when the scripts are run.
-4a) You can run all the tests
+8a) You can run all the tests
.. code-block:: bash
- ./breeze-legacy kind-cluster test
+ breeze k8s tests
+.. code-block:: text
-4b) You can enter an interactive shell to run tests one-by-one
+ Running tests with kind-airflow-python-3.7-v1.24.2 cluster.
+ Command to run: pytest kubernetes_tests
+ ========================================================================================= test session starts ==========================================================================================
+ platform darwin -- Python 3.9.9, pytest-6.2.5, py-1.11.0, pluggy-1.0.0 -- /Users/jarek/IdeaProjects/airflow/.build/.k8s-env/bin/python
+ cachedir: .pytest_cache
+ rootdir: /Users/jarek/IdeaProjects/airflow, configfile: pytest.ini
+ plugins: anyio-3.6.1, instafail-0.4.2, xdist-2.5.0, forked-1.4.0, timeouts-1.2.1, cov-3.0.0
+ setup timeout: 0.0s, execution timeout: 0.0s, teardown timeout: 0.0s
+ collected 55 items
-This prepares and enters the virtualenv in ``.build/.kubernetes_venv_`` folder:
+ kubernetes_tests/test_kubernetes_executor.py::TestKubernetesExecutor::test_integration_run_dag PASSED [ 1%]
+ kubernetes_tests/test_kubernetes_executor.py::TestKubernetesExecutor::test_integration_run_dag_with_scheduler_failure PASSED [ 3%]
+ kubernetes_tests/test_kubernetes_pod_operator.py::TestKubernetesPodOperatorSystem::test_already_checked_on_failure PASSED [ 5%]
+ kubernetes_tests/test_kubernetes_pod_operator.py::TestKubernetesPodOperatorSystem::test_already_checked_on_success ...
-.. code-block:: bash
-
- ./breeze-legacy kind-cluster shell
-
-Once you enter the environment, you receive this information:
+8b) You can enter an interactive shell to run tests one-by-one
+This enters the virtualenv in ``.build/.k8s-env`` folder:
.. code-block:: bash
- Activating the virtual environment for kubernetes testing
+ breeze k8s shell
- You can run kubernetes testing via 'pytest kubernetes_tests/....'
- You can add -s to see the output of your tests on screen
+Once you enter the environment, you receive this information:
- The webserver is available at http://localhost:8080/
+.. code-block:: text
- User/password: admin/admin
+ Entering interactive k8s shell.
- You are entering the virtualenv now. Type exit to exit back to the original shell
+ (kind-airflow-python-3.7-v1.24.2:KubernetesExecutor)>
In a separate terminal you can open the k9s CLI:
.. code-block:: bash
- ./breeze-legacy kind-cluster k9s
+ breeze k8s k9s
Use it to observe what's going on in your cluster.
-6. Debugging in IntelliJ/PyCharm
+9. Debugging in IntelliJ/PyCharm
It is very easy to running/debug Kubernetes tests with IntelliJ/PyCharm. Unlike the regular tests they are
in ``kubernetes_tests`` folder and if you followed the previous steps and entered the shell using
-``./breeze-legacy kind-cluster shell`` command, you can setup your IDE very easy to run (and debug) your
+``breeze k8s shell`` command, you can setup your IDE very easy to run (and debug) your
tests using the standard IntelliJ Run/Debug feature. You just need a few steps:
-a) Add the virtualenv as interpreter for the project:
+9a) Add the virtualenv as interpreter for the project:
.. image:: images/testing/kubernetes-virtualenv.png
:align: center
:alt: Kubernetes testing virtualenv
The virtualenv is created in your "Airflow" source directory in the
-``.build/.kubernetes_venv_`` folder and you
-have to find ``python`` binary and choose it when selecting interpreter.
+``.build/.k8s-env`` folder and you have to find ``python`` binary and choose
+it when selecting interpreter.
-b) Choose pytest as test runner:
+9b) Choose pytest as test runner:
.. image:: images/testing/pytest-runner.png
:align: center
:alt: Pytest runner
-c) Run/Debug tests using standard "Run/Debug" feature of IntelliJ
+9c) Run/Debug tests using standard "Run/Debug" feature of IntelliJ
.. image:: images/testing/run-test.png
:align: center
@@ -889,7 +1194,7 @@ c) Run/Debug tests using standard "Run/Debug" feature of IntelliJ
NOTE! The first time you run it, it will likely fail with
``kubernetes.config.config_exception.ConfigException``:
``Invalid kube-config file. Expected key current-context in kube-config``. You need to add KUBECONFIG
-environment variable copying it from the result of "./breeze-legacy kind-cluster test":
+environment variable copying it from the result of "breeze k8s tests":
.. code-block:: bash
@@ -897,7 +1202,6 @@ environment variable copying it from the result of "./breeze-legacy kind-cluster
/home/jarek/code/airflow/.build/.kube/config
-
.. image:: images/testing/kubeconfig-env.png
:align: center
:alt: Run/Debug tests
@@ -914,21 +1218,28 @@ print output generated test logs and print statements to the terminal immediatel
pytest kubernetes_tests/test_kubernetes_executor.py::TestKubernetesExecutor::test_integration_run_dag_with_scheduler_failure -s
-
You can modify the tests or KubernetesPodOperator and re-run them without re-deploying
Airflow to KinD cluster.
+10. Dumping logs
-Sometimes there are side effects from running tests. You can run ``redeploy_airflow.sh`` without
-recreating the whole cluster. This will delete the whole namespace, including the database data
-and start a new Airflow deployment in the cluster.
+Sometimes You want to see the logs of the clister. This can be done with ``breeze k8s logs``.
.. code-block:: bash
- ./scripts/ci/redeploy_airflow.sh
+ breeze k8s logs
+
+11. Redeploying airflow
+
+Sometimes there are side effects from running tests. You can run ``breeze k8s deploy-airflow --upgrade``
+without recreating the whole cluster.
+
+.. code-block:: bash
-If needed you can also delete the cluster manually:
+ breeze k8s deploy-airflow --upgrade
+If needed you can also delete the cluster manually (within the virtualenv activated by
+``breeze k8s shell``:
.. code-block:: bash
@@ -941,20 +1252,30 @@ Kind has also useful commands to inspect your running cluster:
kind --help
-
-However, when you change Kubernetes executor implementation, you need to redeploy
-Airflow to the cluster.
+12. Stop KinD cluster when you are done
.. code-block:: bash
- ./breeze-legacy kind-cluster deploy
+ breeze k8s delete-cluster
+
+.. code-block:: text
+ Deleting KinD cluster airflow-python-3.7-v1.24.2!
+ Deleting cluster "airflow-python-3.7-v1.24.2" ...
+ KinD cluster airflow-python-3.7-v1.24.2 deleted!
-7. Stop KinD cluster when you are done
+
+Running complete k8s tests
+--------------------------
+
+You can also run complete k8s tests with
.. code-block:: bash
- ./breeze-legacy kind-cluster stop
+ breeze k8s run-complete-tests
+
+This will create cluster, build images, deploy airflow run tests and finally delete clusters as single
+command. It is the way it is run in our CI, you can also run such complete tests in parallel.
Airflow System Tests
@@ -1074,8 +1395,7 @@ A simple example of a system test is available in:
``tests/providers/google/cloud/operators/test_compute_system.py``.
-It runs two DAGs defined in ``airflow.providers.google.cloud.example_dags.example_compute.py`` and
-``airflow.providers.google.cloud.example_dags.example_compute_igm.py``.
+It runs two DAGs defined in ``airflow.providers.google.cloud.example_dags.example_compute.py``.
Preparing provider packages for System Tests for Airflow 1.10.* series
----------------------------------------------------------------------
@@ -1086,7 +1406,7 @@ example, the below command will build google, postgres and mysql wheel packages:
.. code-block:: bash
- breeze prepare-provider-packages google postgres mysql
+ breeze release-management prepare-provider-packages google postgres mysql
Those packages will be prepared in ./dist folder. This folder is mapped to /dist folder
when you enter Breeze, so it is easy to automate installing those packages for testing.
diff --git a/airflow/__init__.py b/airflow/__init__.py
index cbbb03dd1be29..19624252a3875 100644
--- a/airflow/__init__.py
+++ b/airflow/__init__.py
@@ -15,8 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-#
-
"""
Authentication is implemented using flask_login and different environments can
implement their own login mechanisms by providing an `airflow_login` module
@@ -25,28 +23,38 @@
isort:skip_file
"""
-
+from __future__ import annotations
# flake8: noqa: F401
+import os
import sys
-from typing import Callable, Optional
+from typing import Callable
-from airflow import settings
-from airflow import version
+if os.environ.get("_AIRFLOW_PATCH_GEVENT"):
+ # If you are using gevents and start airflow webserver, you might want to run gevent monkeypatching
+ # as one of the first thing when Airflow is started. This allows gevent to patch networking and other
+ # system libraries to make them gevent-compatible before anything else patches them (for example boto)
+ from gevent.monkey import patch_all
-__version__ = version.version
+ patch_all()
-__all__ = ['__version__', 'login', 'DAG', 'PY36', 'PY37', 'PY38', 'PY39', 'PY310', 'XComArg']
+from airflow import settings
+
+__all__ = ["__version__", "login", "DAG", "PY36", "PY37", "PY38", "PY39", "PY310", "XComArg"]
# Make `airflow` an namespace package, supporting installing
# airflow.providers.* in different locations (i.e. one in site, and one in user
# lib.)
__path__ = __import__("pkgutil").extend_path(__path__, __name__) # type: ignore
-settings.initialize()
-login: Optional[Callable] = None
+# Perform side-effects unless someone has explicitly opted out before import
+# WARNING: DO NOT USE THIS UNLESS YOU REALLY KNOW WHAT YOU'RE DOING.
+if not os.environ.get("_AIRFLOW__AS_LIBRARY", None):
+ settings.initialize()
+
+login: Callable | None = None
PY36 = sys.version_info >= (3, 6)
PY37 = sys.version_info >= (3, 7)
@@ -54,28 +62,31 @@
PY39 = sys.version_info >= (3, 9)
PY310 = sys.version_info >= (3, 10)
-# Things to lazy import in form 'name': 'path.to.module'
-__lazy_imports = {
- 'DAG': 'airflow.models.dag',
- 'XComArg': 'airflow.models.xcom_arg',
- 'AirflowException': 'airflow.exceptions',
+# Things to lazy import in form {local_name: ('target_module', 'target_name')}
+__lazy_imports: dict[str, tuple[str, str]] = {
+ "DAG": (".models.dag", "DAG"),
+ "Dataset": (".datasets", "Dataset"),
+ "XComArg": (".models.xcom_arg", "XComArg"),
+ "AirflowException": (".exceptions", "AirflowException"),
+ "version": (".version", ""),
+ "__version__": (".version", "version"),
}
-def __getattr__(name):
+def __getattr__(name: str):
# PEP-562: Lazy loaded attributes on python modules
- path = __lazy_imports.get(name)
- if not path:
+ module_path, attr_name = __lazy_imports.get(name, ("", ""))
+ if not module_path:
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
- import operator
+ import importlib
- # Strip off the "airflow." prefix because of how `__import__` works (it always returns the top level
- # module)
- without_prefix = path.split('.', 1)[-1]
+ mod = importlib.import_module(module_path, __name__)
+ if attr_name:
+ val = getattr(mod, attr_name)
+ else:
+ val = mod
- getter = operator.attrgetter(f'{without_prefix}.{name}')
- val = getter(__import__(path))
# Store for next time
globals()[name] = val
return val
@@ -99,7 +110,7 @@ def __getattr__(name):
# into knowing the types of these symbols, and what
# they contain.
STATICA_HACK = True
-globals()['kcah_acitats'[::-1].upper()] = False
+globals()["kcah_acitats"[::-1].upper()] = False
if STATICA_HACK: # pragma: no cover
from airflow.models.dag import DAG
from airflow.models.xcom_arg import XComArg
diff --git a/airflow/__main__.py b/airflow/__main__.py
index 334126b2d930b..6114534e1c887 100644
--- a/airflow/__main__.py
+++ b/airflow/__main__.py
@@ -17,8 +17,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Main executable module"""
+from __future__ import annotations
+
import os
import argcomplete
@@ -29,14 +30,14 @@
def main():
"""Main executable function"""
- if conf.get("core", "security") == 'kerberos':
- os.environ['KRB5CCNAME'] = conf.get('kerberos', 'ccache')
- os.environ['KRB5_KTNAME'] = conf.get('kerberos', 'keytab')
+ if conf.get("core", "security") == "kerberos":
+ os.environ["KRB5CCNAME"] = conf.get("kerberos", "ccache")
+ os.environ["KRB5_KTNAME"] = conf.get("kerberos", "keytab")
parser = cli_parser.get_parser()
argcomplete.autocomplete(parser)
args = parser.parse_args()
args.func(args)
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/airflow/api/__init__.py b/airflow/api/__init__.py
index da07429869877..656009b0dd69d 100644
--- a/airflow/api/__init__.py
+++ b/airflow/api/__init__.py
@@ -15,7 +15,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Authentication backend"""
+"""Authentication backend."""
+from __future__ import annotations
+
import logging
from importlib import import_module
@@ -26,8 +28,8 @@
def load_auth():
- """Loads authentication backends"""
- auth_backends = 'airflow.api.auth.backend.default'
+ """Load authentication backends."""
+ auth_backends = "airflow.api.auth.backend.default"
try:
auth_backends = conf.get("api", "auth_backends")
except AirflowConfigException:
diff --git a/airflow/api/auth/backend/basic_auth.py b/airflow/api/auth/backend/basic_auth.py
index 397a722a98cf2..3f802fde636a5 100644
--- a/airflow/api/auth/backend/basic_auth.py
+++ b/airflow/api/auth/backend/basic_auth.py
@@ -14,33 +14,36 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Basic authentication backend"""
+"""Basic authentication backend."""
+from __future__ import annotations
+
from functools import wraps
-from typing import Any, Callable, Optional, Tuple, TypeVar, Union, cast
+from typing import Any, Callable, TypeVar, cast
-from flask import Response, current_app, request
+from flask import Response, request
from flask_appbuilder.const import AUTH_LDAP
from flask_login import login_user
+from airflow.utils.airflow_flask_app import get_airflow_app
from airflow.www.fab_security.sqla.models import User
-CLIENT_AUTH: Optional[Union[Tuple[str, str], Any]] = None
+CLIENT_AUTH: tuple[str, str] | Any | None = None
def init_app(_):
- """Initializes authentication backend"""
+ """Initialize authentication backend."""
T = TypeVar("T", bound=Callable)
-def auth_current_user() -> Optional[User]:
- """Authenticate and set current user if Authorization header exists"""
+def auth_current_user() -> User | None:
+ """Authenticate and set current user if Authorization header exists."""
auth = request.authorization
if auth is None or not auth.username or not auth.password:
return None
- ab_security_manager = current_app.appbuilder.sm
+ ab_security_manager = get_airflow_app().appbuilder.sm
user = None
if ab_security_manager.auth_type == AUTH_LDAP:
user = ab_security_manager.auth_user_ldap(auth.username, auth.password)
@@ -52,7 +55,7 @@ def auth_current_user() -> Optional[User]:
def requires_authentication(function: T):
- """Decorator for functions that require authentication"""
+ """Decorate functions that require authentication."""
@wraps(function)
def decorated(*args, **kwargs):
diff --git a/airflow/api/auth/backend/default.py b/airflow/api/auth/backend/default.py
index 6b0a1a6c67907..b1d8f9bcf2674 100644
--- a/airflow/api/auth/backend/default.py
+++ b/airflow/api/auth/backend/default.py
@@ -15,22 +15,24 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Default authentication backend - everything is allowed"""
+"""Default authentication backend - everything is allowed."""
+from __future__ import annotations
+
from functools import wraps
-from typing import Any, Callable, Optional, Tuple, TypeVar, Union, cast
+from typing import Any, Callable, TypeVar, cast
-CLIENT_AUTH: Optional[Union[Tuple[str, str], Any]] = None
+CLIENT_AUTH: tuple[str, str] | Any | None = None
def init_app(_):
- """Initializes authentication backend"""
+ """Initialize authentication backend."""
T = TypeVar("T", bound=Callable)
def requires_authentication(function: T):
- """Decorator for functions that require authentication"""
+ """Decorate functions that require authentication."""
@wraps(function)
def decorated(*args, **kwargs):
diff --git a/airflow/api/auth/backend/deny_all.py b/airflow/api/auth/backend/deny_all.py
index 614e263684ad8..29b23f8d73b5a 100644
--- a/airflow/api/auth/backend/deny_all.py
+++ b/airflow/api/auth/backend/deny_all.py
@@ -15,24 +15,26 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Authentication backend that denies all requests"""
+"""Authentication backend that denies all requests."""
+from __future__ import annotations
+
from functools import wraps
-from typing import Any, Callable, Optional, Tuple, TypeVar, Union, cast
+from typing import Any, Callable, TypeVar, cast
from flask import Response
-CLIENT_AUTH: Optional[Union[Tuple[str, str], Any]] = None
+CLIENT_AUTH: tuple[str, str] | Any | None = None
def init_app(_):
- """Initializes authentication"""
+ """Initialize authentication."""
T = TypeVar("T", bound=Callable)
def requires_authentication(function: T):
- """Decorator for functions that require authentication"""
+ """Decorate functions that require authentication."""
@wraps(function)
def decorated(*args, **kwargs):
diff --git a/airflow/api/auth/backend/kerberos_auth.py b/airflow/api/auth/backend/kerberos_auth.py
index fb76e8a1aa0fa..abb951be968ed 100644
--- a/airflow/api/auth/backend/kerberos_auth.py
+++ b/airflow/api/auth/backend/kerberos_auth.py
@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
#
# Copyright (c) 2013, Michael Komitee
@@ -43,23 +44,23 @@
import logging
import os
from functools import wraps
-from socket import getfqdn
-from typing import Any, Callable, Optional, Tuple, TypeVar, Union, cast
+from typing import Any, Callable, TypeVar, cast
import kerberos
from flask import Response, _request_ctx_stack as stack, g, make_response, request # type: ignore
from requests_kerberos import HTTPKerberosAuth
from airflow.configuration import conf
+from airflow.utils.net import getfqdn
log = logging.getLogger(__name__)
-CLIENT_AUTH: Optional[Union[Tuple[str, str], Any]] = HTTPKerberosAuth(service='airflow')
+CLIENT_AUTH: tuple[str, str] | Any | None = HTTPKerberosAuth(service="airflow")
class KerberosService:
- """Class to keep information about the Kerberos Service initialized"""
+ """Class to keep information about the Kerberos Service initialized."""
def __init__(self):
self.service_name = None
@@ -70,18 +71,18 @@ def __init__(self):
def init_app(app):
- """Initializes application with kerberos"""
- hostname = app.config.get('SERVER_NAME')
+ """Initialize application with kerberos."""
+ hostname = app.config.get("SERVER_NAME")
if not hostname:
hostname = getfqdn()
log.info("Kerberos: hostname %s", hostname)
- service = 'airflow'
+ service = "airflow"
_KERBEROS_SERVICE.service_name = f"{service}@{hostname}"
- if 'KRB5_KTNAME' not in os.environ:
- os.environ['KRB5_KTNAME'] = conf.get('kerberos', 'keytab')
+ if "KRB5_KTNAME" not in os.environ:
+ os.environ["KRB5_KTNAME"] = conf.get("kerberos", "keytab")
try:
log.info("Kerberos init: %s %s", service, hostname)
@@ -94,7 +95,7 @@ def init_app(app):
def _unauthorized():
"""
- Indicate that authorization is required
+ Indicate that authorization is required.
:return:
"""
return Response("Unauthorized", 401, {"WWW-Authenticate": "Negotiate"})
@@ -130,21 +131,21 @@ def _gssapi_authenticate(token):
def requires_authentication(function: T):
- """Decorator for functions that require authentication with Kerberos"""
+ """Decorate functions that require authentication with Kerberos."""
@wraps(function)
def decorated(*args, **kwargs):
header = request.headers.get("Authorization")
if header:
ctx = stack.top
- token = ''.join(header.split()[1:])
+ token = "".join(header.split()[1:])
return_code = _gssapi_authenticate(token)
if return_code == kerberos.AUTH_GSS_COMPLETE:
g.user = ctx.kerberos_user
response = function(*args, **kwargs)
response = make_response(response)
if ctx.kerberos_token is not None:
- response.headers['WWW-Authenticate'] = ' '.join(['negotiate', ctx.kerberos_token])
+ response.headers["WWW-Authenticate"] = " ".join(["negotiate", ctx.kerberos_token])
return response
if return_code != kerberos.AUTH_GSS_CONTINUE:
diff --git a/airflow/api/auth/backend/session.py b/airflow/api/auth/backend/session.py
index 3f345c8b38240..8cc75ccad024e 100644
--- a/airflow/api/auth/backend/session.py
+++ b/airflow/api/auth/backend/session.py
@@ -14,24 +14,26 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Session authentication backend"""
+"""Session authentication backend."""
+from __future__ import annotations
+
from functools import wraps
-from typing import Any, Callable, Optional, Tuple, TypeVar, Union, cast
+from typing import Any, Callable, TypeVar, cast
from flask import Response, g
-CLIENT_AUTH: Optional[Union[Tuple[str, str], Any]] = None
+CLIENT_AUTH: tuple[str, str] | Any | None = None
def init_app(_):
- """Initializes authentication backend"""
+ """Initialize authentication backend."""
T = TypeVar("T", bound=Callable)
def requires_authentication(function: T):
- """Decorator for functions that require authentication"""
+ """Decorate functions that require authentication."""
@wraps(function)
def decorated(*args, **kwargs):
diff --git a/airflow/api/client/__init__.py b/airflow/api/client/__init__.py
index 49224b5336a32..35608032abcb6 100644
--- a/airflow/api/client/__init__.py
+++ b/airflow/api/client/__init__.py
@@ -15,9 +15,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""API Client that allows interacting with Airflow API"""
+"""API Client that allows interacting with Airflow API."""
+from __future__ import annotations
+
from importlib import import_module
-from typing import Any
from airflow import api
from airflow.api.client.api_client import Client
@@ -25,17 +26,17 @@
def get_current_api_client() -> Client:
- """Return current API Client based on current Airflow configuration"""
- api_module = import_module(conf.get_mandatory_value('cli', 'api_client')) # type: Any
+ """Return current API Client based on current Airflow configuration."""
+ api_module = import_module(conf.get_mandatory_value("cli", "api_client"))
auth_backends = api.load_auth()
session = None
for backend in auth_backends:
- session_factory = getattr(backend, 'create_client_session', None)
+ session_factory = getattr(backend, "create_client_session", None)
if session_factory:
session = session_factory()
api_client = api_module.Client(
- api_base_url=conf.get('cli', 'endpoint_url'),
- auth=getattr(backend, 'CLIENT_AUTH', None),
+ api_base_url=conf.get("cli", "endpoint_url"),
+ auth=getattr(backend, "CLIENT_AUTH", None),
session=session,
)
return api_client
diff --git a/airflow/api/client/api_client.py b/airflow/api/client/api_client.py
index c116d3b75ebb7..1334771b8de65 100644
--- a/airflow/api/client/api_client.py
+++ b/airflow/api/client/api_client.py
@@ -16,6 +16,8 @@
# specific language governing permissions and limitations
# under the License.
"""Client for all the API clients."""
+from __future__ import annotations
+
import httpx
@@ -75,7 +77,7 @@ def delete_pool(self, name):
def get_lineage(self, dag_id: str, execution_date: str):
"""
- Return the lineage information for the dag on this execution date
+ Return the lineage information for the dag on this execution date.
:param dag_id:
:param execution_date:
:return:
diff --git a/airflow/api/client/json_client.py b/airflow/api/client/json_client.py
index d87a3aeef65b3..c6995ceb2a5cf 100644
--- a/airflow/api/client/json_client.py
+++ b/airflow/api/client/json_client.py
@@ -15,7 +15,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""JSON API Client"""
+"""JSON API Client."""
+from __future__ import annotations
from urllib.parse import urljoin
@@ -25,12 +26,12 @@
class Client(api_client.Client):
"""Json API client implementation."""
- def _request(self, url, method='GET', json=None):
+ def _request(self, url, method="GET", json=None):
params = {
- 'url': url,
+ "url": url,
}
if json is not None:
- params['json'] = json
+ params["json"] = json
resp = getattr(self._session, method.lower())(**params)
if resp.is_error:
# It is justified here because there might be many resp types.
@@ -38,64 +39,64 @@ def _request(self, url, method='GET', json=None):
data = resp.json()
except Exception:
data = {}
- raise OSError(data.get('error', 'Server error'))
+ raise OSError(data.get("error", "Server error"))
return resp.json()
def trigger_dag(self, dag_id, run_id=None, conf=None, execution_date=None):
- endpoint = f'/api/experimental/dags/{dag_id}/dag_runs'
+ endpoint = f"/api/experimental/dags/{dag_id}/dag_runs"
url = urljoin(self._api_base_url, endpoint)
data = self._request(
url,
- method='POST',
+ method="POST",
json={
"run_id": run_id,
"conf": conf,
"execution_date": execution_date,
},
)
- return data['message']
+ return data["message"]
def delete_dag(self, dag_id):
- endpoint = f'/api/experimental/dags/{dag_id}/delete_dag'
+ endpoint = f"/api/experimental/dags/{dag_id}/delete_dag"
url = urljoin(self._api_base_url, endpoint)
- data = self._request(url, method='DELETE')
- return data['message']
+ data = self._request(url, method="DELETE")
+ return data["message"]
def get_pool(self, name):
- endpoint = f'/api/experimental/pools/{name}'
+ endpoint = f"/api/experimental/pools/{name}"
url = urljoin(self._api_base_url, endpoint)
pool = self._request(url)
- return pool['pool'], pool['slots'], pool['description']
+ return pool["pool"], pool["slots"], pool["description"]
def get_pools(self):
- endpoint = '/api/experimental/pools'
+ endpoint = "/api/experimental/pools"
url = urljoin(self._api_base_url, endpoint)
pools = self._request(url)
- return [(p['pool'], p['slots'], p['description']) for p in pools]
+ return [(p["pool"], p["slots"], p["description"]) for p in pools]
def create_pool(self, name, slots, description):
- endpoint = '/api/experimental/pools'
+ endpoint = "/api/experimental/pools"
url = urljoin(self._api_base_url, endpoint)
pool = self._request(
url,
- method='POST',
+ method="POST",
json={
- 'name': name,
- 'slots': slots,
- 'description': description,
+ "name": name,
+ "slots": slots,
+ "description": description,
},
)
- return pool['pool'], pool['slots'], pool['description']
+ return pool["pool"], pool["slots"], pool["description"]
def delete_pool(self, name):
- endpoint = f'/api/experimental/pools/{name}'
+ endpoint = f"/api/experimental/pools/{name}"
url = urljoin(self._api_base_url, endpoint)
- pool = self._request(url, method='DELETE')
- return pool['pool'], pool['slots'], pool['description']
+ pool = self._request(url, method="DELETE")
+ return pool["pool"], pool["slots"], pool["description"]
def get_lineage(self, dag_id: str, execution_date: str):
endpoint = f"/api/experimental/lineage/{dag_id}/{execution_date}"
url = urljoin(self._api_base_url, endpoint)
- data = self._request(url, method='GET')
- return data['message']
+ data = self._request(url, method="GET")
+ return data["message"]
diff --git a/airflow/api/client/local_client.py b/airflow/api/client/local_client.py
index c0050672a8e47..3b6e61241eaa1 100644
--- a/airflow/api/client/local_client.py
+++ b/airflow/api/client/local_client.py
@@ -15,7 +15,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Local client API"""
+"""Local client API."""
+from __future__ import annotations
from airflow.api.client import api_client
from airflow.api.common import delete_dag, trigger_dag
diff --git a/airflow/api/common/delete_dag.py b/airflow/api/common/delete_dag.py
index 5e0afa81cb5c9..39f1461fccc9e 100644
--- a/airflow/api/common/delete_dag.py
+++ b/airflow/api/common/delete_dag.py
@@ -16,6 +16,8 @@
# specific language governing permissions and limitations
# under the License.
"""Delete DAGs APIs."""
+from __future__ import annotations
+
import logging
from sqlalchemy import and_, or_
@@ -34,6 +36,8 @@
@provide_session
def delete_dag(dag_id: str, keep_records_in_log: bool = True, session=None) -> int:
"""
+ Delete a DAG by a dag_id.
+
:param dag_id: the dag_id of the DAG to delete
:param keep_records_in_log: whether keep records of the given dag_id
in the Log table in the backend database (for reasons like auditing).
@@ -72,12 +76,12 @@ def delete_dag(dag_id: str, keep_records_in_log: bool = True, session=None) -> i
for model in get_sqla_model_classes():
if hasattr(model, "dag_id"):
- if keep_records_in_log and model.__name__ == 'Log':
+ if keep_records_in_log and model.__name__ == "Log":
continue
count += (
session.query(model)
.filter(model.dag_id.in_(dags_to_delete))
- .delete(synchronize_session='fetch')
+ .delete(synchronize_session="fetch")
)
if dag.is_subdag:
parent_dag_id, task_id = dag_id.rsplit(".", 1)
@@ -89,7 +93,7 @@ def delete_dag(dag_id: str, keep_records_in_log: bool = True, session=None) -> i
# Delete entries in Import Errors table for a deleted DAG
# This handles the case when the dag_id is changed in the file
session.query(models.ImportError).filter(models.ImportError.filename == dag.fileloc).delete(
- synchronize_session='fetch'
+ synchronize_session="fetch"
)
return count
diff --git a/airflow/api/common/experimental/__init__.py b/airflow/api/common/experimental/__init__.py
index b161e04346358..35a4da3f1735a 100644
--- a/airflow/api/common/experimental/__init__.py
+++ b/airflow/api/common/experimental/__init__.py
@@ -16,15 +16,16 @@
# specific language governing permissions and limitations
# under the License.
"""Experimental APIs."""
+from __future__ import annotations
+
from datetime import datetime
-from typing import Optional
from airflow.exceptions import DagNotFound, DagRunNotFound, TaskNotFound
from airflow.models import DagBag, DagModel, DagRun
-def check_and_get_dag(dag_id: str, task_id: Optional[str] = None) -> DagModel:
- """Checks that DAG exists and in case it is specified that Task exist"""
+def check_and_get_dag(dag_id: str, task_id: str | None = None) -> DagModel:
+ """Check DAG existence and in case it is specified that Task exists."""
dag_model = DagModel.get_current(dag_id)
if dag_model is None:
raise DagNotFound(f"Dag id {dag_id} not found in DagModel")
@@ -35,15 +36,15 @@ def check_and_get_dag(dag_id: str, task_id: Optional[str] = None) -> DagModel:
error_message = f"Dag id {dag_id} not found"
raise DagNotFound(error_message)
if task_id and not dag.has_task(task_id):
- error_message = f'Task {task_id} not found in dag {dag_id}'
+ error_message = f"Task {task_id} not found in dag {dag_id}"
raise TaskNotFound(error_message)
return dag
def check_and_get_dagrun(dag: DagModel, execution_date: datetime) -> DagRun:
- """Get DagRun object and check that it exists"""
+ """Get DagRun object and check that it exists."""
dagrun = dag.get_dagrun(execution_date=execution_date)
if not dagrun:
- error_message = f'Dag Run for date {execution_date} not found in dag {dag.dag_id}'
+ error_message = f"Dag Run for date {execution_date} not found in dag {dag.dag_id}"
raise DagRunNotFound(error_message)
return dagrun
diff --git a/airflow/api/common/experimental/delete_dag.py b/airflow/api/common/experimental/delete_dag.py
index 36bf7dd8c46a7..821b80aa9af57 100644
--- a/airflow/api/common/experimental/delete_dag.py
+++ b/airflow/api/common/experimental/delete_dag.py
@@ -15,6 +15,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
import warnings
from airflow.api.common.delete_dag import * # noqa
diff --git a/airflow/api/common/experimental/get_code.py b/airflow/api/common/experimental/get_code.py
index d4232b1d0903b..9e2e8c08a4b56 100644
--- a/airflow/api/common/experimental/get_code.py
+++ b/airflow/api/common/experimental/get_code.py
@@ -16,6 +16,8 @@
# specific language governing permissions and limitations
# under the License.
"""Get code APIs."""
+from __future__ import annotations
+
from deprecated import deprecated
from airflow.api.common.experimental import check_and_get_dag
diff --git a/airflow/api/common/experimental/get_dag_run_state.py b/airflow/api/common/experimental/get_dag_run_state.py
index 7201186ea9331..cdf044a1569b1 100644
--- a/airflow/api/common/experimental/get_dag_run_state.py
+++ b/airflow/api/common/experimental/get_dag_run_state.py
@@ -16,8 +16,9 @@
# specific language governing permissions and limitations
# under the License.
"""DAG run APIs."""
+from __future__ import annotations
+
from datetime import datetime
-from typing import Dict
from deprecated import deprecated
@@ -25,7 +26,7 @@
@deprecated(reason="Use DagRun().get_state() instead", version="2.2.4")
-def get_dag_run_state(dag_id: str, execution_date: datetime) -> Dict[str, str]:
+def get_dag_run_state(dag_id: str, execution_date: datetime) -> dict[str, str]:
"""Return the Dag Run state identified by the given dag_id and execution_date.
:param dag_id: DAG id
@@ -36,4 +37,4 @@ def get_dag_run_state(dag_id: str, execution_date: datetime) -> Dict[str, str]:
dagrun = check_and_get_dagrun(dag, execution_date)
- return {'state': dagrun.get_state()}
+ return {"state": dagrun.get_state()}
diff --git a/airflow/api/common/experimental/get_dag_runs.py b/airflow/api/common/experimental/get_dag_runs.py
index 2064d6eb51267..2761bb45fd2cd 100644
--- a/airflow/api/common/experimental/get_dag_runs.py
+++ b/airflow/api/common/experimental/get_dag_runs.py
@@ -16,7 +16,9 @@
# specific language governing permissions and limitations
# under the License.
"""DAG runs APIs."""
-from typing import Any, Dict, List, Optional
+from __future__ import annotations
+
+from typing import Any
from flask import url_for
@@ -25,9 +27,9 @@
from airflow.utils.state import DagRunState
-def get_dag_runs(dag_id: str, state: Optional[str] = None) -> List[Dict[str, Any]]:
+def get_dag_runs(dag_id: str, state: str | None = None) -> list[dict[str, Any]]:
"""
- Returns a list of Dag Runs for a specific DAG ID.
+ Return a list of Dag Runs for a specific DAG ID.
:param dag_id: String identifier of a DAG
:param state: queued|running|success...
@@ -41,13 +43,13 @@ def get_dag_runs(dag_id: str, state: Optional[str] = None) -> List[Dict[str, Any
for run in DagRun.find(dag_id=dag_id, state=state):
dag_runs.append(
{
- 'id': run.id,
- 'run_id': run.run_id,
- 'state': run.state,
- 'dag_id': run.dag_id,
- 'execution_date': run.execution_date.isoformat(),
- 'start_date': ((run.start_date or '') and run.start_date.isoformat()),
- 'dag_run_url': url_for('Airflow.graph', dag_id=run.dag_id, execution_date=run.execution_date),
+ "id": run.id,
+ "run_id": run.run_id,
+ "state": run.state,
+ "dag_id": run.dag_id,
+ "execution_date": run.execution_date.isoformat(),
+ "start_date": ((run.start_date or "") and run.start_date.isoformat()),
+ "dag_run_url": url_for("Airflow.graph", dag_id=run.dag_id, execution_date=run.execution_date),
}
)
diff --git a/airflow/api/common/experimental/get_lineage.py b/airflow/api/common/experimental/get_lineage.py
index 461590b70248f..73bc9dd862ef3 100644
--- a/airflow/api/common/experimental/get_lineage.py
+++ b/airflow/api/common/experimental/get_lineage.py
@@ -15,10 +15,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Lineage apis"""
+"""Lineage APIs."""
+from __future__ import annotations
+
import collections
import datetime
-from typing import Any, Dict
+from typing import Any
from sqlalchemy.orm import Session
@@ -31,15 +33,15 @@
@provide_session
def get_lineage(
dag_id: str, execution_date: datetime.datetime, *, session: Session = NEW_SESSION
-) -> Dict[str, Dict[str, Any]]:
- """Gets the lineage information for dag specified."""
+) -> dict[str, dict[str, Any]]:
+ """Get lineage information for dag specified."""
dag = check_and_get_dag(dag_id)
dagrun = check_and_get_dagrun(dag, execution_date)
inlets = XCom.get_many(dag_ids=dag_id, run_id=dagrun.run_id, key=PIPELINE_INLETS, session=session)
outlets = XCom.get_many(dag_ids=dag_id, run_id=dagrun.run_id, key=PIPELINE_OUTLETS, session=session)
- lineage: Dict[str, Dict[str, Any]] = collections.defaultdict(dict)
+ lineage: dict[str, dict[str, Any]] = collections.defaultdict(dict)
for meta in inlets:
lineage[meta.task_id]["inlets"] = meta.value
for meta in outlets:
diff --git a/airflow/api/common/experimental/get_task.py b/airflow/api/common/experimental/get_task.py
index 4589cc6ce4d42..34e0fac37983f 100644
--- a/airflow/api/common/experimental/get_task.py
+++ b/airflow/api/common/experimental/get_task.py
@@ -15,7 +15,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Task APIs.."""
+"""Task APIs."""
+from __future__ import annotations
+
from deprecated import deprecated
from airflow.api.common.experimental import check_and_get_dag
diff --git a/airflow/api/common/experimental/get_task_instance.py b/airflow/api/common/experimental/get_task_instance.py
index 7361efdc4c796..cc8c734338284 100644
--- a/airflow/api/common/experimental/get_task_instance.py
+++ b/airflow/api/common/experimental/get_task_instance.py
@@ -15,7 +15,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Task Instance APIs."""
+"""Task instance APIs."""
+from __future__ import annotations
+
from datetime import datetime
from deprecated import deprecated
@@ -34,7 +36,7 @@ def get_task_instance(dag_id: str, task_id: str, execution_date: datetime) -> Ta
# Get task instance object and check that it exists
task_instance = dagrun.get_task_instance(task_id)
if not task_instance:
- error_message = f'Task {task_id} instance for date {execution_date} not found'
+ error_message = f"Task {task_id} instance for date {execution_date} not found"
raise TaskInstanceNotFound(error_message)
return task_instance
diff --git a/airflow/api/common/experimental/mark_tasks.py b/airflow/api/common/experimental/mark_tasks.py
index 81cff3e30dad8..303c9f98ee59e 100644
--- a/airflow/api/common/experimental/mark_tasks.py
+++ b/airflow/api/common/experimental/mark_tasks.py
@@ -15,6 +15,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+"""Task Instance APIs."""
+from __future__ import annotations
+
import warnings
from airflow.api.common.mark_tasks import ( # noqa
diff --git a/airflow/api/common/experimental/pool.py b/airflow/api/common/experimental/pool.py
index 12ebc19dfd41a..a37c3e4086042 100644
--- a/airflow/api/common/experimental/pool.py
+++ b/airflow/api/common/experimental/pool.py
@@ -16,6 +16,8 @@
# specific language governing permissions and limitations
# under the License.
"""Pool APIs."""
+from __future__ import annotations
+
from deprecated import deprecated
from airflow.exceptions import AirflowBadRequest, PoolNotFound
diff --git a/airflow/api/common/experimental/trigger_dag.py b/airflow/api/common/experimental/trigger_dag.py
index d52631281f534..123b09cb1c02d 100644
--- a/airflow/api/common/experimental/trigger_dag.py
+++ b/airflow/api/common/experimental/trigger_dag.py
@@ -15,6 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
import warnings
diff --git a/airflow/api/common/mark_tasks.py b/airflow/api/common/mark_tasks.py
index 83bdb2081f0b8..8e25e0f4fd660 100644
--- a/airflow/api/common/mark_tasks.py
+++ b/airflow/api/common/mark_tasks.py
@@ -16,9 +16,10 @@
# specific language governing permissions and limitations
# under the License.
"""Marks tasks APIs."""
+from __future__ import annotations
from datetime import datetime
-from typing import TYPE_CHECKING, Collection, Iterable, Iterator, List, NamedTuple, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Collection, Iterable, Iterator, NamedTuple
from sqlalchemy import or_
from sqlalchemy.orm import lazyload
@@ -38,7 +39,7 @@
class _DagRunInfo(NamedTuple):
logical_date: datetime
- data_interval: Tuple[datetime, datetime]
+ data_interval: tuple[datetime, datetime]
def _create_dagruns(
@@ -78,9 +79,9 @@ def _create_dagruns(
@provide_session
def set_state(
*,
- tasks: Collection[Union[Operator, Tuple[Operator, int]]],
- run_id: Optional[str] = None,
- execution_date: Optional[datetime] = None,
+ tasks: Collection[Operator | tuple[Operator, int]],
+ run_id: str | None = None,
+ execution_date: datetime | None = None,
upstream: bool = False,
downstream: bool = False,
future: bool = False,
@@ -88,13 +89,14 @@ def set_state(
state: TaskInstanceState = TaskInstanceState.SUCCESS,
commit: bool = False,
session: SASession = NEW_SESSION,
-) -> List[TaskInstance]:
+) -> list[TaskInstance]:
"""
- Set the state of a task instance and if needed its relatives. Can set state
- for future tasks (calculated from run_id) and retroactively
+ Set the state of a task instance and if needed its relatives.
+
+ Can set state for future tasks (calculated from run_id) and retroactively
for past tasks. Will verify integrity of past dag runs in order to create
tasks that did not exist. It will not create dag runs that are missing
- on the schedule (but it will as for subdag dag runs if needed).
+ on the schedule (but it will, as for subdag, dag runs if needed).
:param tasks: the iterable of tasks or (task, map_index) tuples from which to work.
``task.dag`` needs to be set
@@ -152,6 +154,10 @@ def set_state(
qry_sub_dag = all_subdag_tasks_query(sub_dag_run_ids, session, state, confirmed_dates)
tis_altered += qry_sub_dag.with_for_update().all()
for task_instance in tis_altered:
+ # The try_number was decremented when setting to up_for_reschedule and deferred.
+ # Increment it back when changing the state again
+ if task_instance.state in [State.DEFERRED, State.UP_FOR_RESCHEDULE]:
+ task_instance._try_number += 1
task_instance.set_state(state, session=session)
session.flush()
else:
@@ -163,12 +169,12 @@ def set_state(
def all_subdag_tasks_query(
- sub_dag_run_ids: List[str],
+ sub_dag_run_ids: list[str],
session: SASession,
state: TaskInstanceState,
confirmed_dates: Iterable[datetime],
):
- """Get *all* tasks of the sub dags"""
+ """Get *all* tasks of the sub dags."""
qry_sub_dag = (
session.query(TaskInstance)
.filter(TaskInstance.dag_id.in_(sub_dag_run_ids), TaskInstance.execution_date.in_(confirmed_dates))
@@ -181,10 +187,10 @@ def get_all_dag_task_query(
dag: DAG,
session: SASession,
state: TaskInstanceState,
- task_ids: List[Union[str, Tuple[str, int]]],
+ task_ids: list[str | tuple[str, int]],
run_ids: Iterable[str],
):
- """Get all tasks of the main dag that will be affected by a state change"""
+ """Get all tasks of the main dag that will be affected by a state change."""
qry_dag = session.query(TaskInstance).filter(
TaskInstance.dag_id == dag.dag_id,
TaskInstance.run_id.in_(run_ids),
@@ -201,7 +207,7 @@ def _iter_subdag_run_ids(
dag: DAG,
session: SASession,
state: DagRunState,
- task_ids: List[str],
+ task_ids: list[str],
commit: bool,
confirmed_infos: Iterable[_DagRunInfo],
) -> Iterator[str]:
@@ -244,7 +250,7 @@ def verify_dagruns(
session: SASession,
current_task: Operator,
):
- """Verifies integrity of dag_runs.
+ """Verify integrity of dag_runs.
:param dag_runs: dag runs to verify
:param commit: whether dag runs state should be updated
@@ -261,7 +267,7 @@ def verify_dagruns(
session.merge(dag_run)
-def _iter_existing_dag_run_infos(dag: DAG, run_ids: List[str], session: SASession) -> Iterator[_DagRunInfo]:
+def _iter_existing_dag_run_infos(dag: DAG, run_ids: list[str], session: SASession) -> Iterator[_DagRunInfo]:
for dag_run in DagRun.find(dag_id=dag.dag_id, run_id=run_ids, session=session):
dag_run.dag = dag
dag_run.verify_integrity(session=session)
@@ -288,8 +294,8 @@ def find_task_relatives(tasks, downstream, upstream):
@provide_session
def get_execution_dates(
dag: DAG, execution_date: datetime, future: bool, past: bool, *, session: SASession = NEW_SESSION
-) -> List[datetime]:
- """Returns dates of DAG execution"""
+) -> list[datetime]:
+ """Return DAG execution dates."""
latest_execution_date = dag.get_latest_execution_date(session=session)
if latest_execution_date is None:
raise ValueError(f"Received non-localized date {execution_date}")
@@ -317,7 +323,7 @@ def get_execution_dates(
@provide_session
def get_run_ids(dag: DAG, run_id: str, future: bool, past: bool, session: SASession = NEW_SESSION):
- """Returns run_ids of DAG execution"""
+ """Return DAG executions' run_ids."""
last_dagrun = dag.get_last_dagrun(include_externally_triggered=True, session=session)
current_dagrun = dag.get_dagrun(run_id=run_id, session=session)
first_dagrun = (
@@ -328,7 +334,7 @@ def get_run_ids(dag: DAG, run_id: str, future: bool, past: bool, session: SASess
)
if last_dagrun is None:
- raise ValueError(f'DagRun for {dag.dag_id} not found')
+ raise ValueError(f"DagRun for {dag.dag_id} not found")
# determine run_id range of dag runs and tasks to consider
end_date = last_dagrun.logical_date if future else current_dagrun.logical_date
@@ -350,7 +356,7 @@ def get_run_ids(dag: DAG, run_id: str, future: bool, past: bool, session: SASess
def _set_dag_run_state(dag_id: str, run_id: str, state: DagRunState, session: SASession = NEW_SESSION):
"""
- Helper method that set dag run state in the DB.
+ Set dag run state in the DB.
:param dag_id: dag_id of target dag run
:param run_id: run id of target dag run
@@ -371,14 +377,15 @@ def _set_dag_run_state(dag_id: str, run_id: str, state: DagRunState, session: SA
def set_dag_run_state_to_success(
*,
dag: DAG,
- execution_date: Optional[datetime] = None,
- run_id: Optional[str] = None,
+ execution_date: datetime | None = None,
+ run_id: str | None = None,
commit: bool = False,
session: SASession = NEW_SESSION,
-) -> List[TaskInstance]:
+) -> list[TaskInstance]:
"""
- Set the dag run for a specific execution date and its task instances
- to success.
+ Set the dag run's state to success.
+
+ Set for a specific execution date and its task instances to success.
:param dag: the DAG of which to alter state
:param execution_date: the execution date from which to start looking(deprecated)
@@ -400,10 +407,10 @@ def set_dag_run_state_to_success(
raise ValueError(f"Received non-localized date {execution_date}")
dag_run = dag.get_dagrun(execution_date=execution_date)
if not dag_run:
- raise ValueError(f'DagRun with execution_date: {execution_date} not found')
+ raise ValueError(f"DagRun with execution_date: {execution_date} not found")
run_id = dag_run.run_id
if not run_id:
- raise ValueError(f'Invalid dag_run_id: {run_id}')
+ raise ValueError(f"Invalid dag_run_id: {run_id}")
# Mark the dag run to success.
if commit:
_set_dag_run_state(dag.dag_id, run_id, DagRunState.SUCCESS, session)
@@ -418,14 +425,15 @@ def set_dag_run_state_to_success(
def set_dag_run_state_to_failed(
*,
dag: DAG,
- execution_date: Optional[datetime] = None,
- run_id: Optional[str] = None,
+ execution_date: datetime | None = None,
+ run_id: str | None = None,
commit: bool = False,
session: SASession = NEW_SESSION,
-) -> List[TaskInstance]:
+) -> list[TaskInstance]:
"""
- Set the dag run for a specific execution date or run_id and its running task instances
- to failed.
+ Set the dag run's state to failed.
+
+ Set for a specific execution date and its task instances to failed.
:param dag: the DAG of which to alter state
:param execution_date: the execution date from which to start looking(deprecated)
@@ -446,11 +454,11 @@ def set_dag_run_state_to_failed(
raise ValueError(f"Received non-localized date {execution_date}")
dag_run = dag.get_dagrun(execution_date=execution_date)
if not dag_run:
- raise ValueError(f'DagRun with execution_date: {execution_date} not found')
+ raise ValueError(f"DagRun with execution_date: {execution_date} not found")
run_id = dag_run.run_id
if not run_id:
- raise ValueError(f'Invalid dag_run_id: {run_id}')
+ raise ValueError(f"Invalid dag_run_id: {run_id}")
# Mark the dag run to failed.
if commit:
@@ -462,7 +470,7 @@ def set_dag_run_state_to_failed(
TaskInstance.dag_id == dag.dag_id,
TaskInstance.run_id == run_id,
TaskInstance.task_id.in_(task_ids),
- TaskInstance.state.in_(State.running),
+ TaskInstance.state.in_([State.RUNNING, State.DEFERRED, State.UP_FOR_RESCHEDULE]),
)
task_ids_of_running_tis = [task_instance.task_id for task_instance in tis]
@@ -478,7 +486,7 @@ def set_dag_run_state_to_failed(
TaskInstance.dag_id == dag.dag_id,
TaskInstance.run_id == run_id,
TaskInstance.state.not_in(State.finished),
- TaskInstance.state.not_in(State.running),
+ TaskInstance.state.not_in([State.RUNNING, State.DEFERRED, State.UP_FOR_RESCHEDULE]),
)
tis = [ti for ti in tis]
@@ -493,11 +501,11 @@ def __set_dag_run_state_to_running_or_queued(
*,
new_state: DagRunState,
dag: DAG,
- execution_date: Optional[datetime] = None,
- run_id: Optional[str] = None,
+ execution_date: datetime | None = None,
+ run_id: str | None = None,
commit: bool = False,
session: SASession = NEW_SESSION,
-) -> List[TaskInstance]:
+) -> list[TaskInstance]:
"""
Set the dag run for a specific execution date to running.
@@ -509,7 +517,7 @@ def __set_dag_run_state_to_running_or_queued(
:return: If commit is true, list of tasks that have been updated,
otherwise list of tasks that will be updated
"""
- res: List[TaskInstance] = []
+ res: list[TaskInstance] = []
if not (execution_date is None) ^ (run_id is None):
return res
@@ -523,10 +531,10 @@ def __set_dag_run_state_to_running_or_queued(
raise ValueError(f"Received non-localized date {execution_date}")
dag_run = dag.get_dagrun(execution_date=execution_date)
if not dag_run:
- raise ValueError(f'DagRun with execution_date: {execution_date} not found')
+ raise ValueError(f"DagRun with execution_date: {execution_date} not found")
run_id = dag_run.run_id
if not run_id:
- raise ValueError(f'DagRun with run_id: {run_id} not found')
+ raise ValueError(f"DagRun with run_id: {run_id} not found")
# Mark the dag run to running.
if commit:
_set_dag_run_state(dag.dag_id, run_id, new_state, session)
@@ -539,11 +547,16 @@ def __set_dag_run_state_to_running_or_queued(
def set_dag_run_state_to_running(
*,
dag: DAG,
- execution_date: Optional[datetime] = None,
- run_id: Optional[str] = None,
+ execution_date: datetime | None = None,
+ run_id: str | None = None,
commit: bool = False,
session: SASession = NEW_SESSION,
-) -> List[TaskInstance]:
+) -> list[TaskInstance]:
+ """
+ Set the dag run's state to running.
+
+ Set for a specific execution date and its task instances to running.
+ """
return __set_dag_run_state_to_running_or_queued(
new_state=DagRunState.RUNNING,
dag=dag,
@@ -558,11 +571,16 @@ def set_dag_run_state_to_running(
def set_dag_run_state_to_queued(
*,
dag: DAG,
- execution_date: Optional[datetime] = None,
- run_id: Optional[str] = None,
+ execution_date: datetime | None = None,
+ run_id: str | None = None,
commit: bool = False,
session: SASession = NEW_SESSION,
-) -> List[TaskInstance]:
+) -> list[TaskInstance]:
+ """
+ Set the dag run's state to queued.
+
+ Set for a specific execution date and its task instances to queued.
+ """
return __set_dag_run_state_to_running_or_queued(
new_state=DagRunState.QUEUED,
dag=dag,
diff --git a/airflow/api/common/trigger_dag.py b/airflow/api/common/trigger_dag.py
index 225efad683717..01da7745c7bf3 100644
--- a/airflow/api/common/trigger_dag.py
+++ b/airflow/api/common/trigger_dag.py
@@ -16,11 +16,10 @@
# specific language governing permissions and limitations
# under the License.
"""Triggering DAG runs APIs."""
+from __future__ import annotations
+
import json
from datetime import datetime
-from typing import List, Optional, Union
-
-import pendulum
from airflow.exceptions import DagNotFound, DagRunAlreadyExists
from airflow.models import DagBag, DagModel, DagRun
@@ -32,11 +31,11 @@
def _trigger_dag(
dag_id: str,
dag_bag: DagBag,
- run_id: Optional[str] = None,
- conf: Optional[Union[dict, str]] = None,
- execution_date: Optional[datetime] = None,
+ run_id: str | None = None,
+ conf: dict | str | None = None,
+ execution_date: datetime | None = None,
replace_microseconds: bool = True,
-) -> List[Optional[DagRun]]:
+) -> list[DagRun | None]:
"""Triggers DAG run.
:param dag_id: DAG ID
@@ -60,21 +59,23 @@ def _trigger_dag(
if replace_microseconds:
execution_date = execution_date.replace(microsecond=0)
- if dag.default_args and 'start_date' in dag.default_args:
+ if dag.default_args and "start_date" in dag.default_args:
min_dag_start_date = dag.default_args["start_date"]
if min_dag_start_date and execution_date < min_dag_start_date:
raise ValueError(
f"The execution_date [{execution_date.isoformat()}] should be >= start_date "
f"[{min_dag_start_date.isoformat()}] from DAG's default_args"
)
+ logical_date = timezone.coerce_datetime(execution_date)
- run_id = run_id or DagRun.generate_run_id(DagRunType.MANUAL, execution_date)
+ data_interval = dag.timetable.infer_manual_data_interval(run_after=logical_date)
+ run_id = run_id or dag.timetable.generate_run_id(
+ run_type=DagRunType.MANUAL, logical_date=logical_date, data_interval=data_interval
+ )
dag_run = DagRun.find_duplicate(dag_id=dag_id, execution_date=execution_date, run_id=run_id)
if dag_run:
- raise DagRunAlreadyExists(
- f"A Dag Run already exists for dag id {dag_id} at {execution_date} with run id {run_id}"
- )
+ raise DagRunAlreadyExists(dag_run=dag_run, execution_date=execution_date, run_id=run_id)
run_conf = None
if conf:
@@ -90,9 +91,7 @@ def _trigger_dag(
conf=run_conf,
external_trigger=True,
dag_hash=dag_bag.dags_hash.get(dag_id),
- data_interval=_dag.timetable.infer_manual_data_interval(
- run_after=pendulum.instance(execution_date)
- ),
+ data_interval=data_interval,
)
dag_runs.append(dag_run)
@@ -101,12 +100,12 @@ def _trigger_dag(
def trigger_dag(
dag_id: str,
- run_id: Optional[str] = None,
- conf: Optional[Union[dict, str]] = None,
- execution_date: Optional[datetime] = None,
+ run_id: str | None = None,
+ conf: dict | str | None = None,
+ execution_date: datetime | None = None,
replace_microseconds: bool = True,
-) -> Optional[DagRun]:
- """Triggers execution of DAG specified by dag_id
+) -> DagRun | None:
+ """Triggers execution of DAG specified by dag_id.
:param dag_id: DAG ID
:param run_id: ID of the dag_run
diff --git a/airflow/api_connexion/endpoints/config_endpoint.py b/airflow/api_connexion/endpoints/config_endpoint.py
index bdd2b3a959547..a9353234f0b4e 100644
--- a/airflow/api_connexion/endpoints/config_endpoint.py
+++ b/airflow/api_connexion/endpoints/config_endpoint.py
@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
from http import HTTPStatus
@@ -26,11 +27,11 @@
from airflow.security import permissions
from airflow.settings import json
-LINE_SEP = '\n' # `\n` cannot appear in f-strings
+LINE_SEP = "\n" # `\n` cannot appear in f-strings
def _conf_dict_to_config(conf_dict: dict) -> Config:
- """Convert config dict to a Config object"""
+ """Convert config dict to a Config object."""
config = Config(
sections=[
ConfigSection(
@@ -43,25 +44,25 @@ def _conf_dict_to_config(conf_dict: dict) -> Config:
def _option_to_text(config_option: ConfigOption) -> str:
- """Convert a single config option to text"""
- return f'{config_option.key} = {config_option.value}'
+ """Convert a single config option to text."""
+ return f"{config_option.key} = {config_option.value}"
def _section_to_text(config_section: ConfigSection) -> str:
- """Convert a single config section to text"""
+ """Convert a single config section to text."""
return (
- f'[{config_section.name}]{LINE_SEP}'
- f'{LINE_SEP.join(_option_to_text(option) for option in config_section.options)}{LINE_SEP}'
+ f"[{config_section.name}]{LINE_SEP}"
+ f"{LINE_SEP.join(_option_to_text(option) for option in config_section.options)}{LINE_SEP}"
)
def _config_to_text(config: Config) -> str:
- """Convert the entire config to text"""
+ """Convert the entire config to text."""
return LINE_SEP.join(_section_to_text(s) for s in config.sections)
def _config_to_json(config: Config) -> str:
- """Convert a Config object to a JSON formatted string"""
+ """Convert a Config object to a JSON formatted string."""
return json.dumps(config_schema.dump(config), indent=4)
@@ -69,8 +70,8 @@ def _config_to_json(config: Config) -> str:
def get_config() -> Response:
"""Get current configuration."""
serializer = {
- 'text/plain': _config_to_text,
- 'application/json': _config_to_json,
+ "text/plain": _config_to_text,
+ "application/json": _config_to_json,
}
return_type = request.accept_mimetypes.best_match(serializer.keys())
if return_type not in serializer:
@@ -79,11 +80,11 @@ def get_config() -> Response:
conf_dict = conf.as_dict(display_source=False, display_sensitive=True)
config = _conf_dict_to_config(conf_dict)
config_text = serializer[return_type](config)
- return Response(config_text, headers={'Content-Type': return_type})
+ return Response(config_text, headers={"Content-Type": return_type})
else:
raise PermissionDenied(
detail=(
- 'Your Airflow administrator chose not to expose the configuration, most likely for security'
- ' reasons.'
+ "Your Airflow administrator chose not to expose the configuration, most likely for security"
+ " reasons."
)
)
diff --git a/airflow/api_connexion/endpoints/connection_endpoint.py b/airflow/api_connexion/endpoints/connection_endpoint.py
index b196b3236b911..40cb474bda4ea 100644
--- a/airflow/api_connexion/endpoints/connection_endpoint.py
+++ b/airflow/api_connexion/endpoints/connection_endpoint.py
@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
import os
from http import HTTPStatus
@@ -37,18 +38,28 @@
from airflow.models import Connection
from airflow.secrets.environment_variables import CONN_ENV_PREFIX
from airflow.security import permissions
+from airflow.utils.log.action_logger import action_event_from_permission
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.strings import get_random_string
+from airflow.www.decorators import action_logging
+
+RESOURCE_EVENT_PREFIX = "connection"
@security.requires_access([(permissions.ACTION_CAN_DELETE, permissions.RESOURCE_CONNECTION)])
@provide_session
+@action_logging(
+ event=action_event_from_permission(
+ prefix=RESOURCE_EVENT_PREFIX,
+ permission=permissions.ACTION_CAN_DELETE,
+ ),
+)
def delete_connection(*, connection_id: str, session: Session = NEW_SESSION) -> APIResponse:
- """Delete a connection entry"""
+ """Delete a connection entry."""
connection = session.query(Connection).filter_by(conn_id=connection_id).one_or_none()
if connection is None:
raise NotFound(
- 'Connection not found',
+ "Connection not found",
detail=f"The Connection with connection_id: `{connection_id}` was not found",
)
session.delete(connection)
@@ -58,7 +69,7 @@ def delete_connection(*, connection_id: str, session: Session = NEW_SESSION) ->
@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_CONNECTION)])
@provide_session
def get_connection(*, connection_id: str, session: Session = NEW_SESSION) -> APIResponse:
- """Get a connection entry"""
+ """Get a connection entry."""
connection = session.query(Connection).filter(Connection.conn_id == connection_id).one_or_none()
if connection is None:
raise NotFound(
@@ -69,7 +80,7 @@ def get_connection(*, connection_id: str, session: Session = NEW_SESSION) -> API
@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_CONNECTION)])
-@format_parameters({'limit': check_limit})
+@format_parameters({"limit": check_limit})
@provide_session
def get_connections(
*,
@@ -78,9 +89,9 @@ def get_connections(
order_by: str = "id",
session: Session = NEW_SESSION,
) -> APIResponse:
- """Get all connection entries"""
+ """Get all connection entries."""
to_replace = {"connection_id": "conn_id"}
- allowed_filter_attrs = ['connection_id', 'conn_type', 'description', 'host', 'port', 'id']
+ allowed_filter_attrs = ["connection_id", "conn_type", "description", "host", "port", "id"]
total_entries = session.query(func.count(Connection.id)).scalar()
query = session.query(Connection)
@@ -93,26 +104,32 @@ def get_connections(
@security.requires_access([(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_CONNECTION)])
@provide_session
+@action_logging(
+ event=action_event_from_permission(
+ prefix=RESOURCE_EVENT_PREFIX,
+ permission=permissions.ACTION_CAN_EDIT,
+ ),
+)
def patch_connection(
*,
connection_id: str,
update_mask: UpdateMask = None,
session: Session = NEW_SESSION,
) -> APIResponse:
- """Update a connection entry"""
+ """Update a connection entry."""
try:
data = connection_schema.load(request.json, partial=True)
except ValidationError as err:
# If validation get to here, it is extra field validation.
raise BadRequest(detail=str(err.messages))
- non_update_fields = ['connection_id', 'conn_id']
+ non_update_fields = ["connection_id", "conn_id"]
connection = session.query(Connection).filter_by(conn_id=connection_id).first()
if connection is None:
raise NotFound(
"Connection not found",
detail=f"The Connection with connection_id: `{connection_id}` was not found",
)
- if data.get('conn_id') and connection.conn_id != data['conn_id']:
+ if data.get("conn_id") and connection.conn_id != data["conn_id"]:
raise BadRequest(detail="The connection_id cannot be updated.")
if update_mask:
update_mask = [i.strip() for i in update_mask]
@@ -132,14 +149,20 @@ def patch_connection(
@security.requires_access([(permissions.ACTION_CAN_CREATE, permissions.RESOURCE_CONNECTION)])
@provide_session
+@action_logging(
+ event=action_event_from_permission(
+ prefix=RESOURCE_EVENT_PREFIX,
+ permission=permissions.ACTION_CAN_CREATE,
+ ),
+)
def post_connection(*, session: Session = NEW_SESSION) -> APIResponse:
- """Create connection entry"""
+ """Create connection entry."""
body = request.json
try:
data = connection_schema.load(body)
except ValidationError as err:
raise BadRequest(detail=str(err.messages))
- conn_id = data['conn_id']
+ conn_id = data["conn_id"]
query = session.query(Connection)
connection = query.filter_by(conn_id=conn_id).first()
if not connection:
@@ -153,16 +176,18 @@ def post_connection(*, session: Session = NEW_SESSION) -> APIResponse:
@security.requires_access([(permissions.ACTION_CAN_CREATE, permissions.RESOURCE_CONNECTION)])
def test_connection() -> APIResponse:
"""
- To test a connection, this method first creates an in-memory dummy conn_id & exports that to an
+ Test an API connection.
+
+ This method first creates an in-memory dummy conn_id & exports that to an
env var, as some hook classes tries to find out the conn from their __init__ method & errors out
if not found. It also deletes the conn id env variable after the test.
"""
body = request.json
dummy_conn_id = get_random_string()
- conn_env_var = f'{CONN_ENV_PREFIX}{dummy_conn_id.upper()}'
+ conn_env_var = f"{CONN_ENV_PREFIX}{dummy_conn_id.upper()}"
try:
data = connection_schema.load(body)
- data['conn_id'] = dummy_conn_id
+ data["conn_id"] = dummy_conn_id
conn = Connection(**data)
os.environ[conn_env_var] = conn.get_uri()
status, message = conn.test_connection()
diff --git a/airflow/api_connexion/endpoints/dag_endpoint.py b/airflow/api_connexion/endpoints/dag_endpoint.py
index 0505f864ee333..be66dc814e95b 100644
--- a/airflow/api_connexion/endpoints/dag_endpoint.py
+++ b/airflow/api_connexion/endpoints/dag_endpoint.py
@@ -14,12 +14,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
from http import HTTPStatus
-from typing import Collection, Optional
+from typing import Collection
from connexion import NoContent
-from flask import current_app, g, request
+from flask import g, request
from marshmallow import ValidationError
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import or_
@@ -38,6 +39,7 @@
from airflow.exceptions import AirflowException, DagNotFound
from airflow.models.dag import DagModel, DagTag
from airflow.security import permissions
+from airflow.utils.airflow_flask_app import get_airflow_app
from airflow.utils.session import NEW_SESSION, provide_session
@@ -56,21 +58,21 @@ def get_dag(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse:
@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG)])
def get_dag_details(*, dag_id: str) -> APIResponse:
"""Get details of DAG."""
- dag: DAG = current_app.dag_bag.get_dag(dag_id)
+ dag: DAG = get_airflow_app().dag_bag.get_dag(dag_id)
if not dag:
raise NotFound("DAG not found", detail=f"The DAG with dag_id: {dag_id} was not found")
return dag_detail_schema.dump(dag)
@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG)])
-@format_parameters({'limit': check_limit})
+@format_parameters({"limit": check_limit})
@provide_session
def get_dags(
*,
limit: int,
offset: int = 0,
- tags: Optional[Collection[str]] = None,
- dag_id_pattern: Optional[str] = None,
+ tags: Collection[str] | None = None,
+ dag_id_pattern: str | None = None,
only_active: bool = True,
session: Session = NEW_SESSION,
) -> APIResponse:
@@ -81,9 +83,9 @@ def get_dags(
dags_query = session.query(DagModel).filter(~DagModel.is_subdag)
if dag_id_pattern:
- dags_query = dags_query.filter(DagModel.dag_id.ilike(f'%{dag_id_pattern}%'))
+ dags_query = dags_query.filter(DagModel.dag_id.ilike(f"%{dag_id_pattern}%"))
- readable_dags = current_app.appbuilder.sm.get_accessible_dag_ids(g.user)
+ readable_dags = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user)
dags_query = dags_query.filter(DagModel.dag_id.in_(readable_dags))
if tags:
@@ -100,27 +102,27 @@ def get_dags(
@security.requires_access([(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG)])
@provide_session
def patch_dag(*, dag_id: str, update_mask: UpdateMask = None, session: Session = NEW_SESSION) -> APIResponse:
- """Update the specific DAG"""
+ """Update the specific DAG."""
try:
patch_body = dag_schema.load(request.json, session=session)
except ValidationError as err:
raise BadRequest(detail=str(err.messages))
if update_mask:
patch_body_ = {}
- if update_mask != ['is_paused']:
+ if update_mask != ["is_paused"]:
raise BadRequest(detail="Only `is_paused` field can be updated through the REST API")
patch_body_[update_mask[0]] = patch_body[update_mask[0]]
patch_body = patch_body_
dag = session.query(DagModel).filter(DagModel.dag_id == dag_id).one_or_none()
if not dag:
raise NotFound(f"Dag with id: '{dag_id}' not found")
- dag.is_paused = patch_body['is_paused']
+ dag.is_paused = patch_body["is_paused"]
session.flush()
return dag_schema.dump(dag)
@security.requires_access([(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG)])
-@format_parameters({'limit': check_limit})
+@format_parameters({"limit": check_limit})
@provide_session
def patch_dags(limit, session, offset=0, only_active=True, tags=None, dag_id_pattern=None, update_mask=None):
"""Patch multiple DAGs."""
@@ -130,7 +132,7 @@ def patch_dags(limit, session, offset=0, only_active=True, tags=None, dag_id_pat
raise BadRequest(detail=str(err.messages))
if update_mask:
patch_body_ = {}
- if update_mask != ['is_paused']:
+ if update_mask != ["is_paused"]:
raise BadRequest(detail="Only `is_paused` field can be updated through the REST API")
update_mask = update_mask[0]
patch_body_[update_mask] = patch_body[update_mask]
@@ -140,10 +142,10 @@ def patch_dags(limit, session, offset=0, only_active=True, tags=None, dag_id_pat
else:
dags_query = session.query(DagModel).filter(~DagModel.is_subdag)
- if dag_id_pattern == '~':
- dag_id_pattern = '%'
- dags_query = dags_query.filter(DagModel.dag_id.ilike(f'%{dag_id_pattern}%'))
- editable_dags = current_app.appbuilder.sm.get_editable_dag_ids(g.user)
+ if dag_id_pattern == "~":
+ dag_id_pattern = "%"
+ dags_query = dags_query.filter(DagModel.dag_id.ilike(f"%{dag_id_pattern}%"))
+ editable_dags = get_airflow_app().appbuilder.sm.get_editable_dag_ids(g.user)
dags_query = dags_query.filter(DagModel.dag_id.in_(editable_dags))
if tags:
@@ -156,7 +158,7 @@ def patch_dags(limit, session, offset=0, only_active=True, tags=None, dag_id_pat
dags_to_update = {dag.dag_id for dag in dags}
session.query(DagModel).filter(DagModel.dag_id.in_(dags_to_update)).update(
- {DagModel.is_paused: patch_body['is_paused']}, synchronize_session='fetch'
+ {DagModel.is_paused: patch_body["is_paused"]}, synchronize_session="fetch"
)
session.flush()
diff --git a/airflow/api_connexion/endpoints/dag_run_endpoint.py b/airflow/api_connexion/endpoints/dag_run_endpoint.py
index e510126534b12..b566ba2df72f6 100644
--- a/airflow/api_connexion/endpoints/dag_run_endpoint.py
+++ b/airflow/api_connexion/endpoints/dag_run_endpoint.py
@@ -14,12 +14,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
from http import HTTPStatus
-from typing import List, Optional, Tuple
import pendulum
from connexion import NoContent
-from flask import current_app, g, request
+from flask import g
from marshmallow import ValidationError
from sqlalchemy import or_
from sqlalchemy.orm import Query, Session
@@ -30,6 +31,7 @@
set_dag_run_state_to_success,
)
from airflow.api_connexion import security
+from airflow.api_connexion.endpoints.request_dict import get_json_request_dict
from airflow.api_connexion.exceptions import AlreadyExists, BadRequest, NotFound
from airflow.api_connexion.parameters import apply_sorting, check_limit, format_datetime, format_parameters
from airflow.api_connexion.schemas.dag_run_schema import (
@@ -38,8 +40,13 @@
dagrun_collection_schema,
dagrun_schema,
dagruns_batch_form_schema,
+ set_dagrun_note_form_schema,
set_dagrun_state_form_schema,
)
+from airflow.api_connexion.schemas.dataset_schema import (
+ DatasetEventCollection,
+ dataset_event_collection_schema,
+)
from airflow.api_connexion.schemas.task_instance_schema import (
TaskInstanceReferenceCollection,
task_instance_reference_collection_schema,
@@ -47,6 +54,7 @@
from airflow.api_connexion.types import APIResponse
from airflow.models import DagModel, DagRun
from airflow.security import permissions
+from airflow.utils.airflow_flask_app import get_airflow_app
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import DagRunState
from airflow.utils.types import DagRunType
@@ -60,7 +68,7 @@
)
@provide_session
def delete_dag_run(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION) -> APIResponse:
- """Delete a DAG Run"""
+ """Delete a DAG Run."""
if session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id).delete() == 0:
raise NotFound(detail=f"DAGRun with DAG ID: '{dag_id}' and DagRun ID: '{dag_run_id}' not found")
return NoContent, HTTPStatus.NO_CONTENT
@@ -84,19 +92,50 @@ def get_dag_run(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION)
return dagrun_schema.dump(dag_run)
+@security.requires_access(
+ [
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN),
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DATASET),
+ ],
+)
+@provide_session
+def get_upstream_dataset_events(
+ *, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION
+) -> APIResponse:
+ """If dag run is dataset-triggered, return the dataset events that triggered it."""
+ dag_run: DagRun | None = (
+ session.query(DagRun)
+ .filter(
+ DagRun.dag_id == dag_id,
+ DagRun.run_id == dag_run_id,
+ )
+ .one_or_none()
+ )
+ if dag_run is None:
+ raise NotFound(
+ "DAGRun not found",
+ detail=f"DAGRun with DAG ID: '{dag_id}' and DagRun ID: '{dag_run_id}' not found",
+ )
+ events = dag_run.consumed_dataset_events
+ return dataset_event_collection_schema.dump(
+ DatasetEventCollection(dataset_events=events, total_entries=len(events))
+ )
+
+
def _fetch_dag_runs(
query: Query,
*,
- end_date_gte: Optional[str],
- end_date_lte: Optional[str],
- execution_date_gte: Optional[str],
- execution_date_lte: Optional[str],
- start_date_gte: Optional[str],
- start_date_lte: Optional[str],
- limit: Optional[int],
- offset: Optional[int],
+ end_date_gte: str | None,
+ end_date_lte: str | None,
+ execution_date_gte: str | None,
+ execution_date_lte: str | None,
+ start_date_gte: str | None,
+ start_date_lte: str | None,
+ limit: int | None,
+ offset: int | None,
order_by: str,
-) -> Tuple[List[DagRun], int]:
+) -> tuple[list[DagRun], int]:
if start_date_gte:
query = query.filter(DagRun.start_date >= start_date_gte)
if start_date_lte:
@@ -137,28 +176,28 @@ def _fetch_dag_runs(
)
@format_parameters(
{
- 'start_date_gte': format_datetime,
- 'start_date_lte': format_datetime,
- 'execution_date_gte': format_datetime,
- 'execution_date_lte': format_datetime,
- 'end_date_gte': format_datetime,
- 'end_date_lte': format_datetime,
- 'limit': check_limit,
+ "start_date_gte": format_datetime,
+ "start_date_lte": format_datetime,
+ "execution_date_gte": format_datetime,
+ "execution_date_lte": format_datetime,
+ "end_date_gte": format_datetime,
+ "end_date_lte": format_datetime,
+ "limit": check_limit,
}
)
@provide_session
def get_dag_runs(
*,
dag_id: str,
- start_date_gte: Optional[str] = None,
- start_date_lte: Optional[str] = None,
- execution_date_gte: Optional[str] = None,
- execution_date_lte: Optional[str] = None,
- end_date_gte: Optional[str] = None,
- end_date_lte: Optional[str] = None,
- state: Optional[List[str]] = None,
- offset: Optional[int] = None,
- limit: Optional[int] = None,
+ start_date_gte: str | None = None,
+ start_date_lte: str | None = None,
+ execution_date_gte: str | None = None,
+ execution_date_lte: str | None = None,
+ end_date_gte: str | None = None,
+ end_date_lte: str | None = None,
+ state: list[str] | None = None,
+ offset: int | None = None,
+ limit: int | None = None,
order_by: str = "id",
session: Session = NEW_SESSION,
):
@@ -167,7 +206,7 @@ def get_dag_runs(
# This endpoint allows specifying ~ as the dag_id to retrieve DAG Runs for all DAGs.
if dag_id == "~":
- appbuilder = current_app.appbuilder
+ appbuilder = get_airflow_app().appbuilder
query = query.filter(DagRun.dag_id.in_(appbuilder.sm.get_readable_dag_ids(g.user)))
else:
query = query.filter(DagRun.dag_id == dag_id)
@@ -198,14 +237,14 @@ def get_dag_runs(
)
@provide_session
def get_dag_runs_batch(*, session: Session = NEW_SESSION) -> APIResponse:
- """Get list of DAG Runs"""
- body = request.get_json()
+ """Get list of DAG Runs."""
+ body = get_json_request_dict()
try:
data = dagruns_batch_form_schema.load(body)
except ValidationError as err:
raise BadRequest(detail=str(err.messages))
- appbuilder = current_app.appbuilder
+ appbuilder = get_airflow_app().appbuilder
readable_dag_ids = appbuilder.sm.get_readable_dag_ids(g.user)
query = session.query(DagRun)
if data.get("dag_ids"):
@@ -252,7 +291,7 @@ def post_dag_run(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse:
detail=f"DAG with dag_id: '{dag_id}' has import errors",
)
try:
- post_body = dagrun_schema.load(request.json, session=session)
+ post_body = dagrun_schema.load(get_json_request_dict(), session=session)
except ValidationError as err:
raise BadRequest(detail=str(err))
@@ -268,7 +307,7 @@ def post_dag_run(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse:
)
if not dagrun_instance:
try:
- dag = current_app.dag_bag.get_dag(dag_id)
+ dag = get_airflow_app().dag_bag.get_dag(dag_id)
dag_run = dag.create_dagrun(
run_type=DagRunType.MANUAL,
run_id=run_id,
@@ -277,7 +316,8 @@ def post_dag_run(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse:
state=DagRunState.QUEUED,
conf=post_body.get("conf"),
external_trigger=True,
- dag_hash=current_app.dag_bag.dags_hash.get(dag_id),
+ dag_hash=get_airflow_app().dag_bag.dags_hash.get(dag_id),
+ session=session,
)
return dagrun_schema.dump(dag_run)
except ValueError as ve:
@@ -303,19 +343,19 @@ def post_dag_run(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse:
@provide_session
def update_dag_run_state(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION) -> APIResponse:
"""Set a state of a dag run."""
- dag_run: Optional[DagRun] = (
+ dag_run: DagRun | None = (
session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id).one_or_none()
)
if dag_run is None:
- error_message = f'Dag Run id {dag_run_id} not found in dag {dag_id}'
+ error_message = f"Dag Run id {dag_run_id} not found in dag {dag_id}"
raise NotFound(error_message)
try:
- post_body = set_dagrun_state_form_schema.load(request.json)
+ post_body = set_dagrun_state_form_schema.load(get_json_request_dict())
except ValidationError as err:
raise BadRequest(detail=str(err))
- state = post_body['state']
- dag = current_app.dag_bag.get_dag(dag_id)
+ state = post_body["state"]
+ dag = get_airflow_app().dag_bag.get_dag(dag_id)
if state == DagRunState.SUCCESS:
set_dag_run_state_to_success(dag=dag, run_id=dag_run.run_id, commit=True)
elif state == DagRunState.QUEUED:
@@ -335,19 +375,19 @@ def update_dag_run_state(*, dag_id: str, dag_run_id: str, session: Session = NEW
@provide_session
def clear_dag_run(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION) -> APIResponse:
"""Clear a dag run."""
- dag_run: Optional[DagRun] = (
+ dag_run: DagRun | None = (
session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id).one_or_none()
)
if dag_run is None:
- error_message = f'Dag Run id {dag_run_id} not found in dag {dag_id}'
+ error_message = f"Dag Run id {dag_run_id} not found in dag {dag_id}"
raise NotFound(error_message)
try:
- post_body = clear_dagrun_form_schema.load(request.json)
+ post_body = clear_dagrun_form_schema.load(get_json_request_dict())
except ValidationError as err:
raise BadRequest(detail=str(err))
- dry_run = post_body.get('dry_run', False)
- dag = current_app.dag_bag.get_dag(dag_id)
+ dry_run = post_body.get("dry_run", False)
+ dag = get_airflow_app().dag_bag.get_dag(dag_id)
start_date = dag_run.logical_date
end_date = dag_run.logical_date
@@ -373,5 +413,38 @@ def clear_dag_run(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSIO
include_parentdag=True,
only_failed=False,
)
- dag_run.refresh_from_db()
+ dag_run = session.query(DagRun).filter(DagRun.id == dag_run.id).one()
return dagrun_schema.dump(dag_run)
+
+
+@security.requires_access(
+ [
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
+ (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN),
+ ],
+)
+@provide_session
+def set_dag_run_note(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION) -> APIResponse:
+ """Set the note for a dag run."""
+ dag_run: DagRun | None = (
+ session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id).one_or_none()
+ )
+ if dag_run is None:
+ error_message = f"Dag Run id {dag_run_id} not found in dag {dag_id}"
+ raise NotFound(error_message)
+ try:
+ post_body = set_dagrun_note_form_schema.load(get_json_request_dict())
+ new_note = post_body["note"]
+ except ValidationError as err:
+ raise BadRequest(detail=str(err))
+
+ from flask_login import current_user
+
+ current_user_id = getattr(current_user, "id", None)
+ if dag_run.dag_run_note is None:
+ dag_run.note = (new_note, current_user_id)
+ else:
+ dag_run.dag_run_note.content = new_note
+ dag_run.dag_run_note.user_id = current_user_id
+ session.commit()
+ return dagrun_schema.dump(dag_run)
diff --git a/airflow/api_connexion/endpoints/dag_source_endpoint.py b/airflow/api_connexion/endpoints/dag_source_endpoint.py
index ad6209221e523..42ccd4e5d8671 100644
--- a/airflow/api_connexion/endpoints/dag_source_endpoint.py
+++ b/airflow/api_connexion/endpoints/dag_source_endpoint.py
@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
from http import HTTPStatus
@@ -29,7 +30,7 @@
@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_CODE)])
def get_dag_source(*, file_token: str) -> Response:
- """Get source code using file token"""
+ """Get source code using file token."""
secret_key = current_app.config["SECRET_KEY"]
auth_s = URLSafeSerializer(secret_key)
try:
@@ -38,10 +39,10 @@ def get_dag_source(*, file_token: str) -> Response:
except (BadSignature, FileNotFoundError):
raise NotFound("Dag source not found")
- return_type = request.accept_mimetypes.best_match(['text/plain', 'application/json'])
- if return_type == 'text/plain':
- return Response(dag_source, headers={'Content-Type': return_type})
- if return_type == 'application/json':
+ return_type = request.accept_mimetypes.best_match(["text/plain", "application/json"])
+ if return_type == "text/plain":
+ return Response(dag_source, headers={"Content-Type": return_type})
+ if return_type == "application/json":
content = dag_source_schema.dumps(dict(content=dag_source))
- return Response(content, headers={'Content-Type': return_type})
+ return Response(content, headers={"Content-Type": return_type})
return Response("Not Allowed Accept Header", status=HTTPStatus.NOT_ACCEPTABLE)
diff --git a/airflow/api_connexion/endpoints/dag_warning_endpoint.py b/airflow/api_connexion/endpoints/dag_warning_endpoint.py
new file mode 100644
index 0000000000000..5a73afd1a33a8
--- /dev/null
+++ b/airflow/api_connexion/endpoints/dag_warning_endpoint.py
@@ -0,0 +1,61 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from sqlalchemy.orm import Session
+
+from airflow.api_connexion import security
+from airflow.api_connexion.parameters import apply_sorting, check_limit, format_parameters
+from airflow.api_connexion.schemas.dag_warning_schema import (
+ DagWarningCollection,
+ dag_warning_collection_schema,
+)
+from airflow.api_connexion.types import APIResponse
+from airflow.models.dagwarning import DagWarning as DagWarningModel
+from airflow.security import permissions
+from airflow.utils.session import NEW_SESSION, provide_session
+
+
+@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_WARNING)])
+@format_parameters({"limit": check_limit})
+@provide_session
+def get_dag_warnings(
+ *,
+ limit: int,
+ dag_id: str | None = None,
+ warning_type: str | None = None,
+ offset: int | None = None,
+ order_by: str = "timestamp",
+ session: Session = NEW_SESSION,
+) -> APIResponse:
+ """Get DAG warnings.
+
+ :param dag_id: the dag_id to optionally filter by
+ :param warning_type: the warning type to optionally filter by
+ """
+ allowed_filter_attrs = ["dag_id", "warning_type", "message", "timestamp"]
+ query = session.query(DagWarningModel)
+ if dag_id:
+ query = query.filter(DagWarningModel.dag_id == dag_id)
+ if warning_type:
+ query = query.filter(DagWarningModel.warning_type == warning_type)
+ total_entries = query.count()
+ query = apply_sorting(query=query, order_by=order_by, allowed_attrs=allowed_filter_attrs)
+ dag_warnings = query.offset(offset).limit(limit).all()
+ return dag_warning_collection_schema.dump(
+ DagWarningCollection(dag_warnings=dag_warnings, total_entries=total_entries)
+ )
diff --git a/airflow/api_connexion/endpoints/dataset_endpoint.py b/airflow/api_connexion/endpoints/dataset_endpoint.py
new file mode 100644
index 0000000000000..42e8bb3c36c33
--- /dev/null
+++ b/airflow/api_connexion/endpoints/dataset_endpoint.py
@@ -0,0 +1,122 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from sqlalchemy import func
+from sqlalchemy.orm import Session, joinedload, subqueryload
+
+from airflow.api_connexion import security
+from airflow.api_connexion.exceptions import NotFound
+from airflow.api_connexion.parameters import apply_sorting, check_limit, format_parameters
+from airflow.api_connexion.schemas.dataset_schema import (
+ DatasetCollection,
+ DatasetEventCollection,
+ dataset_collection_schema,
+ dataset_event_collection_schema,
+ dataset_schema,
+)
+from airflow.api_connexion.types import APIResponse
+from airflow.models.dataset import DatasetEvent, DatasetModel
+from airflow.security import permissions
+from airflow.utils.session import NEW_SESSION, provide_session
+
+
+@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DATASET)])
+@provide_session
+def get_dataset(uri: str, session: Session = NEW_SESSION) -> APIResponse:
+ """Get a Dataset."""
+ dataset = (
+ session.query(DatasetModel)
+ .filter(DatasetModel.uri == uri)
+ .options(joinedload(DatasetModel.consuming_dags), joinedload(DatasetModel.producing_tasks))
+ .one_or_none()
+ )
+ if not dataset:
+ raise NotFound(
+ "Dataset not found",
+ detail=f"The Dataset with uri: `{uri}` was not found",
+ )
+ return dataset_schema.dump(dataset)
+
+
+@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DATASET)])
+@format_parameters({"limit": check_limit})
+@provide_session
+def get_datasets(
+ *,
+ limit: int,
+ offset: int = 0,
+ uri_pattern: str | None = None,
+ order_by: str = "id",
+ session: Session = NEW_SESSION,
+) -> APIResponse:
+ """Get datasets."""
+ allowed_attrs = ["id", "uri", "created_at", "updated_at"]
+
+ total_entries = session.query(func.count(DatasetModel.id)).scalar()
+ query = session.query(DatasetModel)
+ if uri_pattern:
+ query = query.filter(DatasetModel.uri.ilike(f"%{uri_pattern}%"))
+ query = apply_sorting(query, order_by, {}, allowed_attrs)
+ datasets = (
+ query.options(subqueryload(DatasetModel.consuming_dags), subqueryload(DatasetModel.producing_tasks))
+ .offset(offset)
+ .limit(limit)
+ .all()
+ )
+ return dataset_collection_schema.dump(DatasetCollection(datasets=datasets, total_entries=total_entries))
+
+
+@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DATASET)])
+@provide_session
+@format_parameters({"limit": check_limit})
+def get_dataset_events(
+ *,
+ limit: int,
+ offset: int = 0,
+ order_by: str = "timestamp",
+ dataset_id: int | None = None,
+ source_dag_id: str | None = None,
+ source_task_id: str | None = None,
+ source_run_id: str | None = None,
+ source_map_index: int | None = None,
+ session: Session = NEW_SESSION,
+) -> APIResponse:
+ """Get dataset events."""
+ allowed_attrs = ["source_dag_id", "source_task_id", "source_run_id", "source_map_index", "timestamp"]
+
+ query = session.query(DatasetEvent)
+
+ if dataset_id:
+ query = query.filter(DatasetEvent.dataset_id == dataset_id)
+ if source_dag_id:
+ query = query.filter(DatasetEvent.source_dag_id == source_dag_id)
+ if source_task_id:
+ query = query.filter(DatasetEvent.source_task_id == source_task_id)
+ if source_run_id:
+ query = query.filter(DatasetEvent.source_run_id == source_run_id)
+ if source_map_index:
+ query = query.filter(DatasetEvent.source_map_index == source_map_index)
+
+ query = query.options(subqueryload(DatasetEvent.created_dagruns))
+
+ total_entries = query.count()
+ query = apply_sorting(query, order_by, {}, allowed_attrs)
+ events = query.offset(offset).limit(limit).all()
+ return dataset_event_collection_schema.dump(
+ DatasetEventCollection(dataset_events=events, total_entries=total_entries)
+ )
diff --git a/airflow/api_connexion/endpoints/event_log_endpoint.py b/airflow/api_connexion/endpoints/event_log_endpoint.py
index 590190ab98a20..94fa73a431597 100644
--- a/airflow/api_connexion/endpoints/event_log_endpoint.py
+++ b/airflow/api_connexion/endpoints/event_log_endpoint.py
@@ -14,8 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
-from typing import Optional
+from __future__ import annotations
from sqlalchemy import func
from sqlalchemy.orm import Session
@@ -37,7 +36,7 @@
@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_AUDIT_LOG)])
@provide_session
def get_event_log(*, event_log_id: int, session: Session = NEW_SESSION) -> APIResponse:
- """Get a log entry"""
+ """Get a log entry."""
event_log = session.query(Log).get(event_log_id)
if event_log is None:
raise NotFound("Event Log not found")
@@ -50,14 +49,14 @@ def get_event_log(*, event_log_id: int, session: Session = NEW_SESSION) -> APIRe
def get_event_logs(
*,
limit: int,
- offset: Optional[int] = None,
+ offset: int | None = None,
order_by: str = "event_log_id",
session: Session = NEW_SESSION,
) -> APIResponse:
- """Get all log entries from event log"""
+ """Get all log entries from event log."""
to_replace = {"event_log_id": "id", "when": "dttm"}
allowed_filter_attrs = [
- 'event_log_id',
+ "event_log_id",
"when",
"dag_id",
"task_id",
diff --git a/airflow/api_connexion/endpoints/extra_link_endpoint.py b/airflow/api_connexion/endpoints/extra_link_endpoint.py
index 3e9535603bda3..2b12667e7cdfa 100644
--- a/airflow/api_connexion/endpoints/extra_link_endpoint.py
+++ b/airflow/api_connexion/endpoints/extra_link_endpoint.py
@@ -14,8 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
-from flask import current_app
from sqlalchemy.orm.session import Session
from airflow import DAG
@@ -25,6 +25,7 @@
from airflow.exceptions import TaskNotFound
from airflow.models.dagbag import DagBag
from airflow.security import permissions
+from airflow.utils.airflow_flask_app import get_airflow_app
from airflow.utils.session import NEW_SESSION, provide_session
@@ -43,10 +44,10 @@ def get_extra_links(
task_id: str,
session: Session = NEW_SESSION,
) -> APIResponse:
- """Get extra links for task instance"""
+ """Get extra links for task instance."""
from airflow.models.taskinstance import TaskInstance
- dagbag: DagBag = current_app.dag_bag
+ dagbag: DagBag = get_airflow_app().dag_bag
dag: DAG = dagbag.get_dag(dag_id)
if not dag:
raise NotFound("DAG not found", detail=f'DAG with ID = "{dag_id}" not found')
@@ -73,6 +74,6 @@ def get_extra_links(
(link_name, task.get_extra_links(ti, link_name)) for link_name in task.extra_links
)
all_extra_links = {
- link_name: link_url if link_url else None for link_name, link_url in all_extra_link_pairs
+ link_name: link_url if link_url else None for link_name, link_url in sorted(all_extra_link_pairs)
}
return all_extra_links
diff --git a/airflow/api_connexion/endpoints/health_endpoint.py b/airflow/api_connexion/endpoints/health_endpoint.py
index 380225bf16e4d..f833a5d72815b 100644
--- a/airflow/api_connexion/endpoints/health_endpoint.py
+++ b/airflow/api_connexion/endpoints/health_endpoint.py
@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
from airflow.api_connexion.schemas.health_schema import health_schema
from airflow.api_connexion.types import APIResponse
@@ -24,7 +25,7 @@
def get_health() -> APIResponse:
- """Return the health of the airflow scheduler and metadatabase"""
+ """Return the health of the airflow scheduler and metadatabase."""
metadatabase_status = HEALTHY
latest_scheduler_heartbeat = None
scheduler_status = UNHEALTHY
diff --git a/airflow/api_connexion/endpoints/import_error_endpoint.py b/airflow/api_connexion/endpoints/import_error_endpoint.py
index 9a46b3aac4392..f5798fe99e8c4 100644
--- a/airflow/api_connexion/endpoints/import_error_endpoint.py
+++ b/airflow/api_connexion/endpoints/import_error_endpoint.py
@@ -14,8 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
-from typing import Optional
+from __future__ import annotations
from sqlalchemy import func
from sqlalchemy.orm import Session
@@ -37,7 +36,7 @@
@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_IMPORT_ERROR)])
@provide_session
def get_import_error(*, import_error_id: int, session: Session = NEW_SESSION) -> APIResponse:
- """Get an import error"""
+ """Get an import error."""
error = session.query(ImportErrorModel).get(import_error_id)
if error is None:
@@ -49,18 +48,18 @@ def get_import_error(*, import_error_id: int, session: Session = NEW_SESSION) ->
@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_IMPORT_ERROR)])
-@format_parameters({'limit': check_limit})
+@format_parameters({"limit": check_limit})
@provide_session
def get_import_errors(
*,
limit: int,
- offset: Optional[int] = None,
+ offset: int | None = None,
order_by: str = "import_error_id",
session: Session = NEW_SESSION,
) -> APIResponse:
- """Get all import errors"""
- to_replace = {"import_error_id": 'id'}
- allowed_filter_attrs = ['import_error_id', "timestamp", "filename"]
+ """Get all import errors."""
+ to_replace = {"import_error_id": "id"}
+ allowed_filter_attrs = ["import_error_id", "timestamp", "filename"]
total_entries = session.query(func.count(ImportErrorModel.id)).scalar()
query = session.query(ImportErrorModel)
query = apply_sorting(query, order_by, to_replace, allowed_filter_attrs)
diff --git a/airflow/api_connexion/endpoints/log_endpoint.py b/airflow/api_connexion/endpoints/log_endpoint.py
index f1335fe527451..388b164727e90 100644
--- a/airflow/api_connexion/endpoints/log_endpoint.py
+++ b/airflow/api_connexion/endpoints/log_endpoint.py
@@ -14,10 +14,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
-from typing import Any, Optional
+from typing import Any
-from flask import Response, current_app, request
+from flask import Response, request
from itsdangerous.exc import BadSignature
from itsdangerous.url_safe import URLSafeSerializer
from sqlalchemy.orm.session import Session
@@ -29,6 +30,7 @@
from airflow.exceptions import TaskNotFound
from airflow.models import TaskInstance
from airflow.security import permissions
+from airflow.utils.airflow_flask_app import get_airflow_app
from airflow.utils.log.log_reader import TaskLogReader
from airflow.utils.session import NEW_SESSION, provide_session
@@ -48,11 +50,12 @@ def get_log(
task_id: str,
task_try_number: int,
full_content: bool = False,
- token: Optional[str] = None,
+ map_index: int = -1,
+ token: str | None = None,
session: Session = NEW_SESSION,
) -> APIResponse:
- """Get logs for specific task instance"""
- key = current_app.config["SECRET_KEY"]
+ """Get logs for specific task instance."""
+ key = get_airflow_app().config["SECRET_KEY"]
if not token:
metadata = {}
else:
@@ -61,47 +64,48 @@ def get_log(
except BadSignature:
raise BadRequest("Bad Signature. Please use only the tokens provided by the API.")
- if metadata.get('download_logs') and metadata['download_logs']:
+ if metadata.get("download_logs") and metadata["download_logs"]:
full_content = True
if full_content:
- metadata['download_logs'] = True
+ metadata["download_logs"] = True
else:
- metadata['download_logs'] = False
+ metadata["download_logs"] = False
task_log_reader = TaskLogReader()
if not task_log_reader.supports_read:
raise BadRequest("Task log handler does not support read logs.")
-
ti = (
session.query(TaskInstance)
.filter(
TaskInstance.task_id == task_id,
TaskInstance.dag_id == dag_id,
TaskInstance.run_id == dag_run_id,
+ TaskInstance.map_index == map_index,
)
.join(TaskInstance.dag_run)
.one_or_none()
)
if ti is None:
- metadata['end_of_log'] = True
+ metadata["end_of_log"] = True
raise NotFound(title="TaskInstance not found")
- dag = current_app.dag_bag.get_dag(dag_id)
+ dag = get_airflow_app().dag_bag.get_dag(dag_id)
if dag:
try:
ti.task = dag.get_task(ti.task_id)
except TaskNotFound:
pass
- return_type = request.accept_mimetypes.best_match(['text/plain', 'application/json'])
+ return_type = request.accept_mimetypes.best_match(["text/plain", "application/json"])
# return_type would be either the above two or None
logs: Any
- if return_type == 'application/json' or return_type is None: # default
+ if return_type == "application/json" or return_type is None: # default
logs, metadata = task_log_reader.read_log_chunks(ti, task_try_number, metadata)
logs = logs[0] if task_try_number is not None else logs
- token = URLSafeSerializer(key).dumps(metadata)
+ # we must have token here, so we can safely ignore it
+ token = URLSafeSerializer(key).dumps(metadata) # type: ignore[assignment]
return logs_schema.dump(LogResponseObject(continuation_token=token, content=logs))
# text/plain. Stream
logs = task_log_reader.read_log_stream(ti, task_try_number, metadata)
diff --git a/airflow/api_connexion/endpoints/plugin_endpoint.py b/airflow/api_connexion/endpoints/plugin_endpoint.py
index 4b4b17bb96105..a2febda42c9b8 100644
--- a/airflow/api_connexion/endpoints/plugin_endpoint.py
+++ b/airflow/api_connexion/endpoints/plugin_endpoint.py
@@ -14,6 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
from airflow.api_connexion import security
from airflow.api_connexion.parameters import check_limit, format_parameters
from airflow.api_connexion.schemas.plugin_schema import PluginCollection, plugin_collection_schema
@@ -25,9 +27,7 @@
@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_PLUGIN)])
@format_parameters({"limit": check_limit})
def get_plugins(*, limit: int, offset: int = 0) -> APIResponse:
- """Get plugins endpoint"""
+ """Get plugins endpoint."""
plugins_info = get_plugin_info()
- total_entries = len(plugins_info)
- plugins_info = plugins_info[offset:]
- plugins_info = plugins_info[:limit]
- return plugin_collection_schema.dump(PluginCollection(plugins=plugins_info, total_entries=total_entries))
+ collection = PluginCollection(plugins=plugins_info[offset:][:limit], total_entries=len(plugins_info))
+ return plugin_collection_schema.dump(collection)
diff --git a/airflow/api_connexion/endpoints/pool_endpoint.py b/airflow/api_connexion/endpoints/pool_endpoint.py
index 1d24fea63d756..e6f7903e0ca4d 100644
--- a/airflow/api_connexion/endpoints/pool_endpoint.py
+++ b/airflow/api_connexion/endpoints/pool_endpoint.py
@@ -14,17 +14,18 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
from http import HTTPStatus
-from typing import Optional
-from flask import Response, request
+from flask import Response
from marshmallow import ValidationError
from sqlalchemy import func
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from airflow.api_connexion import security
+from airflow.api_connexion.endpoints.request_dict import get_json_request_dict
from airflow.api_connexion.exceptions import AlreadyExists, BadRequest, NotFound
from airflow.api_connexion.parameters import apply_sorting, check_limit, format_parameters
from airflow.api_connexion.schemas.pool_schema import PoolCollection, pool_collection_schema, pool_schema
@@ -37,7 +38,7 @@
@security.requires_access([(permissions.ACTION_CAN_DELETE, permissions.RESOURCE_POOL)])
@provide_session
def delete_pool(*, pool_name: str, session: Session = NEW_SESSION) -> APIResponse:
- """Delete a pool"""
+ """Delete a pool."""
if pool_name == "default_pool":
raise BadRequest(detail="Default Pool can't be deleted")
affected_count = session.query(Pool).filter(Pool.pool == pool_name).delete()
@@ -49,7 +50,7 @@ def delete_pool(*, pool_name: str, session: Session = NEW_SESSION) -> APIRespons
@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_POOL)])
@provide_session
def get_pool(*, pool_name: str, session: Session = NEW_SESSION) -> APIResponse:
- """Get a pool"""
+ """Get a pool."""
obj = session.query(Pool).filter(Pool.pool == pool_name).one_or_none()
if obj is None:
raise NotFound(detail=f"Pool with name:'{pool_name}' not found")
@@ -63,12 +64,12 @@ def get_pools(
*,
limit: int,
order_by: str = "id",
- offset: Optional[int] = None,
+ offset: int | None = None,
session: Session = NEW_SESSION,
) -> APIResponse:
- """Get all pools"""
+ """Get all pools."""
to_replace = {"name": "pool"}
- allowed_filter_attrs = ['name', 'slots', "id"]
+ allowed_filter_attrs = ["name", "slots", "id"]
total_entries = session.query(func.count(Pool.id)).scalar()
query = session.query(Pool)
query = apply_sorting(query, order_by, to_replace, allowed_filter_attrs)
@@ -84,10 +85,11 @@ def patch_pool(
update_mask: UpdateMask = None,
session: Session = NEW_SESSION,
) -> APIResponse:
- """Update a pool"""
+ """Update a pool."""
+ request_dict = get_json_request_dict()
# Only slots can be modified in 'default_pool'
try:
- if pool_name == Pool.DEFAULT_POOL_NAME and request.json["name"] != Pool.DEFAULT_POOL_NAME:
+ if pool_name == Pool.DEFAULT_POOL_NAME and request_dict["name"] != Pool.DEFAULT_POOL_NAME:
if update_mask and len(update_mask) == 1 and update_mask[0].strip() == "slots":
pass
else:
@@ -100,7 +102,7 @@ def patch_pool(
raise NotFound(detail=f"Pool with name:'{pool_name}' not found")
try:
- patch_body = pool_schema.load(request.json)
+ patch_body = pool_schema.load(request_dict)
except ValidationError as err:
raise BadRequest(detail=str(err.messages))
@@ -121,7 +123,7 @@ def patch_pool(
else:
required_fields = {"name", "slots"}
- fields_diff = required_fields - set(request.json.keys())
+ fields_diff = required_fields - set(get_json_request_dict().keys())
if fields_diff:
raise BadRequest(detail=f"Missing required property(ies): {sorted(fields_diff)}")
@@ -134,14 +136,14 @@ def patch_pool(
@security.requires_access([(permissions.ACTION_CAN_CREATE, permissions.RESOURCE_POOL)])
@provide_session
def post_pool(*, session: Session = NEW_SESSION) -> APIResponse:
- """Create a pool"""
+ """Create a pool."""
required_fields = {"name", "slots"} # Pool would require both fields in the post request
- fields_diff = required_fields - set(request.json.keys())
+ fields_diff = required_fields - set(get_json_request_dict().keys())
if fields_diff:
raise BadRequest(detail=f"Missing required property(ies): {sorted(fields_diff)}")
try:
- post_body = pool_schema.load(request.json, session=session)
+ post_body = pool_schema.load(get_json_request_dict(), session=session)
except ValidationError as err:
raise BadRequest(detail=str(err.messages))
diff --git a/airflow/api_connexion/endpoints/provider_endpoint.py b/airflow/api_connexion/endpoints/provider_endpoint.py
index 7526e284beb06..c829d9c968d61 100644
--- a/airflow/api_connexion/endpoints/provider_endpoint.py
+++ b/airflow/api_connexion/endpoints/provider_endpoint.py
@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
import re
@@ -42,7 +43,7 @@ def _provider_mapper(provider: ProviderInfo) -> Provider:
@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_PROVIDER)])
def get_providers() -> APIResponse:
- """Get providers"""
+ """Get providers."""
providers = [_provider_mapper(d) for d in ProvidersManager().providers.values()]
total_entries = len(providers)
return provider_collection_schema.dump(
diff --git a/airflow/api_connexion/endpoints/request_dict.py b/airflow/api_connexion/endpoints/request_dict.py
new file mode 100644
index 0000000000000..b07e06c0b63f8
--- /dev/null
+++ b/airflow/api_connexion/endpoints/request_dict.py
@@ -0,0 +1,26 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import Any, Mapping, cast
+
+
+def get_json_request_dict() -> Mapping[str, Any]:
+ """Cast request dictionary to JSON."""
+ from flask import request
+
+ return cast(Mapping[str, Any], request.get_json())
diff --git a/airflow/api_connexion/endpoints/role_and_permission_endpoint.py b/airflow/api_connexion/endpoints/role_and_permission_endpoint.py
index 88a68341c129e..4ed40caae508c 100644
--- a/airflow/api_connexion/endpoints/role_and_permission_endpoint.py
+++ b/airflow/api_connexion/endpoints/role_and_permission_endpoint.py
@@ -14,12 +14,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
from http import HTTPStatus
-from typing import List, Optional, Tuple
from connexion import NoContent
-from flask import current_app, request
+from flask import request
from marshmallow import ValidationError
from sqlalchemy import asc, desc, func
@@ -35,13 +35,14 @@
)
from airflow.api_connexion.types import APIResponse, UpdateMask
from airflow.security import permissions
+from airflow.utils.airflow_flask_app import get_airflow_app
from airflow.www.fab_security.sqla.models import Action, Role
from airflow.www.security import AirflowSecurityManager
-def _check_action_and_resource(sm: AirflowSecurityManager, perms: List[Tuple[str, str]]) -> None:
+def _check_action_and_resource(sm: AirflowSecurityManager, perms: list[tuple[str, str]]) -> None:
"""
- Checks if the action or resource exists and raise 400 if not
+ Checks if the action or resource exists and otherwise raise 400.
This function is intended for use in the REST API because it raise 400
"""
@@ -54,8 +55,8 @@ def _check_action_and_resource(sm: AirflowSecurityManager, perms: List[Tuple[str
@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_ROLE)])
def get_role(*, role_name: str) -> APIResponse:
- """Get role"""
- ab_security_manager = current_app.appbuilder.sm
+ """Get role."""
+ ab_security_manager = get_airflow_app().appbuilder.sm
role = ab_security_manager.find_role(name=role_name)
if not role:
raise NotFound(title="Role not found", detail=f"Role with name {role_name!r} was not found")
@@ -64,9 +65,9 @@ def get_role(*, role_name: str) -> APIResponse:
@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_ROLE)])
@format_parameters({"limit": check_limit})
-def get_roles(*, order_by: str = "name", limit: int, offset: Optional[int] = None) -> APIResponse:
- """Get roles"""
- appbuilder = current_app.appbuilder
+def get_roles(*, order_by: str = "name", limit: int, offset: int | None = None) -> APIResponse:
+ """Get roles."""
+ appbuilder = get_airflow_app().appbuilder
session = appbuilder.get_session
total_entries = session.query(func.count(Role.id)).scalar()
direction = desc if order_by.startswith("-") else asc
@@ -87,10 +88,10 @@ def get_roles(*, order_by: str = "name", limit: int, offset: Optional[int] = Non
@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_ACTION)])
-@format_parameters({'limit': check_limit})
-def get_permissions(*, limit: int, offset: Optional[int] = None) -> APIResponse:
- """Get permissions"""
- session = current_app.appbuilder.get_session
+@format_parameters({"limit": check_limit})
+def get_permissions(*, limit: int, offset: int | None = None) -> APIResponse:
+ """Get permissions."""
+ session = get_airflow_app().appbuilder.get_session
total_entries = session.query(func.count(Action.id)).scalar()
query = session.query(Action)
actions = query.offset(offset).limit(limit).all()
@@ -99,8 +100,8 @@ def get_permissions(*, limit: int, offset: Optional[int] = None) -> APIResponse:
@security.requires_access([(permissions.ACTION_CAN_DELETE, permissions.RESOURCE_ROLE)])
def delete_role(*, role_name: str) -> APIResponse:
- """Delete a role"""
- ab_security_manager = current_app.appbuilder.sm
+ """Delete a role."""
+ ab_security_manager = get_airflow_app().appbuilder.sm
role = ab_security_manager.find_role(name=role_name)
if not role:
raise NotFound(title="Role not found", detail=f"Role with name {role_name!r} was not found")
@@ -110,8 +111,8 @@ def delete_role(*, role_name: str) -> APIResponse:
@security.requires_access([(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_ROLE)])
def patch_role(*, role_name: str, update_mask: UpdateMask = None) -> APIResponse:
- """Update a role"""
- appbuilder = current_app.appbuilder
+ """Update a role."""
+ appbuilder = get_airflow_app().appbuilder
security_manager = appbuilder.sm
body = request.json
try:
@@ -128,7 +129,7 @@ def patch_role(*, role_name: str, update_mask: UpdateMask = None) -> APIResponse
if field in data and not field == "permissions":
data_[field] = data[field]
elif field == "actions":
- data_["permissions"] = data['permissions']
+ data_["permissions"] = data["permissions"]
else:
raise BadRequest(detail=f"'{field}' in update_mask is unknown")
data = data_
@@ -144,17 +145,17 @@ def patch_role(*, role_name: str, update_mask: UpdateMask = None) -> APIResponse
@security.requires_access([(permissions.ACTION_CAN_CREATE, permissions.RESOURCE_ROLE)])
def post_role() -> APIResponse:
- """Create a new role"""
- appbuilder = current_app.appbuilder
+ """Create a new role."""
+ appbuilder = get_airflow_app().appbuilder
security_manager = appbuilder.sm
body = request.json
try:
data = role_schema.load(body)
except ValidationError as err:
raise BadRequest(detail=str(err.messages))
- role = security_manager.find_role(name=data['name'])
+ role = security_manager.find_role(name=data["name"])
if not role:
- perms = [(item['action']['name'], item['resource']['name']) for item in data['permissions'] if item]
+ perms = [(item["action"]["name"], item["resource"]["name"]) for item in data["permissions"] if item]
_check_action_and_resource(security_manager, perms)
security_manager.bulk_sync_roles([{"role": data["name"], "perms": perms}])
return role_schema.dump(role)
diff --git a/airflow/api_connexion/endpoints/task_endpoint.py b/airflow/api_connexion/endpoints/task_endpoint.py
index 28c39b000c28d..23c2b32487b31 100644
--- a/airflow/api_connexion/endpoints/task_endpoint.py
+++ b/airflow/api_connexion/endpoints/task_endpoint.py
@@ -14,9 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from operator import attrgetter
+from __future__ import annotations
-from flask import current_app
+from operator import attrgetter
from airflow import DAG
from airflow.api_connexion import security
@@ -25,6 +25,7 @@
from airflow.api_connexion.types import APIResponse
from airflow.exceptions import TaskNotFound
from airflow.security import permissions
+from airflow.utils.airflow_flask_app import get_airflow_app
@security.requires_access(
@@ -35,7 +36,7 @@
)
def get_task(*, dag_id: str, task_id: str) -> APIResponse:
"""Get simplified representation of a task."""
- dag: DAG = current_app.dag_bag.get_dag(dag_id)
+ dag: DAG = get_airflow_app().dag_bag.get_dag(dag_id)
if not dag:
raise NotFound("DAG not found")
@@ -53,14 +54,14 @@ def get_task(*, dag_id: str, task_id: str) -> APIResponse:
],
)
def get_tasks(*, dag_id: str, order_by: str = "task_id") -> APIResponse:
- """Get tasks for DAG"""
- dag: DAG = current_app.dag_bag.get_dag(dag_id)
+ """Get tasks for DAG."""
+ dag: DAG = get_airflow_app().dag_bag.get_dag(dag_id)
if not dag:
raise NotFound("DAG not found")
tasks = dag.tasks
try:
- tasks = sorted(tasks, key=attrgetter(order_by.lstrip('-')), reverse=(order_by[0:1] == '-'))
+ tasks = sorted(tasks, key=attrgetter(order_by.lstrip("-")), reverse=(order_by[0:1] == "-"))
except AttributeError as err:
raise BadRequest(detail=str(err))
task_collection = TaskCollection(tasks=tasks, total_entries=len(tasks))
diff --git a/airflow/api_connexion/endpoints/task_instance_endpoint.py b/airflow/api_connexion/endpoints/task_instance_endpoint.py
index c2416ab0d9d44..9d5d54ba5809c 100644
--- a/airflow/api_connexion/endpoints/task_instance_endpoint.py
+++ b/airflow/api_connexion/endpoints/task_instance_endpoint.py
@@ -14,9 +14,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Iterable, List, Optional, Tuple, TypeVar
+from __future__ import annotations
+
+from typing import Any, Iterable, TypeVar
-from flask import current_app, request
from marshmallow import ValidationError
from sqlalchemy import and_, func, or_
from sqlalchemy.exc import MultipleResultsFound
@@ -25,23 +26,29 @@
from sqlalchemy.sql import ClauseElement
from airflow.api_connexion import security
+from airflow.api_connexion.endpoints.request_dict import get_json_request_dict
from airflow.api_connexion.exceptions import BadRequest, NotFound
from airflow.api_connexion.parameters import format_datetime, format_parameters
from airflow.api_connexion.schemas.task_instance_schema import (
TaskInstanceCollection,
TaskInstanceReferenceCollection,
clear_task_instance_form,
+ set_single_task_instance_state_form,
+ set_task_instance_note_form_schema,
set_task_instance_state_form,
task_instance_batch_form,
task_instance_collection_schema,
task_instance_reference_collection_schema,
+ task_instance_reference_schema,
task_instance_schema,
)
from airflow.api_connexion.types import APIResponse
from airflow.models import SlaMiss
from airflow.models.dagrun import DagRun as DR
+from airflow.models.operator import needs_expansion
from airflow.models.taskinstance import TaskInstance as TI, clear_task_instances
from airflow.security import permissions
+from airflow.utils.airflow_flask_app import get_airflow_app
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import DagRunState, State
@@ -63,7 +70,7 @@ def get_task_instance(
task_id: str,
session: Session = NEW_SESSION,
) -> APIResponse:
- """Get task instance"""
+ """Get task instance."""
query = (
session.query(TI)
.filter(TI.dag_id == dag_id, TI.run_id == dag_run_id, TI.task_id == task_id)
@@ -112,7 +119,7 @@ def get_mapped_task_instance(
map_index: int,
session: Session = NEW_SESSION,
) -> APIResponse:
- """Get task instance"""
+ """Get task instance."""
query = (
session.query(TI)
.filter(
@@ -160,20 +167,20 @@ def get_mapped_task_instances(
dag_id: str,
dag_run_id: str,
task_id: str,
- execution_date_gte: Optional[str] = None,
- execution_date_lte: Optional[str] = None,
- start_date_gte: Optional[str] = None,
- start_date_lte: Optional[str] = None,
- end_date_gte: Optional[str] = None,
- end_date_lte: Optional[str] = None,
- duration_gte: Optional[float] = None,
- duration_lte: Optional[float] = None,
- state: Optional[List[str]] = None,
- pool: Optional[List[str]] = None,
- queue: Optional[List[str]] = None,
- limit: Optional[int] = None,
- offset: Optional[int] = None,
- order_by: Optional[str] = None,
+ execution_date_gte: str | None = None,
+ execution_date_lte: str | None = None,
+ start_date_gte: str | None = None,
+ start_date_lte: str | None = None,
+ end_date_gte: str | None = None,
+ end_date_lte: str | None = None,
+ duration_gte: float | None = None,
+ duration_lte: float | None = None,
+ state: list[str] | None = None,
+ pool: list[str] | None = None,
+ queue: list[str] | None = None,
+ limit: int | None = None,
+ offset: int | None = None,
+ order_by: str | None = None,
session: Session = NEW_SESSION,
) -> APIResponse:
"""Get list of task instances."""
@@ -187,8 +194,8 @@ def get_mapped_task_instances(
)
# 0 can mean a mapped TI that expanded to an empty list, so it is not an automatic 404
- if base_query.with_entities(func.count('*')).scalar() == 0:
- dag = current_app.dag_bag.get_dag(dag_id)
+ if base_query.with_entities(func.count("*")).scalar() == 0:
+ dag = get_airflow_app().dag_bag.get_dag(dag_id)
if not dag:
error_message = f"DAG {dag_id} not found"
raise NotFound(error_message)
@@ -196,7 +203,7 @@ def get_mapped_task_instances(
if not task:
error_message = f"Task id {task_id} not found"
raise NotFound(error_message)
- if not task.is_mapped:
+ if not needs_expansion(task):
error_message = f"Task id {task_id} is not mapped"
raise NotFound(error_message)
@@ -214,7 +221,7 @@ def get_mapped_task_instances(
query = _apply_array_filter(query, key=TI.queue, values=queue)
# Count elements before joining extra columns
- total_entries = query.with_entities(func.count('*')).scalar()
+ total_entries = query.with_entities(func.count("*")).scalar()
# Add SLA miss
query = (
@@ -232,11 +239,11 @@ def get_mapped_task_instances(
)
if order_by:
- if order_by == 'state':
+ if order_by == "state":
query = query.order_by(TI.state.asc(), TI.map_index.asc())
- elif order_by == '-state':
+ elif order_by == "-state":
query = query.order_by(TI.state.desc(), TI.map_index.asc())
- elif order_by == '-map_index':
+ elif order_by == "-map_index":
query = query.order_by(TI.map_index.desc())
else:
raise BadRequest(detail=f"Ordering with '{order_by}' is not supported")
@@ -249,20 +256,20 @@ def get_mapped_task_instances(
)
-def _convert_state(states: Optional[Iterable[str]]) -> Optional[List[Optional[str]]]:
+def _convert_state(states: Iterable[str] | None) -> list[str | None] | None:
if not states:
return None
return [State.NONE if s == "none" else s for s in states]
-def _apply_array_filter(query: Query, key: ClauseElement, values: Optional[Iterable[Any]]) -> Query:
+def _apply_array_filter(query: Query, key: ClauseElement, values: Iterable[Any] | None) -> Query:
if values is not None:
cond = ((key == v) for v in values)
query = query.filter(or_(*cond))
return query
-def _apply_range_filter(query: Query, key: ClauseElement, value_range: Tuple[T, T]) -> Query:
+def _apply_range_filter(query: Query, key: ClauseElement, value_range: tuple[T, T]) -> Query:
gte_value, lte_value = value_range
if gte_value is not None:
query = query.filter(key >= gte_value)
@@ -292,20 +299,20 @@ def _apply_range_filter(query: Query, key: ClauseElement, value_range: Tuple[T,
def get_task_instances(
*,
limit: int,
- dag_id: Optional[str] = None,
- dag_run_id: Optional[str] = None,
- execution_date_gte: Optional[str] = None,
- execution_date_lte: Optional[str] = None,
- start_date_gte: Optional[str] = None,
- start_date_lte: Optional[str] = None,
- end_date_gte: Optional[str] = None,
- end_date_lte: Optional[str] = None,
- duration_gte: Optional[float] = None,
- duration_lte: Optional[float] = None,
- state: Optional[List[str]] = None,
- pool: Optional[List[str]] = None,
- queue: Optional[List[str]] = None,
- offset: Optional[int] = None,
+ dag_id: str | None = None,
+ dag_run_id: str | None = None,
+ execution_date_gte: str | None = None,
+ execution_date_lte: str | None = None,
+ start_date_gte: str | None = None,
+ start_date_lte: str | None = None,
+ end_date_gte: str | None = None,
+ end_date_lte: str | None = None,
+ duration_gte: float | None = None,
+ duration_lte: float | None = None,
+ state: list[str] | None = None,
+ pool: list[str] | None = None,
+ queue: list[str] | None = None,
+ offset: int | None = None,
session: Session = NEW_SESSION,
) -> APIResponse:
"""Get list of task instances."""
@@ -333,7 +340,7 @@ def get_task_instances(
base_query = _apply_array_filter(base_query, key=TI.queue, values=queue)
# Count elements before joining extra columns
- total_entries = base_query.with_entities(func.count('*')).scalar()
+ total_entries = base_query.with_entities(func.count("*")).scalar()
# Add join
query = (
base_query.join(
@@ -364,12 +371,12 @@ def get_task_instances(
@provide_session
def get_task_instances_batch(session: Session = NEW_SESSION) -> APIResponse:
"""Get list of task instances."""
- body = request.get_json()
+ body = get_json_request_dict()
try:
data = task_instance_batch_form.load(body)
except ValidationError as err:
raise BadRequest(detail=str(err.messages))
- states = _convert_state(data['state'])
+ states = _convert_state(data["state"])
base_query = session.query(TI).join(TI.dag_run)
base_query = _apply_array_filter(base_query, key=TI.dag_id, values=data["dag_ids"])
@@ -394,7 +401,7 @@ def get_task_instances_batch(session: Session = NEW_SESSION) -> APIResponse:
base_query = _apply_array_filter(base_query, key=TI.queue, values=data["queue"])
# Count elements before joining extra columns
- total_entries = base_query.with_entities(func.count('*')).scalar()
+ total_entries = base_query.with_entities(func.count("*")).scalar()
# Add join
base_query = base_query.join(
SlaMiss,
@@ -423,20 +430,50 @@ def get_task_instances_batch(session: Session = NEW_SESSION) -> APIResponse:
@provide_session
def post_clear_task_instances(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse:
"""Clear task instances."""
- body = request.get_json()
+ body = get_json_request_dict()
try:
data = clear_task_instance_form.load(body)
except ValidationError as err:
raise BadRequest(detail=str(err.messages))
- dag = current_app.dag_bag.get_dag(dag_id)
+ dag = get_airflow_app().dag_bag.get_dag(dag_id)
if not dag:
error_message = f"Dag id {dag_id} not found"
raise NotFound(error_message)
- reset_dag_runs = data.pop('reset_dag_runs')
- dry_run = data.pop('dry_run')
+ reset_dag_runs = data.pop("reset_dag_runs")
+ dry_run = data.pop("dry_run")
# We always pass dry_run here, otherwise this would try to confirm on the terminal!
- task_instances = dag.clear(dry_run=True, dag_bag=current_app.dag_bag, **data)
+ dag_run_id = data.pop("dag_run_id", None)
+ future = data.pop("include_future", False)
+ past = data.pop("include_past", False)
+ downstream = data.pop("include_downstream", False)
+ upstream = data.pop("include_upstream", False)
+ if dag_run_id is not None:
+ dag_run: DR | None = (
+ session.query(DR).filter(DR.dag_id == dag_id, DR.run_id == dag_run_id).one_or_none()
+ )
+ if dag_run is None:
+ error_message = f"Dag Run id {dag_run_id} not found in dag {dag_id}"
+ raise NotFound(error_message)
+ data["start_date"] = dag_run.logical_date
+ data["end_date"] = dag_run.logical_date
+ if past:
+ data["start_date"] = None
+ if future:
+ data["end_date"] = None
+ task_ids = data.pop("task_ids", None)
+ if task_ids is not None:
+ task_id = [task[0] if isinstance(task, tuple) else task for task in task_ids]
+ dag = dag.partial_subset(
+ task_ids_or_regex=task_id,
+ include_downstream=downstream,
+ include_upstream=upstream,
+ )
+
+ if len(dag.task_dict) > 1:
+ # If we had upstream/downstream etc then also include those!
+ task_ids.extend(tid for tid in dag.task_dict if tid != task_id)
+ task_instances = dag.clear(dry_run=True, dag_bag=get_airflow_app().dag_bag, task_ids=task_ids, **data)
if not dry_run:
clear_task_instances(
task_instances.all(),
@@ -460,26 +497,26 @@ def post_clear_task_instances(*, dag_id: str, session: Session = NEW_SESSION) ->
@provide_session
def post_set_task_instances_state(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse:
"""Set a state of task instances."""
- body = request.get_json()
+ body = get_json_request_dict()
try:
data = set_task_instance_state_form.load(body)
except ValidationError as err:
raise BadRequest(detail=str(err.messages))
error_message = f"Dag ID {dag_id} not found"
- dag = current_app.dag_bag.get_dag(dag_id)
+ dag = get_airflow_app().dag_bag.get_dag(dag_id)
if not dag:
raise NotFound(error_message)
- task_id = data['task_id']
+ task_id = data["task_id"]
task = dag.task_dict.get(task_id)
if not task:
error_message = f"Task ID {task_id} not found"
raise NotFound(error_message)
- execution_date = data.get('execution_date')
- run_id = data.get('dag_run_id')
+ execution_date = data.get("execution_date")
+ run_id = data.get("dag_run_id")
if (
execution_date
and (
@@ -494,7 +531,7 @@ def post_set_task_instances_state(*, dag_id: str, session: Session = NEW_SESSION
)
if run_id and not session.query(TI).get(
- {'task_id': task_id, 'dag_id': dag_id, 'run_id': run_id, 'map_index': -1}
+ {"task_id": task_id, "dag_id": dag_id, "run_id": run_id, "map_index": -1}
):
error_message = f"Task instance not found for task {task_id!r} on DAG run with ID {run_id!r}"
raise NotFound(detail=error_message)
@@ -512,3 +549,135 @@ def post_set_task_instances_state(*, dag_id: str, session: Session = NEW_SESSION
session=session,
)
return task_instance_reference_collection_schema.dump(TaskInstanceReferenceCollection(task_instances=tis))
+
+
+def set_mapped_task_instance_note(
+ *, dag_id: str, dag_run_id: str, task_id: str, map_index: int
+) -> APIResponse:
+ """Set the note for a Mapped Task instance."""
+ return set_task_instance_note(dag_id=dag_id, dag_run_id=dag_run_id, task_id=task_id, map_index=map_index)
+
+
+@security.requires_access(
+ [
+ (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG),
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN),
+ (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE),
+ ],
+)
+@provide_session
+def patch_task_instance(
+ *, dag_id: str, dag_run_id: str, task_id: str, map_index: int = -1, session: Session = NEW_SESSION
+) -> APIResponse:
+ """Update the state of a task instance."""
+ body = get_json_request_dict()
+ try:
+ data = set_single_task_instance_state_form.load(body)
+ except ValidationError as err:
+ raise BadRequest(detail=str(err.messages))
+
+ dag = get_airflow_app().dag_bag.get_dag(dag_id)
+ if not dag:
+ raise NotFound("DAG not found", detail=f"DAG {dag_id!r} not found")
+
+ if not dag.has_task(task_id):
+ raise NotFound("Task not found", detail=f"Task {task_id!r} not found in DAG {dag_id!r}")
+
+ ti: TI | None = session.query(TI).get(
+ {"task_id": task_id, "dag_id": dag_id, "run_id": dag_run_id, "map_index": map_index}
+ )
+
+ if not ti:
+ error_message = f"Task instance not found for task {task_id!r} on DAG run with ID {dag_run_id!r}"
+ raise NotFound(detail=error_message)
+
+ if not data["dry_run"]:
+ ti = dag.set_task_instance_state(
+ task_id=task_id,
+ run_id=dag_run_id,
+ map_indexes=[map_index],
+ state=data["new_state"],
+ commit=True,
+ session=session,
+ )
+
+ return task_instance_reference_schema.dump(ti)
+
+
+@security.requires_access(
+ [
+ (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG),
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN),
+ (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE),
+ ],
+)
+@provide_session
+def patch_mapped_task_instance(
+ *, dag_id: str, dag_run_id: str, task_id: str, map_index: int, session: Session = NEW_SESSION
+) -> APIResponse:
+ """Update the state of a mapped task instance."""
+ return patch_task_instance(
+ dag_id=dag_id, dag_run_id=dag_run_id, task_id=task_id, map_index=map_index, session=session
+ )
+
+
+@security.requires_access(
+ [
+ (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG),
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN),
+ (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
+ (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE),
+ ],
+)
+@provide_session
+def set_task_instance_note(
+ *, dag_id: str, dag_run_id: str, task_id: str, map_index: int = -1, session: Session = NEW_SESSION
+) -> APIResponse:
+ """Set the note for a Task instance. This supports both Mapped and non-Mapped Task instances."""
+ try:
+ post_body = set_task_instance_note_form_schema.load(get_json_request_dict())
+ new_note = post_body["note"]
+ except ValidationError as err:
+ raise BadRequest(detail=str(err))
+
+ query = (
+ session.query(TI)
+ .filter(TI.dag_id == dag_id, TI.run_id == dag_run_id, TI.task_id == task_id)
+ .join(TI.dag_run)
+ .outerjoin(
+ SlaMiss,
+ and_(
+ SlaMiss.dag_id == TI.dag_id,
+ SlaMiss.execution_date == DR.execution_date,
+ SlaMiss.task_id == TI.task_id,
+ ),
+ )
+ .add_entity(SlaMiss)
+ .options(joinedload(TI.rendered_task_instance_fields))
+ )
+ if map_index == -1:
+ query = query.filter(or_(TI.map_index == -1, TI.map_index is None))
+ else:
+ query = query.filter(TI.map_index == map_index)
+
+ try:
+ result = query.one_or_none()
+ except MultipleResultsFound:
+ raise NotFound(
+ "Task instance not found", detail="Task instance is mapped, add the map_index value to the URL"
+ )
+ if result is None:
+ error_message = f"Task Instance not found for dag_id={dag_id}, run_id={dag_run_id}, task_id={task_id}"
+ raise NotFound(error_message)
+
+ ti, sla_miss = result
+ from flask_login import current_user
+
+ current_user_id = getattr(current_user, "id", None)
+ if ti.task_instance_note is None:
+ ti.note = (new_note, current_user_id)
+ else:
+ ti.task_instance_note.content = new_note
+ ti.task_instance_note.user_id = current_user_id
+ session.commit()
+ return task_instance_schema.dump((ti, sla_miss))
diff --git a/airflow/api_connexion/endpoints/user_endpoint.py b/airflow/api_connexion/endpoints/user_endpoint.py
index 6b4e984a69559..506e11e00612c 100644
--- a/airflow/api_connexion/endpoints/user_endpoint.py
+++ b/airflow/api_connexion/endpoints/user_endpoint.py
@@ -14,11 +14,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
from http import HTTPStatus
-from typing import List, Optional
from connexion import NoContent
-from flask import current_app, request
+from flask import request
from marshmallow import ValidationError
from sqlalchemy import asc, desc, func
from werkzeug.security import generate_password_hash
@@ -34,13 +35,14 @@
)
from airflow.api_connexion.types import APIResponse, UpdateMask
from airflow.security import permissions
+from airflow.utils.airflow_flask_app import get_airflow_app
from airflow.www.fab_security.sqla.models import Role, User
@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_USER)])
def get_user(*, username: str) -> APIResponse:
- """Get a user"""
- ab_security_manager = current_app.appbuilder.sm
+ """Get a user."""
+ ab_security_manager = get_airflow_app().appbuilder.sm
user = ab_security_manager.find_user(username=username)
if not user:
raise NotFound(title="User not found", detail=f"The User with username `{username}` was not found")
@@ -49,9 +51,9 @@ def get_user(*, username: str) -> APIResponse:
@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_USER)])
@format_parameters({"limit": check_limit})
-def get_users(*, limit: int, order_by: str = "id", offset: Optional[str] = None) -> APIResponse:
- """Get users"""
- appbuilder = current_app.appbuilder
+def get_users(*, limit: int, order_by: str = "id", offset: str | None = None) -> APIResponse:
+ """Get users."""
+ appbuilder = get_airflow_app().appbuilder
session = appbuilder.get_session
total_entries = session.query(func.count(User.id)).scalar()
direction = desc if order_by.startswith("-") else asc
@@ -59,7 +61,7 @@ def get_users(*, limit: int, order_by: str = "id", offset: Optional[str] = None)
order_param = order_by.strip("-")
order_param = to_replace.get(order_param, order_param)
allowed_filter_attrs = [
- 'id',
+ "id",
"first_name",
"last_name",
"user_name",
@@ -81,13 +83,13 @@ def get_users(*, limit: int, order_by: str = "id", offset: Optional[str] = None)
@security.requires_access([(permissions.ACTION_CAN_CREATE, permissions.RESOURCE_USER)])
def post_user() -> APIResponse:
- """Create a new user"""
+ """Create a new user."""
try:
data = user_schema.load(request.json)
except ValidationError as e:
raise BadRequest(detail=str(e.messages))
- security_manager = current_app.appbuilder.sm
+ security_manager = get_airflow_app().appbuilder.sm
username = data["username"]
email = data["email"]
@@ -124,26 +126,26 @@ def post_user() -> APIResponse:
@security.requires_access([(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_USER)])
def patch_user(*, username: str, update_mask: UpdateMask = None) -> APIResponse:
- """Update a user"""
+ """Update a user."""
try:
data = user_schema.load(request.json)
except ValidationError as e:
raise BadRequest(detail=str(e.messages))
- security_manager = current_app.appbuilder.sm
+ security_manager = get_airflow_app().appbuilder.sm
user = security_manager.find_user(username=username)
if user is None:
detail = f"The User with username `{username}` was not found"
raise NotFound(title="User not found", detail=detail)
# Check unique username
- new_username = data.get('username')
+ new_username = data.get("username")
if new_username and new_username != username:
if security_manager.find_user(username=new_username):
raise AlreadyExists(detail=f"The username `{new_username}` already exists")
# Check unique email
- email = data.get('email')
+ email = data.get("email")
if email and email != user.email:
if security_manager.find_user(email=email):
raise AlreadyExists(detail=f"The email `{email}` already exists")
@@ -163,7 +165,7 @@ def patch_user(*, username: str, update_mask: UpdateMask = None) -> APIResponse:
raise BadRequest(detail=detail)
data = masked_data
- roles_to_update: Optional[List[Role]]
+ roles_to_update: list[Role] | None
if "roles" in data:
roles_to_update = []
missing_role_names = []
@@ -193,8 +195,8 @@ def patch_user(*, username: str, update_mask: UpdateMask = None) -> APIResponse:
@security.requires_access([(permissions.ACTION_CAN_DELETE, permissions.RESOURCE_USER)])
def delete_user(*, username: str) -> APIResponse:
- """Delete a user"""
- security_manager = current_app.appbuilder.sm
+ """Delete a user."""
+ security_manager = get_airflow_app().appbuilder.sm
user = security_manager.find_user(username=username)
if user is None:
diff --git a/airflow/api_connexion/endpoints/variable_endpoint.py b/airflow/api_connexion/endpoints/variable_endpoint.py
index 487d2cc486c83..3111ff18d424d 100644
--- a/airflow/api_connexion/endpoints/variable_endpoint.py
+++ b/airflow/api_connexion/endpoints/variable_endpoint.py
@@ -14,40 +14,52 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
from http import HTTPStatus
-from typing import Optional
-from flask import Response, request
+from flask import Response
from marshmallow import ValidationError
from sqlalchemy import func
from sqlalchemy.orm import Session
from airflow.api_connexion import security
+from airflow.api_connexion.endpoints.request_dict import get_json_request_dict
from airflow.api_connexion.exceptions import BadRequest, NotFound
from airflow.api_connexion.parameters import apply_sorting, check_limit, format_parameters
from airflow.api_connexion.schemas.variable_schema import variable_collection_schema, variable_schema
from airflow.api_connexion.types import UpdateMask
from airflow.models import Variable
from airflow.security import permissions
+from airflow.utils.log.action_logger import action_event_from_permission
from airflow.utils.session import NEW_SESSION, provide_session
+from airflow.www.decorators import action_logging
+
+RESOURCE_EVENT_PREFIX = "variable"
@security.requires_access([(permissions.ACTION_CAN_DELETE, permissions.RESOURCE_VARIABLE)])
+@action_logging(
+ event=action_event_from_permission(
+ prefix=RESOURCE_EVENT_PREFIX,
+ permission=permissions.ACTION_CAN_DELETE,
+ ),
+)
def delete_variable(*, variable_key: str) -> Response:
- """Delete variable"""
+ """Delete variable."""
if Variable.delete(variable_key) == 0:
raise NotFound("Variable not found")
return Response(status=HTTPStatus.NO_CONTENT)
@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_VARIABLE)])
-def get_variable(*, variable_key: str) -> Response:
- """Get a variables by key"""
- try:
- var = Variable.get(variable_key)
- except KeyError:
+@provide_session
+def get_variable(*, variable_key: str, session: Session = NEW_SESSION) -> Response:
+ """Get a variable by key."""
+ var = session.query(Variable).filter(Variable.key == variable_key)
+ if not var.count():
raise NotFound("Variable not found")
- return variable_schema.dump({"key": variable_key, "val": var})
+ return variable_schema.dump(var.first())
@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_VARIABLE)])
@@ -55,15 +67,15 @@ def get_variable(*, variable_key: str) -> Response:
@provide_session
def get_variables(
*,
- limit: Optional[int],
+ limit: int | None,
order_by: str = "id",
- offset: Optional[int] = None,
+ offset: int | None = None,
session: Session = NEW_SESSION,
) -> Response:
- """Get all variable values"""
+ """Get all variable values."""
total_entries = session.query(func.count(Variable.id)).scalar()
to_replace = {"value": "val"}
- allowed_filter_attrs = ['value', 'key', 'id']
+ allowed_filter_attrs = ["value", "key", "id"]
query = session.query(Variable)
query = apply_sorting(query, order_by, to_replace, allowed_filter_attrs)
variables = query.offset(offset).limit(limit).all()
@@ -76,10 +88,16 @@ def get_variables(
@security.requires_access([(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_VARIABLE)])
+@action_logging(
+ event=action_event_from_permission(
+ prefix=RESOURCE_EVENT_PREFIX,
+ permission=permissions.ACTION_CAN_EDIT,
+ ),
+)
def patch_variable(*, variable_key: str, update_mask: UpdateMask = None) -> Response:
- """Update a variable by key"""
+ """Update a variable by key."""
try:
- data = variable_schema.load(request.json)
+ data = variable_schema.load(get_json_request_dict())
except ValidationError as err:
raise BadRequest("Invalid Variable schema", detail=str(err.messages))
@@ -97,10 +115,16 @@ def patch_variable(*, variable_key: str, update_mask: UpdateMask = None) -> Resp
@security.requires_access([(permissions.ACTION_CAN_CREATE, permissions.RESOURCE_VARIABLE)])
+@action_logging(
+ event=action_event_from_permission(
+ prefix=RESOURCE_EVENT_PREFIX,
+ permission=permissions.ACTION_CAN_CREATE,
+ ),
+)
def post_variables() -> Response:
- """Create a variable"""
+ """Create a variable."""
try:
- data = variable_schema.load(request.json)
+ data = variable_schema.load(get_json_request_dict())
except ValidationError as err:
raise BadRequest("Invalid Variable schema", detail=str(err.messages))
diff --git a/airflow/api_connexion/endpoints/version_endpoint.py b/airflow/api_connexion/endpoints/version_endpoint.py
index 077d7f8a1cfe4..79b4d2f1e1719 100644
--- a/airflow/api_connexion/endpoints/version_endpoint.py
+++ b/airflow/api_connexion/endpoints/version_endpoint.py
@@ -14,8 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
-from typing import NamedTuple, Optional
+from typing import NamedTuple
import airflow
from airflow.api_connexion.schemas.version_schema import version_info_schema
@@ -24,14 +25,14 @@
class VersionInfo(NamedTuple):
- """Version information"""
+ """Version information."""
version: str
- git_version: Optional[str]
+ git_version: str | None
def get_version() -> APIResponse:
- """Get version information"""
+ """Get version information."""
airflow_version = airflow.__version__
git_version = get_airflow_git_version()
diff --git a/airflow/api_connexion/endpoints/xcom_endpoint.py b/airflow/api_connexion/endpoints/xcom_endpoint.py
index 9cc6b6d79a933..2ab5ec26f5a2a 100644
--- a/airflow/api_connexion/endpoints/xcom_endpoint.py
+++ b/airflow/api_connexion/endpoints/xcom_endpoint.py
@@ -14,9 +14,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Optional
+from __future__ import annotations
-from flask import current_app, g
+import copy
+
+from flask import g
from sqlalchemy import and_
from sqlalchemy.orm import Session
@@ -27,6 +29,7 @@
from airflow.api_connexion.types import APIResponse
from airflow.models import DagRun as DR, XCom
from airflow.security import permissions
+from airflow.utils.airflow_flask_app import get_airflow_app
from airflow.utils.session import NEW_SESSION, provide_session
@@ -45,14 +48,14 @@ def get_xcom_entries(
dag_id: str,
dag_run_id: str,
task_id: str,
- limit: Optional[int],
- offset: Optional[int] = None,
+ limit: int | None,
+ offset: int | None = None,
session: Session = NEW_SESSION,
) -> APIResponse:
- """Get all XCom values"""
+ """Get all XCom values."""
query = session.query(XCom)
- if dag_id == '~':
- appbuilder = current_app.appbuilder
+ if dag_id == "~":
+ appbuilder = get_airflow_app().appbuilder
readable_dag_ids = appbuilder.sm.get_readable_dag_ids(g.user)
query = query.filter(XCom.dag_id.in_(readable_dag_ids))
query = query.join(DR, and_(XCom.dag_id == DR.dag_id, XCom.run_id == DR.run_id))
@@ -60,14 +63,14 @@ def get_xcom_entries(
query = query.filter(XCom.dag_id == dag_id)
query = query.join(DR, and_(XCom.dag_id == DR.dag_id, XCom.run_id == DR.run_id))
- if task_id != '~':
+ if task_id != "~":
query = query.filter(XCom.task_id == task_id)
- if dag_run_id != '~':
+ if dag_run_id != "~":
query = query.filter(DR.run_id == dag_run_id)
query = query.order_by(DR.execution_date, XCom.task_id, XCom.dag_id, XCom.key)
total_entries = query.count()
query = query.offset(offset).limit(limit)
- return xcom_collection_schema.dump(XComCollection(xcom_entries=query.all(), total_entries=total_entries))
+ return xcom_collection_schema.dump(XComCollection(xcom_entries=query, total_entries=total_entries))
@security.requires_access(
@@ -85,14 +88,28 @@ def get_xcom_entry(
task_id: str,
dag_run_id: str,
xcom_key: str,
+ deserialize: bool = False,
session: Session = NEW_SESSION,
) -> APIResponse:
- """Get an XCom entry"""
- query = session.query(XCom).filter(XCom.dag_id == dag_id, XCom.task_id == task_id, XCom.key == xcom_key)
+ """Get an XCom entry."""
+ if deserialize:
+ query = session.query(XCom, XCom.value)
+ else:
+ query = session.query(XCom)
+
+ query = query.filter(XCom.dag_id == dag_id, XCom.task_id == task_id, XCom.key == xcom_key)
query = query.join(DR, and_(XCom.dag_id == DR.dag_id, XCom.run_id == DR.run_id))
query = query.filter(DR.run_id == dag_run_id)
- query_object = query.one_or_none()
- if not query_object:
+ item = query.one_or_none()
+ if item is None:
raise NotFound("XCom entry not found")
- return xcom_schema.dump(query_object)
+
+ if deserialize:
+ xcom, value = item
+ stub = copy.copy(xcom)
+ stub.value = value
+ stub.value = XCom.deserialize_value(stub)
+ item = stub
+
+ return xcom_schema.dump(item)
diff --git a/airflow/api_connexion/exceptions.py b/airflow/api_connexion/exceptions.py
index 8fb7f2e78883b..11468e1506feb 100644
--- a/airflow/api_connexion/exceptions.py
+++ b/airflow/api_connexion/exceptions.py
@@ -14,8 +14,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
from http import HTTPStatus
-from typing import Any, Dict, Optional
+from typing import Any
import flask
import werkzeug
@@ -71,13 +73,13 @@ def common_error_handler(exception: BaseException) -> flask.Response:
class NotFound(ProblemException):
- """Raise when the object cannot be found"""
+ """Raise when the object cannot be found."""
def __init__(
self,
- title: str = 'Not Found',
- detail: Optional[str] = None,
- headers: Optional[Dict] = None,
+ title: str = "Not Found",
+ detail: str | None = None,
+ headers: dict | None = None,
**kwargs: Any,
) -> None:
super().__init__(
@@ -91,13 +93,13 @@ def __init__(
class BadRequest(ProblemException):
- """Raise when the server processes a bad request"""
+ """Raise when the server processes a bad request."""
def __init__(
self,
title: str = "Bad Request",
- detail: Optional[str] = None,
- headers: Optional[Dict] = None,
+ detail: str | None = None,
+ headers: dict | None = None,
**kwargs: Any,
) -> None:
super().__init__(
@@ -111,13 +113,13 @@ def __init__(
class Unauthenticated(ProblemException):
- """Raise when the user is not authenticated"""
+ """Raise when the user is not authenticated."""
def __init__(
self,
title: str = "Unauthorized",
- detail: Optional[str] = None,
- headers: Optional[Dict] = None,
+ detail: str | None = None,
+ headers: dict | None = None,
**kwargs: Any,
):
super().__init__(
@@ -131,13 +133,13 @@ def __init__(
class PermissionDenied(ProblemException):
- """Raise when the user does not have the required permissions"""
+ """Raise when the user does not have the required permissions."""
def __init__(
self,
title: str = "Forbidden",
- detail: Optional[str] = None,
- headers: Optional[Dict] = None,
+ detail: str | None = None,
+ headers: dict | None = None,
**kwargs: Any,
) -> None:
super().__init__(
@@ -151,13 +153,13 @@ def __init__(
class AlreadyExists(ProblemException):
- """Raise when the object already exists"""
+ """Raise when the object already exists."""
def __init__(
self,
title="Conflict",
- detail: Optional[str] = None,
- headers: Optional[Dict] = None,
+ detail: str | None = None,
+ headers: dict | None = None,
**kwargs: Any,
):
super().__init__(
@@ -171,13 +173,13 @@ def __init__(
class Unknown(ProblemException):
- """Returns a response body and status code for HTTP 500 exception"""
+ """Returns a response body and status code for HTTP 500 exception."""
def __init__(
self,
title: str = "Internal Server Error",
- detail: Optional[str] = None,
- headers: Optional[Dict] = None,
+ detail: str | None = None,
+ headers: dict | None = None,
**kwargs: Any,
) -> None:
super().__init__(
diff --git a/airflow/api_connexion/openapi/v1.yaml b/airflow/api_connexion/openapi/v1.yaml
index fc71fbb7f5e3d..3f25b54640741 100644
--- a/airflow/api_connexion/openapi/v1.yaml
+++ b/airflow/api_connexion/openapi/v1.yaml
@@ -229,7 +229,7 @@ info:
This means that the server encountered an unexpected condition that prevented it from
fulfilling the request.
- version: '1.0.0'
+ version: '2.5.1'
license:
name: Apache 2.0
url: http://www.apache.org/licenses/LICENSE-2.0.html
@@ -569,7 +569,7 @@ paths:
content:
application/json:
schema:
- $ref: '#/components/schemas/ClearTaskInstance'
+ $ref: '#/components/schemas/ClearTaskInstances'
responses:
'200':
@@ -585,6 +585,85 @@ paths:
'404':
$ref: '#/components/responses/NotFound'
+ /dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/setNote:
+ parameters:
+ - $ref: '#/components/parameters/DAGID'
+ - $ref: '#/components/parameters/DAGRunID'
+ - $ref: '#/components/parameters/TaskID'
+
+ patch:
+ summary: Update the TaskInstance note.
+ description: |
+ Update the manual user note of a non-mapped Task Instance.
+
+ *New in version 2.5.0*
+ x-openapi-router-controller: airflow.api_connexion.endpoints.task_instance_endpoint
+ operationId: set_task_instance_note
+ tags: [TaskInstance]
+ requestBody:
+ description: Parameters of set Task Instance note.
+ required: true
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/SetTaskInstanceNote'
+
+ responses:
+ '200':
+ description: Success.
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/TaskInstance'
+ '400':
+ $ref: '#/components/responses/BadRequest'
+ '401':
+ $ref: '#/components/responses/Unauthenticated'
+ '403':
+ $ref: '#/components/responses/PermissionDenied'
+ '404':
+ $ref: '#/components/responses/NotFound'
+
+ /dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/{map_index}/setNote:
+ parameters:
+ - $ref: '#/components/parameters/DAGID'
+ - $ref: '#/components/parameters/DAGRunID'
+ - $ref: '#/components/parameters/TaskID'
+ - $ref: '#/components/parameters/MapIndex'
+
+ patch:
+ summary: Update the TaskInstance note.
+ description: |
+ Update the manual user note of a mapped Task Instance.
+
+ *New in version 2.5.0*
+ x-openapi-router-controller: airflow.api_connexion.endpoints.task_instance_endpoint
+ operationId: set_mapped_task_instance_note
+ tags: [TaskInstance]
+ requestBody:
+ description: Parameters of set Task Instance note.
+ required: true
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/SetTaskInstanceNote'
+
+ responses:
+ '200':
+ description: Success.
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/TaskInstance'
+ '400':
+ $ref: '#/components/responses/BadRequest'
+ '401':
+ $ref: '#/components/responses/Unauthenticated'
+ '403':
+ $ref: '#/components/responses/PermissionDenied'
+ '404':
+ $ref: '#/components/responses/NotFound'
+
/dags/{dag_id}/updateTaskInstancesState:
parameters:
- $ref: '#/components/parameters/DAGID'
@@ -818,6 +897,70 @@ paths:
'404':
$ref: '#/components/responses/NotFound'
+ /dags/{dag_id}/dagRuns/{dag_run_id}/upstreamDatasetEvents:
+ parameters:
+ - $ref: '#/components/parameters/DAGID'
+ - $ref: '#/components/parameters/DAGRunID'
+ get:
+ summary: Get dataset events for a DAG run
+ description: |
+ Get datasets for a dag run.
+
+ *New in version 2.4.0*
+ x-openapi-router-controller: airflow.api_connexion.endpoints.dag_run_endpoint
+ operationId: get_upstream_dataset_events
+ tags: [DAGRun, Dataset]
+ responses:
+ '200':
+ description: Success.
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/DatasetEventCollection'
+ '401':
+ $ref: '#/components/responses/Unauthenticated'
+ '403':
+ $ref: '#/components/responses/PermissionDenied'
+ '404':
+ $ref: '#/components/responses/NotFound'
+
+ /dags/{dag_id}/dagRuns/{dag_run_id}/setNote:
+ parameters:
+ - $ref: '#/components/parameters/DAGID'
+ - $ref: '#/components/parameters/DAGRunID'
+ patch:
+ summary: Update the DagRun note.
+ description: |
+ Update the manual user note of a DagRun.
+
+ *New in version 2.5.0*
+ x-openapi-router-controller: airflow.api_connexion.endpoints.dag_run_endpoint
+ operationId: set_dag_run_note
+ tags: [DAGRun]
+ requestBody:
+ description: Parameters of set DagRun note.
+ required: true
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/SetDagRunNote'
+
+ responses:
+ '200':
+ description: Success.
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/DAGRun'
+ '400':
+ $ref: '#/components/responses/BadRequest'
+ '401':
+ $ref: '#/components/responses/Unauthenticated'
+ '403':
+ $ref: '#/components/responses/PermissionDenied'
+ '404':
+ $ref: '#/components/responses/NotFound'
+
/eventLogs:
get:
summary: List log entries
@@ -1116,6 +1259,37 @@ paths:
'404':
$ref: '#/components/responses/NotFound'
+ patch:
+ summary: Updates the state of a task instance
+ description: >
+ Updates the state for single task instance.
+
+ *New in version 2.5.0*
+ x-openapi-router-controller: airflow.api_connexion.endpoints.task_instance_endpoint
+ operationId: patch_task_instance
+ tags: [TaskInstance]
+ requestBody:
+ description: Parameters of action
+ required: true
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/UpdateTaskInstance'
+ responses:
+ '200':
+ description: Success.
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/TaskInstanceReference'
+ '401':
+ $ref: '#/components/responses/Unauthenticated'
+ '403':
+ $ref: '#/components/responses/PermissionDenied'
+ '404':
+ $ref: '#/components/responses/NotFound'
+
+
/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/{map_index}:
parameters:
- $ref: '#/components/parameters/DAGID'
@@ -1146,6 +1320,36 @@ paths:
'404':
$ref: '#/components/responses/NotFound'
+ patch:
+ summary: Updates the state of a mapped task instance
+ description: >
+ Updates the state for single mapped task instance.
+
+ *New in version 2.5.0*
+ x-openapi-router-controller: airflow.api_connexion.endpoints.task_instance_endpoint
+ operationId: patch_mapped_task_instance
+ tags: [TaskInstance]
+ requestBody:
+ description: Parameters of action
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/UpdateTaskInstance'
+ responses:
+ '200':
+ description: Success.
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/TaskInstanceReference'
+ '401':
+ $ref: '#/components/responses/Unauthenticated'
+ '403':
+ $ref: '#/components/responses/PermissionDenied'
+ '404':
+ $ref: '#/components/responses/NotFound'
+
+
/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/listMapped:
parameters:
- $ref: '#/components/parameters/DAGID'
@@ -1182,7 +1386,7 @@ paths:
content:
application/json:
schema:
- $ref: '#/components/schemas/TaskInstance'
+ $ref: '#/components/schemas/TaskInstanceCollection'
'401':
$ref: '#/components/responses/Unauthenticated'
'403':
@@ -1384,6 +1588,23 @@ paths:
x-openapi-router-controller: airflow.api_connexion.endpoints.xcom_endpoint
operationId: get_xcom_entry
tags: [XCom]
+ parameters:
+ - in: query
+ name: deserialize
+ schema:
+ type: boolean
+ default: false
+ required: false
+ description: |
+ Whether to deserialize an XCom value when using a custom XCom backend.
+
+ The XCom API endpoint calls `orm_deserialize_value` by default since an XCom may contain value
+ that is potentially expensive to deserialize in the web server. Setting this to true overrides
+ the consideration, and calls `deserialize_value` instead.
+
+ This parameter is not meaningful when using the default XCom backend.
+
+ *New in version 2.4.0*
responses:
'200':
description: Success.
@@ -1433,6 +1654,7 @@ paths:
- $ref: '#/components/parameters/TaskID'
- $ref: '#/components/parameters/TaskTryNumber'
- $ref: '#/components/parameters/FullContent'
+ - $ref: '#/components/parameters/FilterMapIndex'
- $ref: '#/components/parameters/ContinuationToken'
get:
@@ -1465,6 +1687,7 @@ paths:
'404':
$ref: '#/components/responses/NotFound'
+
/dags/{dag_id}/details:
parameters:
- $ref: '#/components/parameters/DAGID'
@@ -1573,6 +1796,123 @@ paths:
'406':
$ref: '#/components/responses/NotAcceptable'
+ /dagWarnings:
+ get:
+ summary: List dag warnings
+ x-openapi-router-controller: airflow.api_connexion.endpoints.dag_warning_endpoint
+ operationId: get_dag_warnings
+ tags: [DagWarning]
+ parameters:
+ - name: dag_id
+ in: query
+ schema:
+ type: string
+ required: false
+ description: If set, only return DAG warnings with this dag_id.
+ - name: warning_type
+ in: query
+ schema:
+ type: string
+ required: false
+ description: If set, only return DAG warnings with this type.
+ - $ref: '#/components/parameters/PageLimit'
+ - $ref: '#/components/parameters/PageOffset'
+ - $ref: '#/components/parameters/OrderBy'
+
+ responses:
+ '200':
+ description: Success.
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/DagWarningCollection'
+ '401':
+ $ref: '#/components/responses/Unauthenticated'
+ '403':
+ $ref: '#/components/responses/PermissionDenied'
+
+ /datasets:
+ get:
+ summary: List datasets
+ x-openapi-router-controller: airflow.api_connexion.endpoints.dataset_endpoint
+ operationId: get_datasets
+ tags: [Dataset]
+ parameters:
+ - $ref: '#/components/parameters/PageLimit'
+ - $ref: '#/components/parameters/PageOffset'
+ - $ref: '#/components/parameters/OrderBy'
+ - name: uri_pattern
+ in: query
+ schema:
+ type: string
+ required: false
+ description: |
+ If set, only return datasets with uris matching this pattern.
+ responses:
+ '200':
+ description: Success.
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/DatasetCollection'
+ '401':
+ $ref: '#/components/responses/Unauthenticated'
+ '403':
+ $ref: '#/components/responses/PermissionDenied'
+
+ /datasets/{uri}:
+ parameters:
+ - $ref: '#/components/parameters/DatasetURI'
+ get:
+ summary: Get a dataset
+ description: Get a dataset by uri.
+ x-openapi-router-controller: airflow.api_connexion.endpoints.dataset_endpoint
+ operationId: get_dataset
+ tags: [Dataset]
+ responses:
+ '200':
+ description: Success.
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/Dataset'
+ '401':
+ $ref: '#/components/responses/Unauthenticated'
+ '403':
+ $ref: '#/components/responses/PermissionDenied'
+ '404':
+ $ref: '#/components/responses/NotFound'
+
+ /datasets/events:
+ parameters:
+ - $ref: '#/components/parameters/PageLimit'
+ - $ref: '#/components/parameters/PageOffset'
+ - $ref: '#/components/parameters/OrderBy'
+ - $ref: '#/components/parameters/FilterDatasetID'
+ - $ref: '#/components/parameters/FilterSourceDAGID'
+ - $ref: '#/components/parameters/FilterSourceTaskID'
+ - $ref: '#/components/parameters/FilterSourceRunID'
+ - $ref: '#/components/parameters/FilterSourceMapIndex'
+ get:
+ summary: Get dataset events
+ description: Get dataset events
+ x-openapi-router-controller: airflow.api_connexion.endpoints.dataset_endpoint
+ operationId: get_dataset_events
+ tags: [Dataset]
+ responses:
+ '200':
+ description: Success.
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/DatasetEventCollection'
+ '401':
+ $ref: '#/components/responses/Unauthenticated'
+ '403':
+ $ref: '#/components/responses/PermissionDenied'
+ '404':
+ $ref: '#/components/responses/NotFound'
+
/config:
get:
summary: Get current configuration
@@ -1988,15 +2328,13 @@ components:
description: |
The user's first name.
- *Changed in version 2.2.0*: A minimum character length requirement ('minLength') is added.
- minLength: 1
+ *Changed in version 2.4.0*: The requirement for this to be non-empty was removed.
last_name:
type: string
description: |
The user's last name.
- *Changed in version 2.2.0*: A minimum character length requirement ('minLength') is added.
- minLength: 1
+ *Changed in version 2.4.0*: The requirement for this to be non-empty was removed.
username:
type: string
description: |
@@ -2099,6 +2437,10 @@ components:
conn_type:
type: string
description: The connection type.
+ description:
+ type: string
+ description: The description of the connection.
+ nullable: true
host:
type: string
nullable: true
@@ -2451,6 +2793,7 @@ components:
- backfill
- manual
- scheduled
+ - dataset_triggered
state:
$ref: '#/components/schemas/DagState'
readOnly: true
@@ -2465,6 +2808,13 @@ components:
The value of this field can be set only when creating the object. If you try to modify the
field of an existing object, the request fails with an BAD_REQUEST error.
+ note:
+ type: string
+ description: |
+ Contains manually entered notes by the user about the DagRun.
+
+ *New in version 2.5.0*
+ nullable: true
UpdateDagRunState:
type: object
@@ -2496,20 +2846,62 @@ components:
$ref: '#/components/schemas/DAGRun'
- $ref: '#/components/schemas/CollectionInfo'
- EventLog:
+ DagWarning:
type: object
- description: Log of user operations via CLI or Web UI.
properties:
- event_log_id:
- description: The event log ID
- type: integer
+ dag_id:
+ type: string
readOnly: true
- when:
- description: The time when these events happened.
- format: date-time
+ description: The dag_id.
+ warning_type:
type: string
readOnly: true
- dag_id:
+ description: The warning type for the dag warning.
+ message:
+ type: string
+ readOnly: true
+ description: The message for the dag warning.
+ timestamp:
+ type: string
+ format: datetime
+ readOnly: true
+ description: The time when this warning was logged.
+
+ DagWarningCollection:
+ type: object
+ description: |
+ Collection of DAG warnings.
+
+ allOf:
+ - type: object
+ properties:
+ import_errors:
+ type: array
+ items:
+ $ref: '#/components/schemas/DagWarning'
+ - $ref: '#/components/schemas/CollectionInfo'
+
+ SetDagRunNote:
+ type: object
+ properties:
+ note:
+ description: Custom notes left by users for this Dag Run.
+ type: string
+
+ EventLog:
+ type: object
+ description: Log of user operations via CLI or Web UI.
+ properties:
+ event_log_id:
+ description: The event log ID
+ type: integer
+ readOnly: true
+ when:
+ description: The time when these events happened.
+ format: date-time
+ type: string
+ readOnly: true
+ dag_id:
description: The DAG ID
type: string
readOnly: true
@@ -2621,7 +3013,6 @@ components:
readOnly: true
nullable: true
-
Pool:
description: The pool
type: object
@@ -2726,6 +3117,59 @@ components:
nullable: true
notification_sent:
type: boolean
+ nullable: true
+
+ Trigger:
+ type: object
+ properties:
+ id:
+ type: integer
+ classpath:
+ type: string
+ kwargs:
+ type: string
+ created_date:
+ type: string
+ format: datetime
+ triggerer_id:
+ type: integer
+ nullable: true
+
+ Job:
+ type: object
+ properties:
+ id:
+ type: integer
+ dag_id:
+ type: string
+ nullable: true
+ state:
+ type: string
+ nullable: true
+ job_type:
+ type: string
+ nullable: true
+ start_date:
+ type: string
+ format: datetime
+ nullable: true
+ end_date:
+ type: string
+ format: datetime
+ nullable: true
+ latest_heartbeat:
+ type: string
+ format: datetime
+ nullable: true
+ executor_class:
+ type: string
+ nullable: true
+ hostname:
+ type: string
+ nullable: true
+ unixname:
+ type: string
+ nullable: true
TaskInstance:
type: object
@@ -2759,6 +3203,8 @@ components:
nullable: true
try_number:
type: integer
+ map_index:
+ type: integer
max_tries:
type: integer
hostname:
@@ -2771,8 +3217,10 @@ components:
type: integer
queue:
type: string
+ nullable: true
priority_weight:
type: integer
+ nullable: true
operator:
type: string
nullable: true
@@ -2795,6 +3243,19 @@ components:
*New in version 2.3.0*
type: object
+ trigger:
+ $ref: '#/components/schemas/Trigger'
+ nullable: true
+ triggerer_job:
+ $ref: '#/components/schemas/Job'
+ nullable: true
+ note:
+ type: string
+ description: |
+ Contains manually entered notes by the user about the TaskInstance.
+
+ *New in version 2.5.0*
+ nullable: true
TaskInstanceCollection:
type: object
@@ -2850,6 +3311,13 @@ components:
properties:
key:
type: string
+ description:
+ type: string
+ description: |
+ The description of the variable.
+
+ *New in version 2.4.0*
+ nullable: true
VariableCollection:
type: object
@@ -3085,6 +3553,7 @@ components:
queue:
type: string
readOnly: true
+ nullable: true
pool:
type: string
readOnly: true
@@ -3139,9 +3608,6 @@ components:
*New in version 2.1.0*
properties:
- number:
- type: string
- description: The plugin number
name:
type: string
description: The name of the plugin
@@ -3302,6 +3768,211 @@ components:
$ref: '#/components/schemas/Resource'
description: The permission resource
+ Dataset:
+ description: |
+ A dataset item.
+
+ *New in version 2.4.0*
+ type: object
+ properties:
+ id:
+ type: integer
+ description: The dataset id
+ uri:
+ type: string
+ description: The dataset uri
+ nullable: false
+ extra:
+ type: object
+ description: The dataset extra
+ nullable: true
+ created_at:
+ type: string
+ description: The dataset creation time
+ nullable: false
+ updated_at:
+ type: string
+ description: The dataset update time
+ nullable: false
+ consuming_dags:
+ type: array
+ items:
+ $ref: '#/components/schemas/DagScheduleDatasetReference'
+ producing_tasks:
+ type: array
+ items:
+ $ref: '#/components/schemas/TaskOutletDatasetReference'
+
+
+ TaskOutletDatasetReference:
+ description: |
+ A datasets reference to an upstream task.
+
+ *New in version 2.4.0*
+ type: object
+ properties:
+ dag_id:
+ type: string
+ description: The DAG ID that updates the dataset.
+ nullable: true
+ task_id:
+ type: string
+ description: The task ID that updates the dataset.
+ nullable: true
+ created_at:
+ type: string
+ description: The dataset creation time
+ nullable: false
+ updated_at:
+ type: string
+ description: The dataset update time
+ nullable: false
+
+ DagScheduleDatasetReference:
+ description: |
+ A datasets reference to a downstream DAG.
+
+ *New in version 2.4.0*
+ type: object
+ properties:
+ dag_id:
+ type: string
+ description: The DAG ID that depends on the dataset.
+ nullable: true
+ created_at:
+ type: string
+ description: The dataset reference creation time
+ nullable: false
+ updated_at:
+ type: string
+ description: The dataset reference update time
+ nullable: false
+
+ DatasetCollection:
+ description: |
+ A collection of datasets.
+
+ *New in version 2.4.0*
+ type: object
+ allOf:
+ - type: object
+ properties:
+ datasets:
+ type: array
+ items:
+ $ref: '#/components/schemas/Dataset'
+ - $ref: '#/components/schemas/CollectionInfo'
+
+ DatasetEvent:
+ description: |
+ A dataset event.
+
+ *New in version 2.4.0*
+ type: object
+ properties:
+ dataset_id:
+ type: integer
+ description: The dataset id
+ dataset_uri:
+ type: string
+ description: The URI of the dataset
+ nullable: false
+ extra:
+ type: object
+ description: The dataset event extra
+ nullable: true
+ source_dag_id:
+ type: string
+ description: The DAG ID that updated the dataset.
+ nullable: true
+ source_task_id:
+ type: string
+ description: The task ID that updated the dataset.
+ nullable: true
+ source_run_id:
+ type: string
+ description: The DAG run ID that updated the dataset.
+ nullable: true
+ source_map_index:
+ type: integer
+ description: The task map index that updated the dataset.
+ nullable: true
+ created_dagruns:
+ type: array
+ items:
+ $ref: '#/components/schemas/BasicDAGRun'
+ timestamp:
+ type: string
+ description: The dataset event creation time
+ nullable: false
+
+ BasicDAGRun:
+ type: object
+ properties:
+ run_id:
+ type: string
+ description: |
+ Run ID.
+ dag_id:
+ type: string
+ readOnly: true
+ logical_date:
+ type: string
+ description: |
+ The logical date (previously called execution date). This is the time or interval covered by
+ this DAG run, according to the DAG definition.
+
+ The value of this field can be set only when creating the object. If you try to modify the
+ field of an existing object, the request fails with an BAD_REQUEST error.
+
+ This together with DAG_ID are a unique key.
+
+ *New in version 2.2.0*
+ format: date-time
+ start_date:
+ type: string
+ format: date-time
+ description: |
+ The start time. The time when DAG run was actually created.
+
+ *Changed in version 2.1.3*: Field becomes nullable.
+ readOnly: true
+ nullable: true
+ end_date:
+ type: string
+ format: date-time
+ readOnly: true
+ nullable: true
+ data_interval_start:
+ type: string
+ format: date-time
+ readOnly: true
+ nullable: true
+ data_interval_end:
+ type: string
+ format: date-time
+ readOnly: true
+ nullable: true
+ state:
+ $ref: '#/components/schemas/DagState'
+ readOnly: true
+
+ DatasetEventCollection:
+ description: |
+ A collection of dataset events.
+
+ *New in version 2.4.0*
+ type: object
+ allOf:
+ - type: object
+ properties:
+ dataset_events:
+ type: array
+ items:
+ $ref: '#/components/schemas/DatasetEvent'
+ - $ref: '#/components/schemas/CollectionInfo'
+
+
# Configuration
ConfigOption:
type: object
@@ -3358,7 +4029,7 @@ components:
type: boolean
default: true
- ClearTaskInstance:
+ ClearTaskInstances:
type: object
properties:
dry_run:
@@ -3410,6 +4081,31 @@ components:
description: Set state of DAG runs to RUNNING.
type: boolean
+ dag_run_id:
+ type: string
+ description: The DagRun ID for this task instance
+ nullable: true
+
+ include_upstream:
+ description: If set to true, upstream tasks are also affected.
+ type: boolean
+ default: false
+
+ include_downstream:
+ description: If set to true, downstream tasks are also affected.
+ type: boolean
+ default: false
+
+ include_future:
+ description: If set to True, also tasks from future DAG Runs are affected.
+ type: boolean
+ default: false
+
+ include_past:
+ description: If set to True, also tasks from past DAG Runs are affected.
+ type: boolean
+ default: false
+
UpdateTaskInstancesState:
type: object
properties:
@@ -3459,6 +4155,31 @@ components:
- success
- failed
+ UpdateTaskInstance:
+ type: object
+ properties:
+ dry_run:
+ description: |
+ If set, don't actually run this operation. The response will contain the task instance
+ planned to be affected, but won't be modified in any way.
+ type: boolean
+ default: false
+
+ new_state:
+ description: Expected new state.
+ type: string
+ enum:
+ - success
+ - failed
+ SetTaskInstanceNote:
+ type: object
+ required:
+ - note
+ properties:
+ note:
+ description: The custom note to set for this Task Instance.
+ type: string
+
ListDagRunsForm:
type: object
properties:
@@ -3634,6 +4355,7 @@ components:
description: |
Schedule interval. Defines how often DAG runs, this object gets added to your latest task instance's
execution_date to figure out the next schedule.
+ nullable: true
readOnly: true
anyOf:
- $ref: '#/components/schemas/TimeDelta'
@@ -3779,7 +4501,10 @@ components:
*Changed in version 2.0.2*: 'removed' is added as a possible value.
- *Changed in version 2.2.0*: 'deferred' and 'sensing' is added as a possible value.
+ *Changed in version 2.2.0*: 'deferred' is added as a possible value.
+
+ *Changed in version 2.4.0*: 'sensing' state has been removed.
+ *Changed in version 2.4.2*: 'restarting' is added as a possible value
type: string
enum:
- success
@@ -3793,8 +4518,8 @@ components:
- none
- scheduled
- deferred
- - sensing
- removed
+ - restarting
DagState:
description: |
@@ -3945,6 +4670,15 @@ components:
required: true
description: The import error ID.
+ DatasetURI:
+ in: path
+ name: uri
+ schema:
+ type: string
+ format: path
+ required: true
+ description: The encoded Dataset URI
+
PoolName:
in: path
name: pool_name
@@ -3958,6 +4692,7 @@ components:
name: variable_key
schema:
type: string
+ format: path
required: true
description: The variable Key.
@@ -3989,7 +4724,8 @@ components:
type: string
required: true
description: The XCom key.
- # Filter
+
+ # Filters
FilterExecutionDateGTE:
in: query
name: execution_date_gte
@@ -4118,6 +4854,48 @@ components:
*New in version 2.2.0*
+ FilterDatasetID:
+ in: query
+ name: dataset_id
+ schema:
+ type: integer
+ description: The Dataset ID that updated the dataset.
+
+ FilterSourceDAGID:
+ in: query
+ name: source_dag_id
+ schema:
+ type: string
+ description: The DAG ID that updated the dataset.
+
+ FilterSourceTaskID:
+ in: query
+ name: source_task_id
+ schema:
+ type: string
+ description: The task ID that updated the dataset.
+
+ FilterSourceRunID:
+ in: query
+ name: source_run_id
+ schema:
+ type: string
+ description: The DAG run ID that updated the dataset.
+
+ FilterSourceMapIndex:
+ in: query
+ name: source_map_index
+ schema:
+ type: integer
+ description: The map index that updated the dataset.
+
+ FilterMapIndex:
+ in: query
+ name: map_index
+ schema:
+ type: integer
+ description: Filter on map index for mapped task.
+
OrderBy:
in: query
name: order_by
@@ -4143,7 +4921,6 @@ components:
*New in version 2.1.1*
# Other parameters
-
FileToken:
in: path
name: file_token
@@ -4274,6 +5051,8 @@ tags:
- name: Role
- name: Permission
- name: User
+ - name: DagWarning
+ - name: Dataset
externalDocs:
url: https://airflow.apache.org/docs/apache-airflow/stable/
diff --git a/airflow/api_connexion/parameters.py b/airflow/api_connexion/parameters.py
index 81d1bd9280abb..8064912d921fa 100644
--- a/airflow/api_connexion/parameters.py
+++ b/airflow/api_connexion/parameters.py
@@ -14,9 +14,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
from datetime import datetime
from functools import wraps
-from typing import Any, Callable, Container, Dict, Optional, TypeVar, cast
+from typing import Any, Callable, Container, TypeVar, cast
from pendulum.parsing import ParserError
from sqlalchemy import text
@@ -28,21 +30,23 @@
def validate_istimezone(value: datetime) -> None:
- """Validates that a datetime is not naive"""
+ """Validates that a datetime is not naive."""
if not value.tzinfo:
raise BadRequest("Invalid datetime format", detail="Naive datetime is disallowed")
def format_datetime(value: str) -> datetime:
"""
+ Format datetime objects.
+
Datetime format parser for args since connexion doesn't parse datetimes
https://github.com/zalando/connexion/issues/476
This should only be used within connection views because it raises 400
"""
value = value.strip()
- if value[-1] != 'Z':
- value = value.replace(" ", '+')
+ if value[-1] != "Z":
+ value = value.replace(" ", "+")
try:
return timezone.parse(value)
except (ParserError, TypeError) as err:
@@ -51,6 +55,8 @@ def format_datetime(value: str) -> datetime:
def check_limit(value: int) -> int:
"""
+ Check the limit does not exceed configured value.
+
This checks the limit passed to view and raises BadRequest if
limit exceed user configured value
"""
@@ -69,7 +75,7 @@ def check_limit(value: int) -> int:
T = TypeVar("T", bound=Callable)
-def format_parameters(params_formatters: Dict[str, Callable[[Any], Any]]) -> Callable[[T], T]:
+def format_parameters(params_formatters: dict[str, Callable[[Any], Any]]) -> Callable[[T], T]:
"""
Decorator factory that create decorator that convert parameters using given formatters.
@@ -94,11 +100,11 @@ def wrapped_function(*args, **kwargs):
def apply_sorting(
query: Query,
order_by: str,
- to_replace: Optional[Dict[str, str]] = None,
- allowed_attrs: Optional[Container[str]] = None,
+ to_replace: dict[str, str] | None = None,
+ allowed_attrs: Container[str] | None = None,
) -> Query:
- """Apply sorting to query"""
- lstriped_orderby = order_by.lstrip('-')
+ """Apply sorting to query."""
+ lstriped_orderby = order_by.lstrip("-")
if allowed_attrs and lstriped_orderby not in allowed_attrs:
raise BadRequest(
detail=f"Ordering with '{lstriped_orderby}' is disallowed or "
diff --git a/airflow/api_connexion/schemas/common_schema.py b/airflow/api_connexion/schemas/common_schema.py
index 502d5b60bdddd..cf510137621b6 100644
--- a/airflow/api_connexion/schemas/common_schema.py
+++ b/airflow/api_connexion/schemas/common_schema.py
@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
import datetime
import inspect
@@ -31,13 +32,13 @@
class CronExpression(typing.NamedTuple):
- """Cron expression schema"""
+ """Cron expression schema."""
value: str
class TimeDeltaSchema(Schema):
- """Time delta schema"""
+ """Time delta schema."""
objectType = fields.Constant("TimeDelta", data_key="__type")
days = fields.Integer()
@@ -46,14 +47,14 @@ class TimeDeltaSchema(Schema):
@marshmallow.post_load
def make_time_delta(self, data, **kwargs):
- """Create time delta based on data"""
+ """Create time delta based on data."""
if "objectType" in data:
del data["objectType"]
return datetime.timedelta(**data)
class RelativeDeltaSchema(Schema):
- """Relative delta schema"""
+ """Relative delta schema."""
objectType = fields.Constant("RelativeDelta", data_key="__type")
years = fields.Integer()
@@ -74,7 +75,7 @@ class RelativeDeltaSchema(Schema):
@marshmallow.post_load
def make_relative_delta(self, data, **kwargs):
- """Create relative delta based on data"""
+ """Create relative delta based on data."""
if "objectType" in data:
del data["objectType"]
@@ -82,14 +83,14 @@ def make_relative_delta(self, data, **kwargs):
class CronExpressionSchema(Schema):
- """Cron expression schema"""
+ """Cron expression schema."""
objectType = fields.Constant("CronExpression", data_key="__type")
value = fields.String(required=True)
@marshmallow.post_load
def make_cron_expression(self, data, **kwargs):
- """Create cron expression based on data"""
+ """Create cron expression based on data."""
return CronExpression(data["value"])
@@ -118,7 +119,7 @@ def _dump(self, obj, update_fields=True, **kwargs):
return super()._dump(obj, update_fields=update_fields, **kwargs)
def get_obj_type(self, obj):
- """Select schema based on object type"""
+ """Select schema based on object type."""
if isinstance(obj, datetime.timedelta):
return "TimeDelta"
elif isinstance(obj, relativedelta.relativedelta):
@@ -130,7 +131,7 @@ def get_obj_type(self, obj):
class ColorField(fields.String):
- """Schema for color property"""
+ """Schema for color property."""
def __init__(self, **metadata):
super().__init__(**metadata)
@@ -138,7 +139,7 @@ def __init__(self, **metadata):
class WeightRuleField(fields.String):
- """Schema for WeightRule"""
+ """Schema for WeightRule."""
def __init__(self, **metadata):
super().__init__(**metadata)
@@ -146,7 +147,7 @@ def __init__(self, **metadata):
class TimezoneField(fields.String):
- """Schema for timezone"""
+ """Schema for timezone."""
class ClassReferenceSchema(Schema):
diff --git a/airflow/api_connexion/schemas/config_schema.py b/airflow/api_connexion/schemas/config_schema.py
index 2eb459ce14263..28095c177a1df 100644
--- a/airflow/api_connexion/schemas/config_schema.py
+++ b/airflow/api_connexion/schemas/config_schema.py
@@ -14,50 +14,51 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
-from typing import List, NamedTuple
+from typing import NamedTuple
from marshmallow import Schema, fields
class ConfigOptionSchema(Schema):
- """Config Option Schema"""
+ """Config Option Schema."""
key = fields.String(required=True)
value = fields.String(required=True)
class ConfigOption(NamedTuple):
- """Config option"""
+ """Config option."""
key: str
value: str
class ConfigSectionSchema(Schema):
- """Config Section Schema"""
+ """Config Section Schema."""
name = fields.String(required=True)
options = fields.List(fields.Nested(ConfigOptionSchema))
class ConfigSection(NamedTuple):
- """List of config options within a section"""
+ """List of config options within a section."""
name: str
- options: List[ConfigOption]
+ options: list[ConfigOption]
class ConfigSchema(Schema):
- """Config Schema"""
+ """Config Schema."""
sections = fields.List(fields.Nested(ConfigSectionSchema))
class Config(NamedTuple):
- """List of config sections with their options"""
+ """List of config sections with their options."""
- sections: List[ConfigSection]
+ sections: list[ConfigSection]
config_schema = ConfigSchema()
diff --git a/airflow/api_connexion/schemas/connection_schema.py b/airflow/api_connexion/schemas/connection_schema.py
index f06da92bacdab..4288ce079c554 100644
--- a/airflow/api_connexion/schemas/connection_schema.py
+++ b/airflow/api_connexion/schemas/connection_schema.py
@@ -15,8 +15,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
import json
-from typing import List, NamedTuple
+from typing import NamedTuple
from marshmallow import Schema, fields
from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field
@@ -25,14 +27,14 @@
class ConnectionCollectionItemSchema(SQLAlchemySchema):
- """Schema for a connection item"""
+ """Schema for a connection item."""
class Meta:
- """Meta"""
+ """Meta."""
model = Connection
- connection_id = auto_field('conn_id', required=True)
+ connection_id = auto_field("conn_id", required=True)
conn_type = auto_field(required=True)
description = auto_field()
host = auto_field()
@@ -42,10 +44,10 @@ class Meta:
class ConnectionSchema(ConnectionCollectionItemSchema):
- """Connection schema"""
+ """Connection schema."""
password = auto_field(load_only=True)
- extra = fields.Method('serialize_extra', deserialize='deserialize_extra', allow_none=True)
+ extra = fields.Method("serialize_extra", deserialize="deserialize_extra", allow_none=True)
@staticmethod
def serialize_extra(obj: Connection):
@@ -66,21 +68,21 @@ def deserialize_extra(value): # an explicit deserialize method is required for
class ConnectionCollection(NamedTuple):
- """List of Connections with meta"""
+ """List of Connections with meta."""
- connections: List[Connection]
+ connections: list[Connection]
total_entries: int
class ConnectionCollectionSchema(Schema):
- """Connection Collection Schema"""
+ """Connection Collection Schema."""
connections = fields.List(fields.Nested(ConnectionCollectionItemSchema))
total_entries = fields.Int()
class ConnectionTestSchema(Schema):
- """connection Test Schema"""
+ """connection Test Schema."""
status = fields.Boolean(required=True)
message = fields.String(required=True)
diff --git a/airflow/api_connexion/schemas/dag_run_schema.py b/airflow/api_connexion/schemas/dag_run_schema.py
index 5cd79b20228fb..7ca857951b9fe 100644
--- a/airflow/api_connexion/schemas/dag_run_schema.py
+++ b/airflow/api_connexion/schemas/dag_run_schema.py
@@ -15,8 +15,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
import json
-from typing import List, NamedTuple
+from typing import NamedTuple
from marshmallow import fields, post_dump, pre_load, validate
from marshmallow.schema import Schema
@@ -34,7 +36,7 @@
class ConfObject(fields.Field):
- """The conf field"""
+ """The conf field."""
def _serialize(self, value, attr, obj, **kwargs):
if not value:
@@ -51,15 +53,15 @@ def _deserialize(self, value, attr, data, **kwargs):
class DAGRunSchema(SQLAlchemySchema):
- """Schema for DAGRun"""
+ """Schema for DAGRun."""
class Meta:
- """Meta"""
+ """Meta."""
model = DagRun
dateformat = "iso"
- run_id = auto_field(data_key='dag_run_id')
+ run_id = auto_field(data_key="dag_run_id")
dag_id = auto_field(dump_only=True)
execution_date = auto_field(data_key="logical_date", validate=validate_istimezone)
start_date = auto_field(dump_only=True)
@@ -71,6 +73,7 @@ class Meta:
data_interval_end = auto_field(dump_only=True)
last_scheduling_decision = auto_field(dump_only=True)
run_type = auto_field(dump_only=True)
+ note = auto_field(dump_only=True)
@pre_load
def autogenerate(self, data, **kwargs):
@@ -110,7 +113,7 @@ def autofill(self, data, **kwargs):
class SetDagRunStateFormSchema(Schema):
- """Schema for handling the request of setting state of DAG run"""
+ """Schema for handling the request of setting state of DAG run."""
state = DagStateField(
validate=validate.OneOf(
@@ -120,32 +123,32 @@ class SetDagRunStateFormSchema(Schema):
class ClearDagRunStateFormSchema(Schema):
- """Schema for handling the request of clearing a DAG run"""
+ """Schema for handling the request of clearing a DAG run."""
dry_run = fields.Boolean(load_default=True)
class DAGRunCollection(NamedTuple):
- """List of DAGRuns with metadata"""
+ """List of DAGRuns with metadata."""
- dag_runs: List[DagRun]
+ dag_runs: list[DagRun]
total_entries: int
class DAGRunCollectionSchema(Schema):
- """DAGRun Collection schema"""
+ """DAGRun Collection schema."""
dag_runs = fields.List(fields.Nested(DAGRunSchema))
total_entries = fields.Int()
class DagRunsBatchFormSchema(Schema):
- """Schema to validate and deserialize the Form(request payload) submitted to DagRun Batch endpoint"""
+ """Schema to validate and deserialize the Form(request payload) submitted to DagRun Batch endpoint."""
class Meta:
- """Meta"""
+ """Meta."""
- datetimeformat = 'iso'
+ datetimeformat = "iso"
strict = True
order_by = fields.String()
@@ -161,8 +164,15 @@ class Meta:
end_date_lte = fields.DateTime(load_default=None, validate=validate_istimezone)
+class SetDagRunNoteFormSchema(Schema):
+ """Schema for handling the request of clearing a DAG run."""
+
+ note = fields.String(allow_none=True, validate=validate.Length(max=1000))
+
+
dagrun_schema = DAGRunSchema()
dagrun_collection_schema = DAGRunCollectionSchema()
set_dagrun_state_form_schema = SetDagRunStateFormSchema()
clear_dagrun_form_schema = ClearDagRunStateFormSchema()
dagruns_batch_form_schema = DagRunsBatchFormSchema()
+set_dagrun_note_form_schema = SetDagRunNoteFormSchema()
diff --git a/airflow/api_connexion/schemas/dag_schema.py b/airflow/api_connexion/schemas/dag_schema.py
index 2f369113290d9..182bbf180334f 100644
--- a/airflow/api_connexion/schemas/dag_schema.py
+++ b/airflow/api_connexion/schemas/dag_schema.py
@@ -14,8 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
-from typing import List, NamedTuple
+from typing import NamedTuple
from itsdangerous import URLSafeSerializer
from marshmallow import Schema, fields
@@ -28,10 +29,10 @@
class DagTagSchema(SQLAlchemySchema):
- """Dag Tag schema"""
+ """Dag Tag schema."""
class Meta:
- """Meta"""
+ """Meta."""
model = DagTag
@@ -39,10 +40,10 @@ class Meta:
class DAGSchema(SQLAlchemySchema):
- """DAG schema"""
+ """DAG schema."""
class Meta:
- """Meta"""
+ """Meta."""
model = DagModel
@@ -75,20 +76,20 @@ class Meta:
@staticmethod
def get_owners(obj: DagModel):
- """Convert owners attribute to DAG representation"""
- if not getattr(obj, 'owners', None):
+ """Convert owners attribute to DAG representation."""
+ if not getattr(obj, "owners", None):
return []
return obj.owners.split(",")
@staticmethod
def get_token(obj: DagModel):
- """Return file token"""
- serializer = URLSafeSerializer(conf.get('webserver', 'secret_key'))
+ """Return file token."""
+ serializer = URLSafeSerializer(conf.get_mandatory_value("webserver", "secret_key"))
return serializer.dumps(obj.fileloc)
class DAGDetailSchema(DAGSchema):
- """DAG details"""
+ """DAG details."""
owners = fields.Method("get_owners", dump_only=True)
timezone = TimezoneField()
@@ -100,7 +101,7 @@ class DAGDetailSchema(DAGSchema):
dag_run_timeout = fields.Nested(TimeDeltaSchema, attribute="dagrun_timeout")
doc_md = fields.String()
default_view = fields.String()
- params = fields.Method('get_params', dump_only=True)
+ params = fields.Method("get_params", dump_only=True)
tags = fields.Method("get_tags", dump_only=True) # type: ignore
is_paused = fields.Method("get_is_paused", dump_only=True)
is_active = fields.Method("get_is_active", dump_only=True)
@@ -108,7 +109,7 @@ class DAGDetailSchema(DAGSchema):
end_date = fields.DateTime(dump_only=True)
template_search_path = fields.String(dump_only=True)
render_template_as_native_obj = fields.Boolean(dump_only=True)
- last_loaded = fields.DateTime(dump_only=True, data_key='last_parsed')
+ last_loaded = fields.DateTime(dump_only=True, data_key="last_parsed")
@staticmethod
def get_concurrency(obj: DAG):
@@ -116,7 +117,7 @@ def get_concurrency(obj: DAG):
@staticmethod
def get_tags(obj: DAG):
- """Dumps tags as objects"""
+ """Dumps tags as objects."""
tags = obj.tags
if tags:
return [DagTagSchema().dump(dict(name=tag)) for tag in tags]
@@ -124,37 +125,37 @@ def get_tags(obj: DAG):
@staticmethod
def get_owners(obj: DAG):
- """Convert owners attribute to DAG representation"""
- if not getattr(obj, 'owner', None):
+ """Convert owners attribute to DAG representation."""
+ if not getattr(obj, "owner", None):
return []
return obj.owner.split(",")
@staticmethod
def get_is_paused(obj: DAG):
- """Checks entry in DAG table to see if this DAG is paused"""
+ """Checks entry in DAG table to see if this DAG is paused."""
return obj.get_is_paused()
@staticmethod
def get_is_active(obj: DAG):
- """Checks entry in DAG table to see if this DAG is active"""
+ """Checks entry in DAG table to see if this DAG is active."""
return obj.get_is_active()
@staticmethod
def get_params(obj: DAG):
- """Get the Params defined in a DAG"""
+ """Get the Params defined in a DAG."""
params = obj.params
return {k: v.dump() for k, v in params.items()}
class DAGCollection(NamedTuple):
- """List of DAGs with metadata"""
+ """List of DAGs with metadata."""
- dags: List[DagModel]
+ dags: list[DagModel]
total_entries: int
class DAGCollectionSchema(Schema):
- """DAG Collection schema"""
+ """DAG Collection schema."""
dags = fields.List(fields.Nested(DAGSchema))
total_entries = fields.Int()
diff --git a/airflow/api_connexion/schemas/dag_source_schema.py b/airflow/api_connexion/schemas/dag_source_schema.py
index d142454bc1f6d..adb89ce76a569 100644
--- a/airflow/api_connexion/schemas/dag_source_schema.py
+++ b/airflow/api_connexion/schemas/dag_source_schema.py
@@ -14,12 +14,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
from marshmallow import Schema, fields
class DagSourceSchema(Schema):
- """Dag Source schema"""
+ """Dag Source schema."""
content = fields.String(dump_only=True)
diff --git a/airflow/api_connexion/schemas/dag_warning_schema.py b/airflow/api_connexion/schemas/dag_warning_schema.py
new file mode 100644
index 0000000000000..35c9830d273c7
--- /dev/null
+++ b/airflow/api_connexion/schemas/dag_warning_schema.py
@@ -0,0 +1,56 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import NamedTuple
+
+from marshmallow import Schema, fields
+from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field
+
+from airflow.models.dagwarning import DagWarning
+
+
+class DagWarningSchema(SQLAlchemySchema):
+ """Import error schema."""
+
+ class Meta:
+ """Meta."""
+
+ model = DagWarning
+
+ dag_id = auto_field(data_key="dag_id", dump_only=True)
+ warning_type = auto_field()
+ message = auto_field()
+ timestamp = auto_field(format="iso")
+
+
+class DagWarningCollection(NamedTuple):
+ """List of dag warnings with metadata."""
+
+ dag_warnings: list[DagWarning]
+ total_entries: int
+
+
+class DagWarningCollectionSchema(Schema):
+ """Import error collection schema."""
+
+ dag_warnings = fields.List(fields.Nested(DagWarningSchema))
+ total_entries = fields.Int()
+
+
+dag_warning_schema = DagWarningSchema()
+dag_warning_collection_schema = DagWarningCollectionSchema()
diff --git a/airflow/api_connexion/schemas/dataset_schema.py b/airflow/api_connexion/schemas/dataset_schema.py
new file mode 100644
index 0000000000000..bfdd0d24231c8
--- /dev/null
+++ b/airflow/api_connexion/schemas/dataset_schema.py
@@ -0,0 +1,150 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import NamedTuple
+
+from marshmallow import Schema, fields
+from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field
+
+from airflow.api_connexion.schemas.common_schema import JsonObjectField
+from airflow.models.dagrun import DagRun
+from airflow.models.dataset import (
+ DagScheduleDatasetReference,
+ DatasetEvent,
+ DatasetModel,
+ TaskOutletDatasetReference,
+)
+
+
+class TaskOutletDatasetReferenceSchema(SQLAlchemySchema):
+ """TaskOutletDatasetReference DB schema."""
+
+ class Meta:
+ """Meta."""
+
+ model = TaskOutletDatasetReference
+
+ dag_id = auto_field()
+ task_id = auto_field()
+ created_at = auto_field()
+ updated_at = auto_field()
+
+
+class DagScheduleDatasetReferenceSchema(SQLAlchemySchema):
+ """DagScheduleDatasetReference DB schema."""
+
+ class Meta:
+ """Meta."""
+
+ model = DagScheduleDatasetReference
+
+ dag_id = auto_field()
+ created_at = auto_field()
+ updated_at = auto_field()
+
+
+class DatasetSchema(SQLAlchemySchema):
+ """Dataset DB schema."""
+
+ class Meta:
+ """Meta."""
+
+ model = DatasetModel
+
+ id = auto_field()
+ uri = auto_field()
+ extra = JsonObjectField()
+ created_at = auto_field()
+ updated_at = auto_field()
+ producing_tasks = fields.List(fields.Nested(TaskOutletDatasetReferenceSchema))
+ consuming_dags = fields.List(fields.Nested(DagScheduleDatasetReferenceSchema))
+
+
+class DatasetCollection(NamedTuple):
+ """List of Datasets with meta."""
+
+ datasets: list[DatasetModel]
+ total_entries: int
+
+
+class DatasetCollectionSchema(Schema):
+ """Dataset Collection Schema."""
+
+ datasets = fields.List(fields.Nested(DatasetSchema))
+ total_entries = fields.Int()
+
+
+dataset_schema = DatasetSchema()
+dataset_collection_schema = DatasetCollectionSchema()
+
+
+class BasicDAGRunSchema(SQLAlchemySchema):
+ """Basic Schema for DAGRun."""
+
+ class Meta:
+ """Meta."""
+
+ model = DagRun
+ dateformat = "iso"
+
+ run_id = auto_field(data_key="dag_run_id")
+ dag_id = auto_field(dump_only=True)
+ execution_date = auto_field(data_key="logical_date", dump_only=True)
+ start_date = auto_field(dump_only=True)
+ end_date = auto_field(dump_only=True)
+ state = auto_field(dump_only=True)
+ data_interval_start = auto_field(dump_only=True)
+ data_interval_end = auto_field(dump_only=True)
+
+
+class DatasetEventSchema(SQLAlchemySchema):
+ """Dataset Event DB schema."""
+
+ class Meta:
+ """Meta."""
+
+ model = DatasetEvent
+
+ id = auto_field()
+ dataset_id = auto_field()
+ dataset_uri = fields.String(attribute="dataset.uri", dump_only=True)
+ extra = JsonObjectField()
+ source_task_id = auto_field()
+ source_dag_id = auto_field()
+ source_run_id = auto_field()
+ source_map_index = auto_field()
+ created_dagruns = fields.List(fields.Nested(BasicDAGRunSchema))
+ timestamp = auto_field()
+
+
+class DatasetEventCollection(NamedTuple):
+ """List of Dataset events with meta."""
+
+ dataset_events: list[DatasetEvent]
+ total_entries: int
+
+
+class DatasetEventCollectionSchema(Schema):
+ """Dataset Event Collection Schema."""
+
+ dataset_events = fields.List(fields.Nested(DatasetEventSchema))
+ total_entries = fields.Int()
+
+
+dataset_event_schema = DatasetEventSchema()
+dataset_event_collection_schema = DatasetEventCollectionSchema()
diff --git a/airflow/api_connexion/schemas/enum_schemas.py b/airflow/api_connexion/schemas/enum_schemas.py
index 71faf9fa20436..981a3669b1b58 100644
--- a/airflow/api_connexion/schemas/enum_schemas.py
+++ b/airflow/api_connexion/schemas/enum_schemas.py
@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
from marshmallow import fields, validate
@@ -21,7 +22,7 @@
class DagStateField(fields.String):
- """Schema for DagState Enum"""
+ """Schema for DagState Enum."""
def __init__(self, **metadata):
super().__init__(**metadata)
@@ -29,7 +30,7 @@ def __init__(self, **metadata):
class TaskInstanceStateField(fields.String):
- """Schema for TaskInstanceState Enum"""
+ """Schema for TaskInstanceState Enum."""
def __init__(self, **metadata):
super().__init__(**metadata)
diff --git a/airflow/api_connexion/schemas/error_schema.py b/airflow/api_connexion/schemas/error_schema.py
index c9462b5f967a8..dcd4d37ff7781 100644
--- a/airflow/api_connexion/schemas/error_schema.py
+++ b/airflow/api_connexion/schemas/error_schema.py
@@ -14,7 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import List, NamedTuple
+from __future__ import annotations
+
+from typing import NamedTuple
from marshmallow import Schema, fields
from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field
@@ -23,10 +25,10 @@
class ImportErrorSchema(SQLAlchemySchema):
- """Import error schema"""
+ """Import error schema."""
class Meta:
- """Meta"""
+ """Meta."""
model = ImportError
@@ -39,14 +41,14 @@ class Meta:
class ImportErrorCollection(NamedTuple):
- """List of import errors with metadata"""
+ """List of import errors with metadata."""
- import_errors: List[ImportError]
+ import_errors: list[ImportError]
total_entries: int
class ImportErrorCollectionSchema(Schema):
- """Import error collection schema"""
+ """Import error collection schema."""
import_errors = fields.List(fields.Nested(ImportErrorSchema))
total_entries = fields.Int()
diff --git a/airflow/api_connexion/schemas/event_log_schema.py b/airflow/api_connexion/schemas/event_log_schema.py
index d97c223bffa23..5bf4ccf00d0b3 100644
--- a/airflow/api_connexion/schemas/event_log_schema.py
+++ b/airflow/api_connexion/schemas/event_log_schema.py
@@ -14,8 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
-from typing import List, NamedTuple
+from typing import NamedTuple
from marshmallow import Schema, fields
from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field
@@ -24,15 +25,15 @@
class EventLogSchema(SQLAlchemySchema):
- """Event log schema"""
+ """Event log schema."""
class Meta:
- """Meta"""
+ """Meta."""
model = Log
- id = auto_field(data_key='event_log_id', dump_only=True)
- dttm = auto_field(data_key='when', dump_only=True)
+ id = auto_field(data_key="event_log_id", dump_only=True)
+ dttm = auto_field(data_key="when", dump_only=True)
dag_id = auto_field(dump_only=True)
task_id = auto_field(dump_only=True)
event = auto_field(dump_only=True)
@@ -42,14 +43,14 @@ class Meta:
class EventLogCollection(NamedTuple):
- """List of import errors with metadata"""
+ """List of import errors with metadata."""
- event_logs: List[Log]
+ event_logs: list[Log]
total_entries: int
class EventLogCollectionSchema(Schema):
- """EventLog Collection Schema"""
+ """EventLog Collection Schema."""
event_logs = fields.List(fields.Nested(EventLogSchema))
total_entries = fields.Int()
diff --git a/airflow/api_connexion/schemas/health_schema.py b/airflow/api_connexion/schemas/health_schema.py
index 7089babb6230b..67155406c1a79 100644
--- a/airflow/api_connexion/schemas/health_schema.py
+++ b/airflow/api_connexion/schemas/health_schema.py
@@ -14,28 +14,29 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
from marshmallow import Schema, fields
class BaseInfoSchema(Schema):
- """Base status field for metadatabase and scheduler"""
+ """Base status field for metadatabase and scheduler."""
status = fields.String(dump_only=True)
class MetaDatabaseInfoSchema(BaseInfoSchema):
- """Schema for Metadatabase info"""
+ """Schema for Metadatabase info."""
class SchedulerInfoSchema(BaseInfoSchema):
- """Schema for Metadatabase info"""
+ """Schema for Metadatabase info."""
latest_scheduler_heartbeat = fields.String(dump_only=True)
class HealthInfoSchema(Schema):
- """Schema for the Health endpoint"""
+ """Schema for the Health endpoint."""
metadatabase = fields.Nested(MetaDatabaseInfoSchema)
scheduler = fields.Nested(SchedulerInfoSchema)
diff --git a/airflow/api_connexion/schemas/job_schema.py b/airflow/api_connexion/schemas/job_schema.py
new file mode 100644
index 0000000000000..4d98d39c92030
--- /dev/null
+++ b/airflow/api_connexion/schemas/job_schema.py
@@ -0,0 +1,42 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field
+
+from airflow.jobs.base_job import BaseJob
+
+
+class JobSchema(SQLAlchemySchema):
+ """Sla Miss Schema."""
+
+ class Meta:
+ """Meta."""
+
+ model = BaseJob
+
+ id = auto_field(dump_only=True)
+ dag_id = auto_field(dump_only=True)
+ state = auto_field(dump_only=True)
+ job_type = auto_field(dump_only=True)
+ start_date = auto_field(dump_only=True)
+ end_date = auto_field(dump_only=True)
+ latest_heartbeat = auto_field(dump_only=True)
+ executor_class = auto_field(dump_only=True)
+ hostname = auto_field(dump_only=True)
+ unixname = auto_field(dump_only=True)
diff --git a/airflow/api_connexion/schemas/log_schema.py b/airflow/api_connexion/schemas/log_schema.py
index eff97e1723d63..82e291fafc42c 100644
--- a/airflow/api_connexion/schemas/log_schema.py
+++ b/airflow/api_connexion/schemas/log_schema.py
@@ -14,23 +14,25 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import NamedTuple, Optional
+from __future__ import annotations
+
+from typing import NamedTuple
from marshmallow import Schema, fields
class LogsSchema(Schema):
- """Schema for logs"""
+ """Schema for logs."""
content = fields.Str()
continuation_token = fields.Str()
class LogResponseObject(NamedTuple):
- """Log Response Object"""
+ """Log Response Object."""
content: str
- continuation_token: Optional[str]
+ continuation_token: str | None
logs_schema = LogsSchema()
diff --git a/airflow/api_connexion/schemas/plugin_schema.py b/airflow/api_connexion/schemas/plugin_schema.py
index 89704cd191be6..780fef17bf76a 100644
--- a/airflow/api_connexion/schemas/plugin_schema.py
+++ b/airflow/api_connexion/schemas/plugin_schema.py
@@ -14,37 +14,37 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
-from typing import List, NamedTuple
+from typing import NamedTuple
from marshmallow import Schema, fields
class PluginSchema(Schema):
- """Plugin schema"""
+ """Plugin schema."""
- number = fields.Int()
name = fields.String()
hooks = fields.List(fields.String())
executors = fields.List(fields.String())
- macros = fields.List(fields.String())
- flask_blueprints = fields.List(fields.String())
- appbuilder_views = fields.List(fields.String())
+ macros = fields.List(fields.Dict())
+ flask_blueprints = fields.List(fields.Dict())
+ appbuilder_views = fields.List(fields.Dict())
appbuilder_menu_items = fields.List(fields.Dict())
- global_operator_extra_links = fields.List(fields.String())
- operator_extra_links = fields.List(fields.String())
+ global_operator_extra_links = fields.List(fields.Dict())
+ operator_extra_links = fields.List(fields.Dict())
source = fields.String()
class PluginCollection(NamedTuple):
- """Plugin List"""
+ """Plugin List."""
- plugins: List
+ plugins: list
total_entries: int
class PluginCollectionSchema(Schema):
- """Plugin Collection List"""
+ """Plugin Collection List."""
plugins = fields.List(fields.Nested(PluginSchema))
total_entries = fields.Int()
diff --git a/airflow/api_connexion/schemas/pool_schema.py b/airflow/api_connexion/schemas/pool_schema.py
index 1f91b0d5bdd1c..4e25287d1d357 100644
--- a/airflow/api_connexion/schemas/pool_schema.py
+++ b/airflow/api_connexion/schemas/pool_schema.py
@@ -14,8 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
-from typing import List, NamedTuple
+from typing import NamedTuple
from marshmallow import Schema, fields
from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field
@@ -24,10 +25,10 @@
class PoolSchema(SQLAlchemySchema):
- """Pool schema"""
+ """Pool schema."""
class Meta:
- """Meta"""
+ """Meta."""
model = Pool
@@ -36,6 +37,7 @@ class Meta:
occupied_slots = fields.Method("get_occupied_slots", dump_only=True)
running_slots = fields.Method("get_running_slots", dump_only=True)
queued_slots = fields.Method("get_queued_slots", dump_only=True)
+ scheduled_slots = fields.Method("get_scheduled_slots", dump_only=True)
open_slots = fields.Method("get_open_slots", dump_only=True)
description = auto_field()
@@ -54,6 +56,11 @@ def get_queued_slots(obj: Pool) -> int:
"""Returns the queued slots of the pool."""
return obj.queued_slots()
+ @staticmethod
+ def get_scheduled_slots(obj: Pool) -> int:
+ """Returns the scheduled slots of the pool."""
+ return obj.scheduled_slots()
+
@staticmethod
def get_open_slots(obj: Pool) -> float:
"""Returns the open slots of the pool."""
@@ -61,14 +68,14 @@ def get_open_slots(obj: Pool) -> float:
class PoolCollection(NamedTuple):
- """List of Pools with metadata"""
+ """List of Pools with metadata."""
- pools: List[Pool]
+ pools: list[Pool]
total_entries: int
class PoolCollectionSchema(Schema):
- """Pool Collection schema"""
+ """Pool Collection schema."""
pools = fields.List(fields.Nested(PoolSchema))
total_entries = fields.Int()
diff --git a/airflow/api_connexion/schemas/provider_schema.py b/airflow/api_connexion/schemas/provider_schema.py
index 4c9867380bd99..ad62f4ae26f10 100644
--- a/airflow/api_connexion/schemas/provider_schema.py
+++ b/airflow/api_connexion/schemas/provider_schema.py
@@ -14,8 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
-from typing import List, NamedTuple
+from typing import NamedTuple
from marshmallow import Schema, fields
@@ -41,7 +42,7 @@ class Provider(TypedDict):
class ProviderCollection(NamedTuple):
"""List of Providers."""
- providers: List[Provider]
+ providers: list[Provider]
total_entries: int
diff --git a/airflow/api_connexion/schemas/role_and_permission_schema.py b/airflow/api_connexion/schemas/role_and_permission_schema.py
index 4031750199f83..324336c288668 100644
--- a/airflow/api_connexion/schemas/role_and_permission_schema.py
+++ b/airflow/api_connexion/schemas/role_and_permission_schema.py
@@ -14,8 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
-from typing import List, NamedTuple
+from typing import NamedTuple
from marshmallow import Schema, fields
from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field
@@ -24,10 +25,10 @@
class ActionSchema(SQLAlchemySchema):
- """Action Action Schema"""
+ """Action Schema."""
class Meta:
- """Meta"""
+ """Meta."""
model = Action
@@ -35,10 +36,10 @@ class Meta:
class ResourceSchema(SQLAlchemySchema):
- """View menu Schema"""
+ """View menu Schema."""
class Meta:
- """Meta"""
+ """Meta."""
model = Resource
@@ -46,24 +47,24 @@ class Meta:
class ActionCollection(NamedTuple):
- """Action Action Collection"""
+ """Action Collection."""
- actions: List[Action]
+ actions: list[Action]
total_entries: int
class ActionCollectionSchema(Schema):
- """Permissions list schema"""
+ """Permissions list schema."""
actions = fields.List(fields.Nested(ActionSchema))
total_entries = fields.Int()
class ActionResourceSchema(SQLAlchemySchema):
- """Action View Schema"""
+ """Action View Schema."""
class Meta:
- """Meta"""
+ """Meta."""
model = Permission
@@ -72,26 +73,26 @@ class Meta:
class RoleSchema(SQLAlchemySchema):
- """Role item schema"""
+ """Role item schema."""
class Meta:
- """Meta"""
+ """Meta."""
model = Role
name = auto_field()
- permissions = fields.List(fields.Nested(ActionResourceSchema), data_key='actions')
+ permissions = fields.List(fields.Nested(ActionResourceSchema), data_key="actions")
class RoleCollection(NamedTuple):
- """List of roles"""
+ """List of roles."""
- roles: List[Role]
+ roles: list[Role]
total_entries: int
class RoleCollectionSchema(Schema):
- """List of roles"""
+ """List of roles."""
roles = fields.List(fields.Nested(RoleSchema))
total_entries = fields.Int()
diff --git a/airflow/api_connexion/schemas/sla_miss_schema.py b/airflow/api_connexion/schemas/sla_miss_schema.py
index 9413e37cbde21..97a462e186d59 100644
--- a/airflow/api_connexion/schemas/sla_miss_schema.py
+++ b/airflow/api_connexion/schemas/sla_miss_schema.py
@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field
@@ -21,10 +22,10 @@
class SlaMissSchema(SQLAlchemySchema):
- """Sla Miss Schema"""
+ """Sla Miss Schema."""
class Meta:
- """Meta"""
+ """Meta."""
model = SlaMiss
diff --git a/airflow/api_connexion/schemas/task_instance_schema.py b/airflow/api_connexion/schemas/task_instance_schema.py
index 37005256f6cdc..970ef9a0fd4d6 100644
--- a/airflow/api_connexion/schemas/task_instance_schema.py
+++ b/airflow/api_connexion/schemas/task_instance_schema.py
@@ -14,8 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
-from typing import List, NamedTuple, Optional, Tuple
+from typing import NamedTuple
from marshmallow import Schema, ValidationError, fields, validate, validates_schema
from marshmallow.utils import get_value
@@ -24,17 +25,19 @@
from airflow.api_connexion.parameters import validate_istimezone
from airflow.api_connexion.schemas.common_schema import JsonObjectField
from airflow.api_connexion.schemas.enum_schemas import TaskInstanceStateField
+from airflow.api_connexion.schemas.job_schema import JobSchema
from airflow.api_connexion.schemas.sla_miss_schema import SlaMissSchema
+from airflow.api_connexion.schemas.trigger_schema import TriggerSchema
from airflow.models import SlaMiss, TaskInstance
from airflow.utils.helpers import exactly_one
from airflow.utils.state import State
class TaskInstanceSchema(SQLAlchemySchema):
- """Task instance schema"""
+ """Task instance schema."""
class Meta:
- """Meta"""
+ """Meta."""
model = TaskInstance
@@ -59,8 +62,11 @@ class Meta:
queued_dttm = auto_field(data_key="queued_when")
pid = auto_field()
executor_config = auto_field()
+ note = auto_field()
sla_miss = fields.Nested(SlaMissSchema, dump_default=None)
- rendered_fields = JsonObjectField(default={})
+ rendered_fields = JsonObjectField(dump_default={})
+ trigger = fields.Nested(TriggerSchema)
+ triggerer_job = fields.Nested(JobSchema)
def get_attribute(self, obj, attr, default):
if attr == "sla_miss":
@@ -75,21 +81,21 @@ def get_attribute(self, obj, attr, default):
class TaskInstanceCollection(NamedTuple):
- """List of task instances with metadata"""
+ """List of task instances with metadata."""
- task_instances: List[Tuple[TaskInstance, Optional[SlaMiss]]]
+ task_instances: list[tuple[TaskInstance, SlaMiss | None]]
total_entries: int
class TaskInstanceCollectionSchema(Schema):
- """Task instance collection schema"""
+ """Task instance collection schema."""
task_instances = fields.List(fields.Nested(TaskInstanceSchema))
total_entries = fields.Int()
class TaskInstanceBatchFormSchema(Schema):
- """Schema for the request form passed to Task Instance Batch endpoint"""
+ """Schema for the request form passed to Task Instance Batch endpoint."""
page_offset = fields.Int(load_default=0, validate=validate.Range(min=0))
page_limit = fields.Int(load_default=100, validate=validate.Range(min=1))
@@ -108,7 +114,7 @@ class TaskInstanceBatchFormSchema(Schema):
class ClearTaskInstanceFormSchema(Schema):
- """Schema for handling the request of clearing task instance of a Dag"""
+ """Schema for handling the request of clearing task instance of a Dag."""
dry_run = fields.Boolean(load_default=True)
start_date = fields.DateTime(load_default=None, validate=validate_istimezone)
@@ -119,19 +125,30 @@ class ClearTaskInstanceFormSchema(Schema):
include_parentdag = fields.Boolean(load_default=False)
reset_dag_runs = fields.Boolean(load_default=False)
task_ids = fields.List(fields.String(), validate=validate.Length(min=1))
+ dag_run_id = fields.Str(load_default=None)
+ include_upstream = fields.Boolean(load_default=False)
+ include_downstream = fields.Boolean(load_default=False)
+ include_future = fields.Boolean(load_default=False)
+ include_past = fields.Boolean(load_default=False)
@validates_schema
def validate_form(self, data, **kwargs):
- """Validates clear task instance form"""
+ """Validates clear task instance form."""
if data["only_failed"] and data["only_running"]:
raise ValidationError("only_failed and only_running both are set to True")
if data["start_date"] and data["end_date"]:
if data["start_date"] > data["end_date"]:
raise ValidationError("end_date is sooner than start_date")
+ if data["start_date"] and data["end_date"] and data["dag_run_id"]:
+ raise ValidationError("Exactly one of dag_run_id or (start_date and end_date) must be provided")
+ if data["start_date"] and data["dag_run_id"]:
+ raise ValidationError("Exactly one of dag_run_id or start_date must be provided")
+ if data["end_date"] and data["dag_run_id"]:
+ raise ValidationError("Exactly one of dag_run_id or end_date must be provided")
class SetTaskInstanceStateFormSchema(Schema):
- """Schema for handling the request of setting state of task instance of a DAG"""
+ """Schema for handling the request of setting state of task instance of a DAG."""
dry_run = fields.Boolean(dump_default=True)
task_id = fields.Str(required=True)
@@ -145,13 +162,20 @@ class SetTaskInstanceStateFormSchema(Schema):
@validates_schema
def validate_form(self, data, **kwargs):
- """Validates set task instance state form"""
+ """Validates set task instance state form."""
if not exactly_one(data.get("execution_date"), data.get("dag_run_id")):
raise ValidationError("Exactly one of execution_date or dag_run_id must be provided")
+class SetSingleTaskInstanceStateFormSchema(Schema):
+ """Schema for handling the request of updating state of a single task instance."""
+
+ dry_run = fields.Boolean(dump_default=True)
+ new_state = TaskInstanceStateField(required=True, validate=validate.OneOf([State.SUCCESS, State.FAILED]))
+
+
class TaskInstanceReferenceSchema(Schema):
- """Schema for the task instance reference schema"""
+ """Schema for the task instance reference schema."""
task_id = fields.Str()
run_id = fields.Str(data_key="dag_run_id")
@@ -160,21 +184,31 @@ class TaskInstanceReferenceSchema(Schema):
class TaskInstanceReferenceCollection(NamedTuple):
- """List of objects with metadata about taskinstance and dag_run_id"""
+ """List of objects with metadata about taskinstance and dag_run_id."""
- task_instances: List[Tuple[TaskInstance, str]]
+ task_instances: list[tuple[TaskInstance, str]]
class TaskInstanceReferenceCollectionSchema(Schema):
- """Collection schema for task reference"""
+ """Collection schema for task reference."""
task_instances = fields.List(fields.Nested(TaskInstanceReferenceSchema))
+class SetTaskInstanceNoteFormSchema(Schema):
+ """Schema for settings a note for a TaskInstance."""
+
+ # Note: We can't add map_index to the url as subpaths can't start with dashes.
+ map_index = fields.Int(allow_none=False)
+ note = fields.String(allow_none=True, validate=validate.Length(max=1000))
+
+
task_instance_schema = TaskInstanceSchema()
task_instance_collection_schema = TaskInstanceCollectionSchema()
task_instance_batch_form = TaskInstanceBatchFormSchema()
clear_task_instance_form = ClearTaskInstanceFormSchema()
set_task_instance_state_form = SetTaskInstanceStateFormSchema()
+set_single_task_instance_state_form = SetSingleTaskInstanceStateFormSchema()
task_instance_reference_schema = TaskInstanceReferenceSchema()
task_instance_reference_collection_schema = TaskInstanceReferenceCollectionSchema()
+set_task_instance_note_form_schema = SetTaskInstanceNoteFormSchema()
diff --git a/airflow/api_connexion/schemas/task_schema.py b/airflow/api_connexion/schemas/task_schema.py
index aa9a4703088b0..5715ca2ea0fa3 100644
--- a/airflow/api_connexion/schemas/task_schema.py
+++ b/airflow/api_connexion/schemas/task_schema.py
@@ -14,8 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
-from typing import List, NamedTuple
+from typing import NamedTuple
from marshmallow import Schema, fields
@@ -26,13 +27,15 @@
WeightRuleField,
)
from airflow.api_connexion.schemas.dag_schema import DAGSchema
+from airflow.models.mappedoperator import MappedOperator
from airflow.models.operator import Operator
class TaskSchema(Schema):
- """Task schema"""
+ """Task schema."""
class_ref = fields.Method("_get_class_reference", dump_only=True)
+ operator_name = fields.Method("_get_operator_name", dump_only=True)
task_id = fields.String(dump_only=True)
owner = fields.String(dump_only=True)
start_date = fields.DateTime(dump_only=True)
@@ -57,29 +60,38 @@ class TaskSchema(Schema):
template_fields = fields.List(fields.String(), dump_only=True)
sub_dag = fields.Nested(DAGSchema, dump_only=True)
downstream_task_ids = fields.List(fields.String(), dump_only=True)
- params = fields.Method('get_params', dump_only=True)
- is_mapped = fields.Boolean(dump_only=True)
+ params = fields.Method("_get_params", dump_only=True)
+ is_mapped = fields.Method("_get_is_mapped", dump_only=True)
- def _get_class_reference(self, obj):
+ @staticmethod
+ def _get_class_reference(obj):
result = ClassReferenceSchema().dump(obj)
return result.data if hasattr(result, "data") else result
@staticmethod
- def get_params(obj):
- """Get the Params defined in a Task"""
+ def _get_operator_name(obj):
+ return obj.operator_name
+
+ @staticmethod
+ def _get_params(obj):
+ """Get the Params defined in a Task."""
params = obj.params
return {k: v.dump() for k, v in params.items()}
+ @staticmethod
+ def _get_is_mapped(obj):
+ return isinstance(obj, MappedOperator)
+
class TaskCollection(NamedTuple):
- """List of Tasks with metadata"""
+ """List of Tasks with metadata."""
- tasks: List[Operator]
+ tasks: list[Operator]
total_entries: int
class TaskCollectionSchema(Schema):
- """Schema for TaskCollection"""
+ """Schema for TaskCollection."""
tasks = fields.List(fields.Nested(TaskSchema))
total_entries = fields.Int()
diff --git a/airflow/api_connexion/schemas/trigger_schema.py b/airflow/api_connexion/schemas/trigger_schema.py
new file mode 100644
index 0000000000000..15d180a5732ff
--- /dev/null
+++ b/airflow/api_connexion/schemas/trigger_schema.py
@@ -0,0 +1,37 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field
+
+from airflow.models import Trigger
+
+
+class TriggerSchema(SQLAlchemySchema):
+ """Sla Miss Schema."""
+
+ class Meta:
+ """Meta."""
+
+ model = Trigger
+
+ id = auto_field(dump_only=True)
+ classpath = auto_field(dump_only=True)
+ kwargs = auto_field(dump_only=True)
+ created_date = auto_field(dump_only=True)
+ triggerer_id = auto_field(dump_only=True)
diff --git a/airflow/api_connexion/schemas/user_schema.py b/airflow/api_connexion/schemas/user_schema.py
index 3d36aa91c8031..843ad32f0245c 100644
--- a/airflow/api_connexion/schemas/user_schema.py
+++ b/airflow/api_connexion/schemas/user_schema.py
@@ -14,7 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import List, NamedTuple
+from __future__ import annotations
+
+from typing import NamedTuple
from marshmallow import Schema, fields
from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field
@@ -25,10 +27,10 @@
class UserCollectionItemSchema(SQLAlchemySchema):
- """user collection item schema"""
+ """user collection item schema."""
class Meta:
- """Meta"""
+ """Meta."""
model = User
dateformat = "iso"
@@ -41,26 +43,26 @@ class Meta:
last_login = auto_field(dump_only=True)
login_count = auto_field(dump_only=True)
fail_login_count = auto_field(dump_only=True)
- roles = fields.List(fields.Nested(RoleSchema, only=('name',)))
+ roles = fields.List(fields.Nested(RoleSchema, only=("name",)))
created_on = auto_field(validate=validate_istimezone, dump_only=True)
changed_on = auto_field(validate=validate_istimezone, dump_only=True)
class UserSchema(UserCollectionItemSchema):
- """User schema"""
+ """User schema."""
password = auto_field(load_only=True)
class UserCollection(NamedTuple):
- """User collection"""
+ """User collection."""
- users: List[User]
+ users: list[User]
total_entries: int
class UserCollectionSchema(Schema):
- """User collection schema"""
+ """User collection schema."""
users = fields.List(fields.Nested(UserCollectionItemSchema))
total_entries = fields.Int()
diff --git a/airflow/api_connexion/schemas/variable_schema.py b/airflow/api_connexion/schemas/variable_schema.py
index 6b5d16e4227d6..ffe54f0742907 100644
--- a/airflow/api_connexion/schemas/variable_schema.py
+++ b/airflow/api_connexion/schemas/variable_schema.py
@@ -14,19 +14,21 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
from marshmallow import Schema, fields
class VariableSchema(Schema):
- """Variable Schema"""
+ """Variable Schema."""
key = fields.String(required=True)
value = fields.String(attribute="val", required=True)
+ description = fields.String(attribute="description", required=False)
class VariableCollectionSchema(Schema):
- """Variable Collection Schema"""
+ """Variable Collection Schema."""
variables = fields.List(fields.Nested(VariableSchema))
total_entries = fields.Int()
diff --git a/airflow/api_connexion/schemas/version_schema.py b/airflow/api_connexion/schemas/version_schema.py
index 24bd9337c1c36..519f91c55e816 100644
--- a/airflow/api_connexion/schemas/version_schema.py
+++ b/airflow/api_connexion/schemas/version_schema.py
@@ -14,12 +14,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
from marshmallow import Schema, fields
class VersionInfoSchema(Schema):
- """Version information schema"""
+ """Version information schema."""
version = fields.String(dump_only=True)
git_version = fields.String(dump_only=True)
diff --git a/airflow/api_connexion/schemas/xcom_schema.py b/airflow/api_connexion/schemas/xcom_schema.py
index b3f3f0dd021a2..09d2505bf7d4d 100644
--- a/airflow/api_connexion/schemas/xcom_schema.py
+++ b/airflow/api_connexion/schemas/xcom_schema.py
@@ -14,7 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import List, NamedTuple
+from __future__ import annotations
+
+from typing import NamedTuple
from marshmallow import Schema, fields
from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field
@@ -23,10 +25,10 @@
class XComCollectionItemSchema(SQLAlchemySchema):
- """Schema for a xcom item"""
+ """Schema for a xcom item."""
class Meta:
- """Meta"""
+ """Meta."""
model = XCom
@@ -38,20 +40,20 @@ class Meta:
class XComSchema(XComCollectionItemSchema):
- """XCom schema"""
+ """XCom schema."""
value = auto_field()
class XComCollection(NamedTuple):
- """List of XComs with meta"""
+ """List of XComs with meta."""
- xcom_entries: List[XCom]
+ xcom_entries: list[XCom]
total_entries: int
class XComCollectionSchema(Schema):
- """XCom Collection Schema"""
+ """XCom Collection Schema."""
xcom_entries = fields.List(fields.Nested(XComCollectionItemSchema))
total_entries = fields.Int()
diff --git a/airflow/api_connexion/security.py b/airflow/api_connexion/security.py
index 3562c98eb4b35..0e3e9b1f3748b 100644
--- a/airflow/api_connexion/security.py
+++ b/airflow/api_connexion/security.py
@@ -14,20 +14,22 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
from functools import wraps
-from typing import Callable, Optional, Sequence, Tuple, TypeVar, cast
+from typing import Callable, Sequence, TypeVar, cast
-from flask import Response, current_app
+from flask import Response
from airflow.api_connexion.exceptions import PermissionDenied, Unauthenticated
+from airflow.utils.airflow_flask_app import get_airflow_app
T = TypeVar("T", bound=Callable)
def check_authentication() -> None:
"""Checks that the request has valid authorization information."""
- for auth in current_app.api_auth:
+ for auth in get_airflow_app().api_auth:
response = auth.requires_authentication(Response)()
if response.status_code == 200:
return
@@ -36,16 +38,16 @@ def check_authentication() -> None:
raise Unauthenticated(headers=response.headers)
-def requires_access(permissions: Optional[Sequence[Tuple[str, str]]] = None) -> Callable[[T], T]:
+def requires_access(permissions: Sequence[tuple[str, str]] | None = None) -> Callable[[T], T]:
"""Factory for decorator that checks current user's permissions against required permissions."""
- appbuilder = current_app.appbuilder
+ appbuilder = get_airflow_app().appbuilder
appbuilder.sm.sync_resource_permissions(permissions)
def requires_access_decorator(func: T):
@wraps(func)
def decorated(*args, **kwargs):
check_authentication()
- if appbuilder.sm.check_authorization(permissions, kwargs.get('dag_id')):
+ if appbuilder.sm.check_authorization(permissions, kwargs.get("dag_id")):
return func(*args, **kwargs)
raise PermissionDenied()
diff --git a/airflow/api_connexion/types.py b/airflow/api_connexion/types.py
index f640d14bd7e79..3a6f89d9bb52a 100644
--- a/airflow/api_connexion/types.py
+++ b/airflow/api_connexion/types.py
@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
from typing import Any, Mapping, Optional, Sequence, Tuple, Union
diff --git a/airflow/callbacks/base_callback_sink.py b/airflow/callbacks/base_callback_sink.py
index e7cbf23e7b959..c243f0fbd640f 100644
--- a/airflow/callbacks/base_callback_sink.py
+++ b/airflow/callbacks/base_callback_sink.py
@@ -15,7 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-#
+from __future__ import annotations
from airflow.callbacks.callback_requests import CallbackRequest
diff --git a/airflow/callbacks/callback_requests.py b/airflow/callbacks/callback_requests.py
index 8112589cd0262..8ec0187978db6 100644
--- a/airflow/callbacks/callback_requests.py
+++ b/airflow/callbacks/callback_requests.py
@@ -14,9 +14,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
import json
-from typing import TYPE_CHECKING, Optional
+from typing import TYPE_CHECKING
if TYPE_CHECKING:
from airflow.models.taskinstance import SimpleTaskInstance
@@ -28,10 +29,17 @@ class CallbackRequest:
:param full_filepath: File Path to use to run the callback
:param msg: Additional Message that can be used for logging
+ :param processor_subdir: Directory used by Dag Processor when parsed the dag.
"""
- def __init__(self, full_filepath: str, msg: Optional[str] = None):
+ def __init__(
+ self,
+ full_filepath: str,
+ processor_subdir: str | None = None,
+ msg: str | None = None,
+ ):
self.full_filepath = full_filepath
+ self.processor_subdir = processor_subdir
self.msg = msg
def __eq__(self, other):
@@ -53,6 +61,8 @@ def from_json(cls, json_str: str):
class TaskCallbackRequest(CallbackRequest):
"""
+ Task callback status information.
+
A Class with information about the success/failure TI callback to be executed. Currently, only failure
callbacks (when tasks are externally killed) and Zombies are run via DagFileProcessorProcess.
@@ -60,31 +70,33 @@ class TaskCallbackRequest(CallbackRequest):
:param simple_task_instance: Simplified Task Instance representation
:param is_failure_callback: Flag to determine whether it is a Failure Callback or Success Callback
:param msg: Additional Message that can be used for logging to determine failure/zombie
+ :param processor_subdir: Directory used by Dag Processor when parsed the dag.
"""
def __init__(
self,
full_filepath: str,
- simple_task_instance: "SimpleTaskInstance",
- is_failure_callback: Optional[bool] = True,
- msg: Optional[str] = None,
+ simple_task_instance: SimpleTaskInstance,
+ is_failure_callback: bool | None = True,
+ processor_subdir: str | None = None,
+ msg: str | None = None,
):
- super().__init__(full_filepath=full_filepath, msg=msg)
+ super().__init__(full_filepath=full_filepath, processor_subdir=processor_subdir, msg=msg)
self.simple_task_instance = simple_task_instance
self.is_failure_callback = is_failure_callback
def to_json(self) -> str:
- dict_obj = self.__dict__.copy()
- dict_obj["simple_task_instance"] = dict_obj["simple_task_instance"].__dict__
- return json.dumps(dict_obj)
+ from airflow.serialization.serialized_objects import BaseSerialization
+
+ val = BaseSerialization.serialize(self.__dict__, strict=True)
+ return json.dumps(val)
@classmethod
def from_json(cls, json_str: str):
- from airflow.models.taskinstance import SimpleTaskInstance
+ from airflow.serialization.serialized_objects import BaseSerialization
- kwargs = json.loads(json_str)
- simple_ti = SimpleTaskInstance.from_dict(obj_dict=kwargs.pop("simple_task_instance"))
- return cls(simple_task_instance=simple_ti, **kwargs)
+ val = json.loads(json_str)
+ return cls(**BaseSerialization.deserialize(val))
class DagCallbackRequest(CallbackRequest):
@@ -94,6 +106,7 @@ class DagCallbackRequest(CallbackRequest):
:param full_filepath: File Path to use to run the callback
:param dag_id: DAG ID
:param run_id: Run ID for the DagRun
+ :param processor_subdir: Directory used by Dag Processor when parsed the dag.
:param is_failure_callback: Flag to determine whether it is a Failure Callback or Success Callback
:param msg: Additional Message that can be used for logging
"""
@@ -103,10 +116,11 @@ def __init__(
full_filepath: str,
dag_id: str,
run_id: str,
- is_failure_callback: Optional[bool] = True,
- msg: Optional[str] = None,
+ processor_subdir: str | None,
+ is_failure_callback: bool | None = True,
+ msg: str | None = None,
):
- super().__init__(full_filepath=full_filepath, msg=msg)
+ super().__init__(full_filepath=full_filepath, processor_subdir=processor_subdir, msg=msg)
self.dag_id = dag_id
self.run_id = run_id
self.is_failure_callback = is_failure_callback
@@ -118,8 +132,15 @@ class SlaCallbackRequest(CallbackRequest):
:param full_filepath: File Path to use to run the callback
:param dag_id: DAG ID
+ :param processor_subdir: Directory used by Dag Processor when parsed the dag.
"""
- def __init__(self, full_filepath: str, dag_id: str, msg: Optional[str] = None):
- super().__init__(full_filepath, msg)
+ def __init__(
+ self,
+ full_filepath: str,
+ dag_id: str,
+ processor_subdir: str | None,
+ msg: str | None = None,
+ ):
+ super().__init__(full_filepath, processor_subdir=processor_subdir, msg=msg)
self.dag_id = dag_id
diff --git a/airflow/callbacks/database_callback_sink.py b/airflow/callbacks/database_callback_sink.py
index b9b81e6ed745a..24306170dfea2 100644
--- a/airflow/callbacks/database_callback_sink.py
+++ b/airflow/callbacks/database_callback_sink.py
@@ -15,13 +15,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-#
+from __future__ import annotations
from sqlalchemy.orm import Session
from airflow.callbacks.base_callback_sink import BaseCallbackSink
from airflow.callbacks.callback_requests import CallbackRequest
-from airflow.models import DbCallbackRequest
+from airflow.models.db_callback_request import DbCallbackRequest
from airflow.utils.session import NEW_SESSION, provide_session
diff --git a/airflow/callbacks/pipe_callback_sink.py b/airflow/callbacks/pipe_callback_sink.py
index 1e11ffd4f5650..d702a781fa57c 100644
--- a/airflow/callbacks/pipe_callback_sink.py
+++ b/airflow/callbacks/pipe_callback_sink.py
@@ -15,7 +15,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-#
+from __future__ import annotations
+
from multiprocessing.connection import Connection as MultiprocessingConnection
from typing import Callable
diff --git a/airflow/cli/cli_parser.py b/airflow/cli/cli_parser.py
index 60c77c7a37feb..33e513b586484 100644
--- a/airflow/cli/cli_parser.py
+++ b/airflow/cli/cli_parser.py
@@ -16,7 +16,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Command-line interface"""
+"""Command-line interface."""
+from __future__ import annotations
import argparse
import json
@@ -24,7 +25,7 @@
import textwrap
from argparse import Action, ArgumentError, RawTextHelpFormatter
from functools import lru_cache
-from typing import Callable, Dict, Iterable, List, NamedTuple, Optional, Union
+from typing import Callable, Iterable, NamedTuple, Union
import lazy_object_proxy
@@ -43,8 +44,8 @@
def lazy_load_command(import_path: str) -> Callable:
- """Create a lazy loader for command"""
- _, _, name = import_path.rpartition('.')
+ """Create a lazy loader for command."""
+ _, _, name = import_path.rpartition(".")
def command(*args, **kwargs):
func = import_string(import_path)
@@ -56,12 +57,12 @@ def command(*args, **kwargs):
class DefaultHelpParser(argparse.ArgumentParser):
- """CustomParser to display help message"""
+ """CustomParser to display help message."""
def _check_value(self, action, value):
- """Override _check_value and check conditionally added command"""
- if action.dest == 'subcommand' and value == 'celery':
- executor = conf.get('core', 'EXECUTOR')
+ """Override _check_value and check conditionally added command."""
+ if action.dest == "subcommand" and value == "celery":
+ executor = conf.get("core", "EXECUTOR")
if executor not in (CELERY_EXECUTOR, CELERY_KUBERNETES_EXECUTOR):
executor_cls, _ = ExecutorLoader.import_executor_cls(executor)
classes = ()
@@ -83,12 +84,12 @@ def _check_value(self, action, value):
pass
if not issubclass(executor_cls, classes):
message = (
- f'celery subcommand works only with CeleryExecutor, CeleryKubernetesExecutor and '
- f'executors derived from them, your current executor: {executor}, subclassed from: '
+ f"celery subcommand works only with CeleryExecutor, CeleryKubernetesExecutor and "
+ f"executors derived from them, your current executor: {executor}, subclassed from: "
f'{", ".join([base_cls.__qualname__ for base_cls in executor_cls.__bases__])}'
)
raise ArgumentError(action, message)
- if action.dest == 'subcommand' and value == 'kubernetes':
+ if action.dest == "subcommand" and value == "kubernetes":
try:
import kubernetes.client # noqa: F401
except ImportError:
@@ -104,9 +105,9 @@ def _check_value(self, action, value):
super()._check_value(action, value)
def error(self, message):
- """Override error and use print_instead of print_usage"""
+ """Override error and use print_instead of print_usage."""
self.print_help()
- self.exit(2, f'\n{self.prog} command error: {message}, see help above.\n')
+ self.exit(2, f"\n{self.prog} command error: {message}, see help above.\n")
# Used in Arg to enable `None' as a distinct value from "not passed"
@@ -114,7 +115,7 @@ def error(self, message):
class Arg:
- """Class to keep information about command line argument"""
+ """Class to keep information about command line argument."""
def __init__(
self,
@@ -140,7 +141,7 @@ def __init__(
self.kwargs[k] = v
def add_to_parser(self, parser: argparse.ArgumentParser):
- """Add this argument to an ArgumentParser"""
+ """Add this argument to an ArgumentParser."""
parser.add_argument(*self.flags, **self.kwargs)
@@ -162,12 +163,12 @@ def _check(value):
def string_list_type(val):
- """Parses comma-separated list and returns list of string (strips whitespace)"""
- return [x.strip() for x in val.split(',')]
+ """Parses comma-separated list and returns list of string (strips whitespace)."""
+ return [x.strip() for x in val.split(",")]
def string_lower_type(val):
- """Lowers arg"""
+ """Lowers arg."""
if not val:
return
return val.strip().lower()
@@ -177,8 +178,16 @@ def string_lower_type(val):
ARG_DAG_ID = Arg(("dag_id",), help="The id of the dag")
ARG_TASK_ID = Arg(("task_id",), help="The id of the task")
ARG_EXECUTION_DATE = Arg(("execution_date",), help="The execution date of the DAG", type=parsedate)
+ARG_EXECUTION_DATE_OPTIONAL = Arg(
+ ("execution_date",), nargs="?", help="The execution date of the DAG (optional)", type=parsedate
+)
ARG_EXECUTION_DATE_OR_RUN_ID = Arg(
- ('execution_date_or_run_id',), help="The execution_date of the DAG or run_id of the DAGRun"
+ ("execution_date_or_run_id",), help="The execution_date of the DAG or run_id of the DAGRun"
+)
+ARG_EXECUTION_DATE_OR_RUN_ID_OPTIONAL = Arg(
+ ("execution_date_or_run_id",),
+ nargs="?",
+ help="The execution_date of the DAG or run_id of the DAGRun (optional)",
)
ARG_TASK_REGEX = Arg(
("-t", "--task-regex"), help="The regex to filter specific task_ids to backfill (optional)"
@@ -190,7 +199,7 @@ def string_lower_type(val):
"Defaults to '[AIRFLOW_HOME]/dags' where [AIRFLOW_HOME] is the "
"value you set for 'AIRFLOW_HOME' config you set in 'airflow.cfg' "
),
- default='[AIRFLOW_HOME]/dags' if BUILD_DOCS else settings.DAGS_FOLDER,
+ default="[AIRFLOW_HOME]/dags" if BUILD_DOCS else settings.DAGS_FOLDER,
)
ARG_START_DATE = Arg(("-s", "--start-date"), help="Override start_date YYYY-MM-DD", type=parsedate)
ARG_END_DATE = Arg(("-e", "--end-date"), help="Override end_date YYYY-MM-DD", type=parsedate)
@@ -208,7 +217,7 @@ def string_lower_type(val):
help="Perform a dry run for each task. Only renders Template Fields for each task, nothing else",
action="store_true",
)
-ARG_PID = Arg(("--pid",), help="PID file location", nargs='?')
+ARG_PID = Arg(("--pid",), help="PID file location", nargs="?")
ARG_DAEMON = Arg(
("-D", "--daemon"), help="Daemonize instead of running in the foreground", action="store_true"
)
@@ -232,7 +241,7 @@ def string_lower_type(val):
default="table",
)
ARG_COLOR = Arg(
- ('--color',),
+ ("--color",),
help="Do emit colored output (default: auto)",
choices={ColorMode.ON, ColorMode.OFF, ColorMode.AUTO},
default=ColorMode.AUTO,
@@ -245,7 +254,7 @@ def string_lower_type(val):
default=None,
)
ARG_REVISION_RANGE = Arg(
- ('--revision-range',),
+ ("--revision-range",),
help=(
"Migration revision range(start:end) to use for offline sql generation. "
"Example: ``a13f7613ad25:7b2661a43ba3``"
@@ -254,13 +263,16 @@ def string_lower_type(val):
)
# list_dag_runs
-ARG_DAG_ID_OPT = Arg(("-d", "--dag-id"), help="The id of the dag")
+ARG_DAG_ID_REQ_FLAG = Arg(
+ ("-d", "--dag-id"), required=True, help="The id of the dag"
+) # TODO: convert this to a positional arg in Airflow 3
ARG_NO_BACKFILL = Arg(
("--no-backfill",), help="filter all the backfill dagruns given the dag id", action="store_true"
)
ARG_STATE = Arg(("--state",), help="Only list the dag runs corresponding to the state")
# list_jobs
+ARG_DAG_ID_OPT = Arg(("-d", "--dag-id"), help="The id of the dag")
ARG_LIMIT = Arg(("--limit",), help="Return a limited number of records")
# next_execution
@@ -339,6 +351,11 @@ def string_lower_type(val):
help=("if set, the backfill will keep going even if some of the tasks failed"),
action="store_true",
)
+ARG_DISABLE_RETRY = Arg(
+ ("--disable-retry",),
+ help=("if set, the backfill will set tasks as failed without retrying."),
+ action="store_true",
+)
ARG_RUN_BACKWARDS = Arg(
(
"-B",
@@ -351,6 +368,11 @@ def string_lower_type(val):
),
action="store_true",
)
+ARG_TREAT_DAG_AS_REGEX = Arg(
+ ("--treat-dag-as-regex",),
+ help=("if set, dag_id will be treated as regex instead of an exact string"),
+ action="store_true",
+)
# test_dag
ARG_SHOW_DAGRUN = Arg(
("--show-dagrun",),
@@ -359,7 +381,7 @@ def string_lower_type(val):
"\n"
"The diagram is in DOT language\n"
),
- action='store_true',
+ action="store_true",
)
ARG_IMGCAT_DAGRUN = Arg(
("--imgcat-dagrun",),
@@ -367,7 +389,7 @@ def string_lower_type(val):
"After completing the dag run, prints a diagram on the screen for the "
"current DAG Run using the imgcat tool.\n"
),
- action='store_true',
+ action="store_true",
)
ARG_SAVE_DAGRUN = Arg(
("--save-dagrun",),
@@ -405,11 +427,11 @@ def string_lower_type(val):
# show_dag
ARG_SAVE = Arg(("-s", "--save"), help="Saves the result to the indicated file.")
-ARG_IMGCAT = Arg(("--imgcat",), help="Displays graph using the imgcat tool.", action='store_true')
+ARG_IMGCAT = Arg(("--imgcat",), help="Displays graph using the imgcat tool.", action="store_true")
# trigger_dag
ARG_RUN_ID = Arg(("-r", "--run-id"), help="Helps to identify this run")
-ARG_CONF = Arg(('-c', '--conf'), help="JSON string that gets pickled into the DagRun's conf attribute")
+ARG_CONF = Arg(("-c", "--conf"), help="JSON string that gets pickled into the DagRun's conf attribute")
ARG_EXEC_DATE = Arg(("-e", "--exec-date"), help="The execution date of the DAG", type=parsedate)
# db
@@ -434,10 +456,15 @@ def string_lower_type(val):
help="Perform a dry run",
action="store_true",
)
+ARG_DB_SKIP_ARCHIVE = Arg(
+ ("--skip-archive",),
+ help="Don't preserve purged records in an archive table.",
+ action="store_true",
+)
# pool
-ARG_POOL_NAME = Arg(("pool",), metavar='NAME', help="Pool name")
+ARG_POOL_NAME = Arg(("pool",), metavar="NAME", help="Pool name")
ARG_POOL_SLOTS = Arg(("slots",), type=int, help="Pool slots")
ARG_POOL_DESCRIPTION = Arg(("description",), help="Pool description")
ARG_POOL_IMPORT = Arg(
@@ -446,11 +473,11 @@ def string_lower_type(val):
help="Import pools from JSON file. Example format::\n"
+ textwrap.indent(
textwrap.dedent(
- '''
+ """
{
"pool_1": {"slots": 5, "description": ""},
"pool_2": {"slots": 10, "description": "test"}
- }'''
+ }"""
),
" " * 4,
),
@@ -460,22 +487,23 @@ def string_lower_type(val):
# variables
ARG_VAR = Arg(("key",), help="Variable key")
-ARG_VAR_VALUE = Arg(("value",), metavar='VALUE', help="Variable value")
+ARG_VAR_VALUE = Arg(("value",), metavar="VALUE", help="Variable value")
ARG_DEFAULT = Arg(
("-d", "--default"), metavar="VAL", default=None, help="Default value returned if variable does not exist"
)
-ARG_JSON = Arg(("-j", "--json"), help="Deserialize JSON variable", action="store_true")
+ARG_DESERIALIZE_JSON = Arg(("-j", "--json"), help="Deserialize JSON variable", action="store_true")
+ARG_SERIALIZE_JSON = Arg(("-j", "--json"), help="Serialize JSON variable", action="store_true")
ARG_VAR_IMPORT = Arg(("file",), help="Import variables from JSON file")
ARG_VAR_EXPORT = Arg(("file",), help="Export all variables to JSON file")
# kerberos
-ARG_PRINCIPAL = Arg(("principal",), help="kerberos principal", nargs='?')
-ARG_KEYTAB = Arg(("-k", "--keytab"), help="keytab", nargs='?', default=conf.get('kerberos', 'keytab'))
+ARG_PRINCIPAL = Arg(("principal",), help="kerberos principal", nargs="?")
+ARG_KEYTAB = Arg(("-k", "--keytab"), help="keytab", nargs="?", default=conf.get("kerberos", "keytab"))
# run
ARG_INTERACTIVE = Arg(
- ('-N', '--interactive'),
- help='Do not capture standard output and error streams (useful for interactive debugging)',
- action='store_true',
+ ("-N", "--interactive"),
+ help="Do not capture standard output and error streams (useful for interactive debugging)",
+ action="store_true",
)
# TODO(aoen): "force" is a poor choice of name here since it implies it overrides
# all dependencies (not just past success), e.g. the ignore_depends_on_past
@@ -514,7 +542,7 @@ def string_lower_type(val):
ARG_PICKLE = Arg(("-p", "--pickle"), help="Serialized pickle object of the entire dag (used internally)")
ARG_JOB_ID = Arg(("-j", "--job-id"), help=argparse.SUPPRESS)
ARG_CFG_PATH = Arg(("--cfg-path",), help="Path to config file to use instead of airflow.cfg")
-ARG_MAP_INDEX = Arg(('--map-index',), type=int, default=-1, help="Mapped task index")
+ARG_MAP_INDEX = Arg(("--map-index",), type=int, default=-1, help="Mapped task index")
# database
@@ -524,6 +552,14 @@ def string_lower_type(val):
type=int,
default=60,
)
+ARG_DB_RESERIALIZE_DAGS = Arg(
+ ("--no-reserialize-dags",),
+ # Not intended for user, so dont show in help
+ help=argparse.SUPPRESS,
+ action="store_false",
+ default=True,
+ dest="reserialize_dags",
+)
ARG_DB_VERSION__UPGRADE = Arg(
("-n", "--to-version"),
help=(
@@ -554,7 +590,7 @@ def string_lower_type(val):
ARG_DB_SQL_ONLY = Arg(
("-s", "--show-sql-only"),
help="Don't actually run migrations; just print out sql scripts for offline migration. "
- "Required if using either `--from-version` or `--from-version`.",
+ "Required if using either `--from-revision` or `--from-version`.",
action="store_true",
default=False,
)
@@ -568,41 +604,41 @@ def string_lower_type(val):
# webserver
ARG_PORT = Arg(
("-p", "--port"),
- default=conf.get('webserver', 'WEB_SERVER_PORT'),
+ default=conf.get("webserver", "WEB_SERVER_PORT"),
type=int,
help="The port on which to run the server",
)
ARG_SSL_CERT = Arg(
("--ssl-cert",),
- default=conf.get('webserver', 'WEB_SERVER_SSL_CERT'),
+ default=conf.get("webserver", "WEB_SERVER_SSL_CERT"),
help="Path to the SSL certificate for the webserver",
)
ARG_SSL_KEY = Arg(
("--ssl-key",),
- default=conf.get('webserver', 'WEB_SERVER_SSL_KEY'),
+ default=conf.get("webserver", "WEB_SERVER_SSL_KEY"),
help="Path to the key to use with the SSL certificate",
)
ARG_WORKERS = Arg(
("-w", "--workers"),
- default=conf.get('webserver', 'WORKERS'),
+ default=conf.get("webserver", "WORKERS"),
type=int,
help="Number of workers to run the webserver on",
)
ARG_WORKERCLASS = Arg(
("-k", "--workerclass"),
- default=conf.get('webserver', 'WORKER_CLASS'),
- choices=['sync', 'eventlet', 'gevent', 'tornado'],
+ default=conf.get("webserver", "WORKER_CLASS"),
+ choices=["sync", "eventlet", "gevent", "tornado"],
help="The worker class to use for Gunicorn",
)
ARG_WORKER_TIMEOUT = Arg(
("-t", "--worker-timeout"),
- default=conf.get('webserver', 'WEB_SERVER_WORKER_TIMEOUT'),
+ default=conf.get("webserver", "WEB_SERVER_WORKER_TIMEOUT"),
type=int,
help="The timeout for waiting on webserver workers",
)
ARG_HOSTNAME = Arg(
("-H", "--hostname"),
- default=conf.get('webserver', 'WEB_SERVER_HOST'),
+ default=conf.get("webserver", "WEB_SERVER_HOST"),
help="Set the hostname on which to run the web server",
)
ARG_DEBUG = Arg(
@@ -610,24 +646,24 @@ def string_lower_type(val):
)
ARG_ACCESS_LOGFILE = Arg(
("-A", "--access-logfile"),
- default=conf.get('webserver', 'ACCESS_LOGFILE'),
- help="The logfile to store the webserver access log. Use '-' to print to stderr",
+ default=conf.get("webserver", "ACCESS_LOGFILE"),
+ help="The logfile to store the webserver access log. Use '-' to print to stdout",
)
ARG_ERROR_LOGFILE = Arg(
("-E", "--error-logfile"),
- default=conf.get('webserver', 'ERROR_LOGFILE'),
+ default=conf.get("webserver", "ERROR_LOGFILE"),
help="The logfile to store the webserver error log. Use '-' to print to stderr",
)
ARG_ACCESS_LOGFORMAT = Arg(
("-L", "--access-logformat"),
- default=conf.get('webserver', 'ACCESS_LOGFORMAT'),
+ default=conf.get("webserver", "ACCESS_LOGFORMAT"),
help="The access log format for gunicorn logs",
)
# scheduler
ARG_NUM_RUNS = Arg(
("-n", "--num-runs"),
- default=conf.getint('scheduler', 'num_runs'),
+ default=conf.getint("scheduler", "num_runs"),
type=int,
help="Set the number of runs to execute before exiting",
)
@@ -646,13 +682,13 @@ def string_lower_type(val):
ARG_QUEUES = Arg(
("-q", "--queues"),
help="Comma delimited list of queues to serve",
- default=conf.get('operators', 'DEFAULT_QUEUE'),
+ default=conf.get("operators", "DEFAULT_QUEUE"),
)
ARG_CONCURRENCY = Arg(
("-c", "--concurrency"),
type=int,
help="The number of worker processes",
- default=conf.get('celery', 'worker_concurrency'),
+ default=conf.get("celery", "worker_concurrency"),
)
ARG_CELERY_HOSTNAME = Arg(
("-H", "--celery-hostname"),
@@ -661,18 +697,17 @@ def string_lower_type(val):
ARG_UMASK = Arg(
("-u", "--umask"),
help="Set the umask of celery worker in daemon mode",
- default=conf.get('celery', 'worker_umask'),
)
ARG_WITHOUT_MINGLE = Arg(
("--without-mingle",),
default=False,
- help="Don’t synchronize with other workers at start-up",
+ help="Don't synchronize with other workers at start-up",
action="store_true",
)
ARG_WITHOUT_GOSSIP = Arg(
("--without-gossip",),
default=False,
- help="Don’t subscribe to other workers events",
+ help="Don't subscribe to other workers events",
action="store_true",
)
@@ -680,22 +715,22 @@ def string_lower_type(val):
ARG_BROKER_API = Arg(("-a", "--broker-api"), help="Broker API")
ARG_FLOWER_HOSTNAME = Arg(
("-H", "--hostname"),
- default=conf.get('celery', 'FLOWER_HOST'),
+ default=conf.get("celery", "FLOWER_HOST"),
help="Set the hostname on which to run the server",
)
ARG_FLOWER_PORT = Arg(
("-p", "--port"),
- default=conf.get('celery', 'FLOWER_PORT'),
+ default=conf.get("celery", "FLOWER_PORT"),
type=int,
help="The port on which to run the server",
)
ARG_FLOWER_CONF = Arg(("-c", "--flower-conf"), help="Configuration file for flower")
ARG_FLOWER_URL_PREFIX = Arg(
- ("-u", "--url-prefix"), default=conf.get('celery', 'FLOWER_URL_PREFIX'), help="URL prefix for Flower"
+ ("-u", "--url-prefix"), default=conf.get("celery", "FLOWER_URL_PREFIX"), help="URL prefix for Flower"
)
ARG_FLOWER_BASIC_AUTH = Arg(
("-A", "--basic-auth"),
- default=conf.get('celery', 'FLOWER_BASIC_AUTH'),
+ default=conf.get("celery", "FLOWER_BASIC_AUTH"),
help=(
"Securing Flower with Basic Authentication. "
"Accepts user:password pairs separated by a comma. "
@@ -713,91 +748,91 @@ def string_lower_type(val):
)
# connections
-ARG_CONN_ID = Arg(('conn_id',), help='Connection id, required to get/add/delete a connection', type=str)
+ARG_CONN_ID = Arg(("conn_id",), help="Connection id, required to get/add/delete a connection", type=str)
ARG_CONN_ID_FILTER = Arg(
- ('--conn-id',), help='If passed, only items with the specified connection ID will be displayed', type=str
+ ("--conn-id",), help="If passed, only items with the specified connection ID will be displayed", type=str
)
ARG_CONN_URI = Arg(
- ('--conn-uri',), help='Connection URI, required to add a connection without conn_type', type=str
+ ("--conn-uri",), help="Connection URI, required to add a connection without conn_type", type=str
)
ARG_CONN_JSON = Arg(
- ('--conn-json',), help='Connection JSON, required to add a connection using JSON representation', type=str
+ ("--conn-json",), help="Connection JSON, required to add a connection using JSON representation", type=str
)
ARG_CONN_TYPE = Arg(
- ('--conn-type',), help='Connection type, required to add a connection without conn_uri', type=str
+ ("--conn-type",), help="Connection type, required to add a connection without conn_uri", type=str
)
ARG_CONN_DESCRIPTION = Arg(
- ('--conn-description',), help='Connection description, optional when adding a connection', type=str
+ ("--conn-description",), help="Connection description, optional when adding a connection", type=str
)
-ARG_CONN_HOST = Arg(('--conn-host',), help='Connection host, optional when adding a connection', type=str)
-ARG_CONN_LOGIN = Arg(('--conn-login',), help='Connection login, optional when adding a connection', type=str)
+ARG_CONN_HOST = Arg(("--conn-host",), help="Connection host, optional when adding a connection", type=str)
+ARG_CONN_LOGIN = Arg(("--conn-login",), help="Connection login, optional when adding a connection", type=str)
ARG_CONN_PASSWORD = Arg(
- ('--conn-password',), help='Connection password, optional when adding a connection', type=str
+ ("--conn-password",), help="Connection password, optional when adding a connection", type=str
)
ARG_CONN_SCHEMA = Arg(
- ('--conn-schema',), help='Connection schema, optional when adding a connection', type=str
+ ("--conn-schema",), help="Connection schema, optional when adding a connection", type=str
)
-ARG_CONN_PORT = Arg(('--conn-port',), help='Connection port, optional when adding a connection', type=str)
+ARG_CONN_PORT = Arg(("--conn-port",), help="Connection port, optional when adding a connection", type=str)
ARG_CONN_EXTRA = Arg(
- ('--conn-extra',), help='Connection `Extra` field, optional when adding a connection', type=str
+ ("--conn-extra",), help="Connection `Extra` field, optional when adding a connection", type=str
)
ARG_CONN_EXPORT = Arg(
- ('file',),
- help='Output file path for exporting the connections',
- type=argparse.FileType('w', encoding='UTF-8'),
+ ("file",),
+ help="Output file path for exporting the connections",
+ type=argparse.FileType("w", encoding="UTF-8"),
)
ARG_CONN_EXPORT_FORMAT = Arg(
- ('--format',),
- help='Deprecated -- use `--file-format` instead. File format to use for the export.',
+ ("--format",),
+ help="Deprecated -- use `--file-format` instead. File format to use for the export.",
type=str,
- choices=['json', 'yaml', 'env'],
+ choices=["json", "yaml", "env"],
)
ARG_CONN_EXPORT_FILE_FORMAT = Arg(
- ('--file-format',), help='File format for the export', type=str, choices=['json', 'yaml', 'env']
+ ("--file-format",), help="File format for the export", type=str, choices=["json", "yaml", "env"]
)
ARG_CONN_SERIALIZATION_FORMAT = Arg(
- ('--serialization-format',),
- help='When exporting as `.env` format, defines how connections should be serialized. Default is `uri`.',
+ ("--serialization-format",),
+ help="When exporting as `.env` format, defines how connections should be serialized. Default is `uri`.",
type=string_lower_type,
- choices=['json', 'uri'],
+ choices=["json", "uri"],
)
ARG_CONN_IMPORT = Arg(("file",), help="Import connections from a file")
# providers
ARG_PROVIDER_NAME = Arg(
- ('provider_name',), help='Provider name, required to get provider information', type=str
+ ("provider_name",), help="Provider name, required to get provider information", type=str
)
ARG_FULL = Arg(
- ('-f', '--full'),
- help='Full information about the provider, including documentation information.',
+ ("-f", "--full"),
+ help="Full information about the provider, including documentation information.",
required=False,
action="store_true",
)
# users
-ARG_USERNAME = Arg(('-u', '--username'), help='Username of the user', required=True, type=str)
-ARG_USERNAME_OPTIONAL = Arg(('-u', '--username'), help='Username of the user', type=str)
-ARG_FIRSTNAME = Arg(('-f', '--firstname'), help='First name of the user', required=True, type=str)
-ARG_LASTNAME = Arg(('-l', '--lastname'), help='Last name of the user', required=True, type=str)
+ARG_USERNAME = Arg(("-u", "--username"), help="Username of the user", required=True, type=str)
+ARG_USERNAME_OPTIONAL = Arg(("-u", "--username"), help="Username of the user", type=str)
+ARG_FIRSTNAME = Arg(("-f", "--firstname"), help="First name of the user", required=True, type=str)
+ARG_LASTNAME = Arg(("-l", "--lastname"), help="Last name of the user", required=True, type=str)
ARG_ROLE = Arg(
- ('-r', '--role'),
- help='Role of the user. Existing roles include Admin, User, Op, Viewer, and Public',
+ ("-r", "--role"),
+ help="Role of the user. Existing roles include Admin, User, Op, Viewer, and Public",
required=True,
type=str,
)
-ARG_EMAIL = Arg(('-e', '--email'), help='Email of the user', required=True, type=str)
-ARG_EMAIL_OPTIONAL = Arg(('-e', '--email'), help='Email of the user', type=str)
+ARG_EMAIL = Arg(("-e", "--email"), help="Email of the user", required=True, type=str)
+ARG_EMAIL_OPTIONAL = Arg(("-e", "--email"), help="Email of the user", type=str)
ARG_PASSWORD = Arg(
- ('-p', '--password'),
- help='Password of the user, required to create a user without --use-random-password',
+ ("-p", "--password"),
+ help="Password of the user, required to create a user without --use-random-password",
type=str,
)
ARG_USE_RANDOM_PASSWORD = Arg(
- ('--use-random-password',),
- help='Do not prompt for password. Use random string instead.'
- ' Required to create a user without --password ',
+ ("--use-random-password",),
+ help="Do not prompt for password. Use random string instead."
+ " Required to create a user without --password ",
default=False,
- action='store_true',
+ action="store_true",
)
ARG_USER_IMPORT = Arg(
("import",),
@@ -805,7 +840,7 @@ def string_lower_type(val):
help="Import users from JSON file. Example format::\n"
+ textwrap.indent(
textwrap.dedent(
- '''
+ """
[
{
"email": "foo@bar.org",
@@ -814,7 +849,7 @@ def string_lower_type(val):
"roles": ["Public"],
"username": "jondoe"
}
- ]'''
+ ]"""
),
" " * 4,
),
@@ -822,10 +857,14 @@ def string_lower_type(val):
ARG_USER_EXPORT = Arg(("export",), metavar="FILEPATH", help="Export all users to JSON file")
# roles
-ARG_CREATE_ROLE = Arg(('-c', '--create'), help='Create a new role', action='store_true')
-ARG_LIST_ROLES = Arg(('-l', '--list'), help='List roles', action='store_true')
-ARG_ROLES = Arg(('role',), help='The name of a role', nargs='*')
-ARG_AUTOSCALE = Arg(('-a', '--autoscale'), help="Minimum and Maximum number of worker to autoscale")
+ARG_CREATE_ROLE = Arg(("-c", "--create"), help="Create a new role", action="store_true")
+ARG_LIST_ROLES = Arg(("-l", "--list"), help="List roles", action="store_true")
+ARG_ROLES = Arg(("role",), help="The name of a role", nargs="*")
+ARG_PERMISSIONS = Arg(("-p", "--permission"), help="Show role permissions", action="store_true")
+ARG_ROLE_RESOURCE = Arg(("-r", "--resource"), help="The name of permissions", nargs="*", required=True)
+ARG_ROLE_ACTION = Arg(("-a", "--action"), help="The action of permissions", nargs="*")
+ARG_ROLE_ACTION_REQUIRED = Arg(("-a", "--action"), help="The action of permissions", nargs="*", required=True)
+ARG_AUTOSCALE = Arg(("-a", "--autoscale"), help="Minimum and Maximum number of worker to autoscale")
ARG_SKIP_SERVE_LOGS = Arg(
("-s", "--skip-serve-logs"),
default=False,
@@ -835,19 +874,19 @@ def string_lower_type(val):
ARG_ROLE_IMPORT = Arg(("file",), help="Import roles from JSON file", nargs=None)
ARG_ROLE_EXPORT = Arg(("file",), help="Export all roles to JSON file", nargs=None)
ARG_ROLE_EXPORT_FMT = Arg(
- ('-p', '--pretty'),
- help='Format output JSON file by sorting role names and indenting by 4 spaces',
- action='store_true',
+ ("-p", "--pretty"),
+ help="Format output JSON file by sorting role names and indenting by 4 spaces",
+ action="store_true",
)
# info
ARG_ANONYMIZE = Arg(
- ('--anonymize',),
- help='Minimize any personal identifiable information. Use it when sharing output with others.',
- action='store_true',
+ ("--anonymize",),
+ help="Minimize any personal identifiable information. Use it when sharing output with others.",
+ action="store_true",
)
ARG_FILE_IO = Arg(
- ('--file-io',), help='Send output to file.io service and returns link.', action='store_true'
+ ("--file-io",), help="Send output to file.io service and returns link.", action="store_true"
)
# config
@@ -863,7 +902,7 @@ def string_lower_type(val):
# kubernetes cleanup-pods
ARG_NAMESPACE = Arg(
("--namespace",),
- default=conf.get('kubernetes', 'namespace'),
+ default=conf.get("kubernetes_executor", "namespace"),
help="Kubernetes Namespace. Default value is `[kubernetes] namespace` in configuration.",
)
@@ -879,10 +918,10 @@ def string_lower_type(val):
# jobs check
ARG_JOB_TYPE_FILTER = Arg(
- ('--job-type',),
- choices=('BackfillJob', 'LocalTaskJob', 'SchedulerJob', 'TriggererJob'),
- action='store',
- help='The type of job(s) that will be checked.',
+ ("--job-type",),
+ choices=("BackfillJob", "LocalTaskJob", "SchedulerJob", "TriggererJob"),
+ action="store",
+ help="The type of job(s) that will be checked.",
)
ARG_JOB_HOSTNAME_FILTER = Arg(
@@ -892,6 +931,13 @@ def string_lower_type(val):
help="The hostname of job(s) that will be checked.",
)
+ARG_JOB_HOSTNAME_CALLABLE_FILTER = Arg(
+ ("--local",),
+ action="store_true",
+ help="If passed, this command will only show jobs from the local host "
+ "(those with a hostname matching what `hostname_callable` returns).",
+)
+
ARG_JOB_LIMIT = Arg(
("--limit",),
default=1,
@@ -901,7 +947,7 @@ def string_lower_type(val):
ARG_ALLOW_MULTIPLE = Arg(
("--allow-multiple",),
- action='store_true',
+ action="store_true",
help="If passed, this command will be successful even if multiple matching alive jobs are found.",
)
@@ -936,49 +982,49 @@ def string_lower_type(val):
class ActionCommand(NamedTuple):
- """Single CLI command"""
+ """Single CLI command."""
name: str
help: str
func: Callable
args: Iterable[Arg]
- description: Optional[str] = None
- epilog: Optional[str] = None
+ description: str | None = None
+ epilog: str | None = None
class GroupCommand(NamedTuple):
- """ClI command with subcommands"""
+ """ClI command with subcommands."""
name: str
help: str
subcommands: Iterable
- description: Optional[str] = None
- epilog: Optional[str] = None
+ description: str | None = None
+ epilog: str | None = None
CLICommand = Union[ActionCommand, GroupCommand]
DAGS_COMMANDS = (
ActionCommand(
- name='list',
+ name="list",
help="List all the DAGs",
- func=lazy_load_command('airflow.cli.commands.dag_command.dag_list_dags'),
+ func=lazy_load_command("airflow.cli.commands.dag_command.dag_list_dags"),
args=(ARG_SUBDIR, ARG_OUTPUT, ARG_VERBOSE),
),
ActionCommand(
- name='list-import-errors',
+ name="list-import-errors",
help="List all the DAGs that have import errors",
- func=lazy_load_command('airflow.cli.commands.dag_command.dag_list_import_errors'),
+ func=lazy_load_command("airflow.cli.commands.dag_command.dag_list_import_errors"),
args=(ARG_SUBDIR, ARG_OUTPUT, ARG_VERBOSE),
),
ActionCommand(
- name='report',
- help='Show DagBag loading report',
- func=lazy_load_command('airflow.cli.commands.dag_command.dag_report'),
+ name="report",
+ help="Show DagBag loading report",
+ func=lazy_load_command("airflow.cli.commands.dag_command.dag_report"),
args=(ARG_SUBDIR, ARG_OUTPUT, ARG_VERBOSE),
),
ActionCommand(
- name='list-runs',
+ name="list-runs",
help="List DAG runs given a DAG id",
description=(
"List DAG runs given a DAG id. If state option is given, it will only search for all the "
@@ -987,9 +1033,9 @@ class GroupCommand(NamedTuple):
"dagruns that were executed before this date. If end_date is given, it will filter out "
"all the dagruns that were executed after this date. "
),
- func=lazy_load_command('airflow.cli.commands.dag_command.dag_list_dag_runs'),
+ func=lazy_load_command("airflow.cli.commands.dag_command.dag_list_dag_runs"),
args=(
- ARG_DAG_ID_OPT,
+ ARG_DAG_ID_REQ_FLAG,
ARG_NO_BACKFILL,
ARG_STATE,
ARG_OUTPUT,
@@ -999,53 +1045,53 @@ class GroupCommand(NamedTuple):
),
),
ActionCommand(
- name='list-jobs',
+ name="list-jobs",
help="List the jobs",
- func=lazy_load_command('airflow.cli.commands.dag_command.dag_list_jobs'),
+ func=lazy_load_command("airflow.cli.commands.dag_command.dag_list_jobs"),
args=(ARG_DAG_ID_OPT, ARG_STATE, ARG_LIMIT, ARG_OUTPUT, ARG_VERBOSE),
),
ActionCommand(
- name='state',
+ name="state",
help="Get the status of a dag run",
- func=lazy_load_command('airflow.cli.commands.dag_command.dag_state'),
- args=(ARG_DAG_ID, ARG_EXECUTION_DATE, ARG_SUBDIR),
+ func=lazy_load_command("airflow.cli.commands.dag_command.dag_state"),
+ args=(ARG_DAG_ID, ARG_EXECUTION_DATE, ARG_SUBDIR, ARG_VERBOSE),
),
ActionCommand(
- name='next-execution',
+ name="next-execution",
help="Get the next execution datetimes of a DAG",
description=(
"Get the next execution datetimes of a DAG. It returns one execution unless the "
"num-executions option is given"
),
- func=lazy_load_command('airflow.cli.commands.dag_command.dag_next_execution'),
- args=(ARG_DAG_ID, ARG_SUBDIR, ARG_NUM_EXECUTIONS),
+ func=lazy_load_command("airflow.cli.commands.dag_command.dag_next_execution"),
+ args=(ARG_DAG_ID, ARG_SUBDIR, ARG_NUM_EXECUTIONS, ARG_VERBOSE),
),
ActionCommand(
- name='pause',
- help='Pause a DAG',
- func=lazy_load_command('airflow.cli.commands.dag_command.dag_pause'),
- args=(ARG_DAG_ID, ARG_SUBDIR),
+ name="pause",
+ help="Pause a DAG",
+ func=lazy_load_command("airflow.cli.commands.dag_command.dag_pause"),
+ args=(ARG_DAG_ID, ARG_SUBDIR, ARG_VERBOSE),
),
ActionCommand(
- name='unpause',
- help='Resume a paused DAG',
- func=lazy_load_command('airflow.cli.commands.dag_command.dag_unpause'),
- args=(ARG_DAG_ID, ARG_SUBDIR),
+ name="unpause",
+ help="Resume a paused DAG",
+ func=lazy_load_command("airflow.cli.commands.dag_command.dag_unpause"),
+ args=(ARG_DAG_ID, ARG_SUBDIR, ARG_VERBOSE),
),
ActionCommand(
- name='trigger',
- help='Trigger a DAG run',
- func=lazy_load_command('airflow.cli.commands.dag_command.dag_trigger'),
- args=(ARG_DAG_ID, ARG_SUBDIR, ARG_RUN_ID, ARG_CONF, ARG_EXEC_DATE),
+ name="trigger",
+ help="Trigger a DAG run",
+ func=lazy_load_command("airflow.cli.commands.dag_command.dag_trigger"),
+ args=(ARG_DAG_ID, ARG_SUBDIR, ARG_RUN_ID, ARG_CONF, ARG_EXEC_DATE, ARG_VERBOSE),
),
ActionCommand(
- name='delete',
+ name="delete",
help="Delete all DB records related to the specified DAG",
- func=lazy_load_command('airflow.cli.commands.dag_command.dag_delete'),
- args=(ARG_DAG_ID, ARG_YES),
+ func=lazy_load_command("airflow.cli.commands.dag_command.dag_delete"),
+ args=(ARG_DAG_ID, ARG_YES, ARG_VERBOSE),
),
ActionCommand(
- name='show',
+ name="show",
help="Displays DAG's tasks with their dependencies",
description=(
"The --imgcat option only works in iTerm.\n"
@@ -1064,16 +1110,17 @@ class GroupCommand(NamedTuple):
"If you want to create a DOT file then you should execute the following command:\n"
"airflow dags show --save output.dot\n"
),
- func=lazy_load_command('airflow.cli.commands.dag_command.dag_show'),
+ func=lazy_load_command("airflow.cli.commands.dag_command.dag_show"),
args=(
ARG_DAG_ID,
ARG_SUBDIR,
ARG_SAVE,
ARG_IMGCAT,
+ ARG_VERBOSE,
),
),
ActionCommand(
- name='show-dependencies',
+ name="show-dependencies",
help="Displays DAGs with their dependencies",
description=(
"The --imgcat option only works in iTerm.\n"
@@ -1092,15 +1139,16 @@ class GroupCommand(NamedTuple):
"If you want to create a DOT file then you should execute the following command:\n"
"airflow dags show-dependencies --save output.dot\n"
),
- func=lazy_load_command('airflow.cli.commands.dag_command.dag_dependencies_show'),
+ func=lazy_load_command("airflow.cli.commands.dag_command.dag_dependencies_show"),
args=(
ARG_SUBDIR,
ARG_SAVE,
ARG_IMGCAT,
+ ARG_VERBOSE,
),
),
ActionCommand(
- name='backfill',
+ name="backfill",
help="Run subsections of a DAG for a specified date range",
description=(
"Run subsections of a DAG for a specified date range. If reset_dag_run option is used, "
@@ -1108,7 +1156,7 @@ class GroupCommand(NamedTuple):
"task_instances within the backfill date range. If rerun_failed_tasks is used, backfill "
"will auto re-run the previous failed task instances within the backfill date range"
),
- func=lazy_load_command('airflow.cli.commands.dag_command.dag_backfill'),
+ func=lazy_load_command("airflow.cli.commands.dag_command.dag_backfill"),
args=(
ARG_DAG_ID,
ARG_TASK_REGEX,
@@ -1119,6 +1167,7 @@ class GroupCommand(NamedTuple):
ARG_DONOT_PICKLE,
ARG_YES,
ARG_CONTINUE_ON_FAILURES,
+ ARG_DISABLE_RETRY,
ARG_BF_IGNORE_DEPENDENCIES,
ARG_BF_IGNORE_FIRST_DEPENDS_ON_PAST,
ARG_SUBDIR,
@@ -1130,14 +1179,14 @@ class GroupCommand(NamedTuple):
ARG_RESET_DAG_RUN,
ARG_RERUN_FAILED_TASKS,
ARG_RUN_BACKWARDS,
+ ARG_TREAT_DAG_AS_REGEX,
),
),
ActionCommand(
- name='test',
+ name="test",
help="Execute one single DagRun",
description=(
- "Execute one single DagRun for a given DAG and execution date, "
- "using the DebugExecutor.\n"
+ "Execute one single DagRun for a given DAG and execution date.\n"
"\n"
"The --imgcat-dagrun option only works in iTerm.\n"
"\n"
@@ -1155,39 +1204,45 @@ class GroupCommand(NamedTuple):
"If you want to create a DOT file then you should execute the following command:\n"
"airflow dags test --save-dagrun output.dot\n"
),
- func=lazy_load_command('airflow.cli.commands.dag_command.dag_test'),
+ func=lazy_load_command("airflow.cli.commands.dag_command.dag_test"),
args=(
ARG_DAG_ID,
- ARG_EXECUTION_DATE,
+ ARG_EXECUTION_DATE_OPTIONAL,
+ ARG_CONF,
ARG_SUBDIR,
ARG_SHOW_DAGRUN,
ARG_IMGCAT_DAGRUN,
ARG_SAVE_DAGRUN,
+ ARG_VERBOSE,
),
),
ActionCommand(
- name='reserialize',
+ name="reserialize",
help="Reserialize all DAGs by parsing the DagBag files",
description=(
"Drop all serialized dags from the metadata DB. This will cause all DAGs to be reserialized "
"from the DagBag folder. This can be helpful if your serialized DAGs get out of sync with the "
"version of Airflow that you are running."
),
- func=lazy_load_command('airflow.cli.commands.dag_command.dag_reserialize'),
- args=(ARG_CLEAR_ONLY,),
+ func=lazy_load_command("airflow.cli.commands.dag_command.dag_reserialize"),
+ args=(
+ ARG_CLEAR_ONLY,
+ ARG_SUBDIR,
+ ARG_VERBOSE,
+ ),
),
)
TASKS_COMMANDS = (
ActionCommand(
- name='list',
+ name="list",
help="List the tasks within a DAG",
- func=lazy_load_command('airflow.cli.commands.task_command.task_list'),
+ func=lazy_load_command("airflow.cli.commands.task_command.task_list"),
args=(ARG_DAG_ID, ARG_TREE, ARG_SUBDIR, ARG_VERBOSE),
),
ActionCommand(
- name='clear',
+ name="clear",
help="Clear a set of task instance, as if they never ran",
- func=lazy_load_command('airflow.cli.commands.task_command.task_clear'),
+ func=lazy_load_command("airflow.cli.commands.task_command.task_clear"),
args=(
ARG_DAG_ID,
ARG_TASK_REGEX,
@@ -1202,12 +1257,13 @@ class GroupCommand(NamedTuple):
ARG_EXCLUDE_SUBDAGS,
ARG_EXCLUDE_PARENTDAG,
ARG_DAG_REGEX,
+ ARG_VERBOSE,
),
),
ActionCommand(
- name='state',
+ name="state",
help="Get the status of a task instance",
- func=lazy_load_command('airflow.cli.commands.task_command.task_state'),
+ func=lazy_load_command("airflow.cli.commands.task_command.task_state"),
args=(
ARG_DAG_ID,
ARG_TASK_ID,
@@ -1218,20 +1274,20 @@ class GroupCommand(NamedTuple):
),
),
ActionCommand(
- name='failed-deps',
+ name="failed-deps",
help="Returns the unmet dependencies for a task instance",
description=(
"Returns the unmet dependencies for a task instance from the perspective of the scheduler. "
"In other words, why a task instance doesn't get scheduled and then queued by the scheduler, "
"and then run by an executor."
),
- func=lazy_load_command('airflow.cli.commands.task_command.task_failed_deps'),
- args=(ARG_DAG_ID, ARG_TASK_ID, ARG_EXECUTION_DATE_OR_RUN_ID, ARG_SUBDIR, ARG_MAP_INDEX),
+ func=lazy_load_command("airflow.cli.commands.task_command.task_failed_deps"),
+ args=(ARG_DAG_ID, ARG_TASK_ID, ARG_EXECUTION_DATE_OR_RUN_ID, ARG_SUBDIR, ARG_MAP_INDEX, ARG_VERBOSE),
),
ActionCommand(
- name='render',
+ name="render",
help="Render a task instance's template(s)",
- func=lazy_load_command('airflow.cli.commands.task_command.task_render'),
+ func=lazy_load_command("airflow.cli.commands.task_command.task_render"),
args=(
ARG_DAG_ID,
ARG_TASK_ID,
@@ -1242,9 +1298,9 @@ class GroupCommand(NamedTuple):
),
),
ActionCommand(
- name='run',
+ name="run",
help="Run a single task instance",
- func=lazy_load_command('airflow.cli.commands.task_command.task_run'),
+ func=lazy_load_command("airflow.cli.commands.task_command.task_run"),
args=(
ARG_DAG_ID,
ARG_TASK_ID,
@@ -1265,133 +1321,135 @@ class GroupCommand(NamedTuple):
ARG_INTERACTIVE,
ARG_SHUT_DOWN_LOGGING,
ARG_MAP_INDEX,
+ ARG_VERBOSE,
),
),
ActionCommand(
- name='test',
+ name="test",
help="Test a task instance",
description=(
"Test a task instance. This will run a task without checking for dependencies or recording "
"its state in the database"
),
- func=lazy_load_command('airflow.cli.commands.task_command.task_test'),
+ func=lazy_load_command("airflow.cli.commands.task_command.task_test"),
args=(
ARG_DAG_ID,
ARG_TASK_ID,
- ARG_EXECUTION_DATE_OR_RUN_ID,
+ ARG_EXECUTION_DATE_OR_RUN_ID_OPTIONAL,
ARG_SUBDIR,
ARG_DRY_RUN,
ARG_TASK_PARAMS,
ARG_POST_MORTEM,
ARG_ENV_VARS,
ARG_MAP_INDEX,
+ ARG_VERBOSE,
),
),
ActionCommand(
- name='states-for-dag-run',
+ name="states-for-dag-run",
help="Get the status of all task instances in a dag run",
- func=lazy_load_command('airflow.cli.commands.task_command.task_states_for_dag_run'),
+ func=lazy_load_command("airflow.cli.commands.task_command.task_states_for_dag_run"),
args=(ARG_DAG_ID, ARG_EXECUTION_DATE_OR_RUN_ID, ARG_OUTPUT, ARG_VERBOSE),
),
)
POOLS_COMMANDS = (
ActionCommand(
- name='list',
- help='List pools',
- func=lazy_load_command('airflow.cli.commands.pool_command.pool_list'),
+ name="list",
+ help="List pools",
+ func=lazy_load_command("airflow.cli.commands.pool_command.pool_list"),
args=(ARG_OUTPUT, ARG_VERBOSE),
),
ActionCommand(
- name='get',
- help='Get pool size',
- func=lazy_load_command('airflow.cli.commands.pool_command.pool_get'),
+ name="get",
+ help="Get pool size",
+ func=lazy_load_command("airflow.cli.commands.pool_command.pool_get"),
args=(ARG_POOL_NAME, ARG_OUTPUT, ARG_VERBOSE),
),
ActionCommand(
- name='set',
- help='Configure pool',
- func=lazy_load_command('airflow.cli.commands.pool_command.pool_set'),
+ name="set",
+ help="Configure pool",
+ func=lazy_load_command("airflow.cli.commands.pool_command.pool_set"),
args=(ARG_POOL_NAME, ARG_POOL_SLOTS, ARG_POOL_DESCRIPTION, ARG_OUTPUT, ARG_VERBOSE),
),
ActionCommand(
- name='delete',
- help='Delete pool',
- func=lazy_load_command('airflow.cli.commands.pool_command.pool_delete'),
+ name="delete",
+ help="Delete pool",
+ func=lazy_load_command("airflow.cli.commands.pool_command.pool_delete"),
args=(ARG_POOL_NAME, ARG_OUTPUT, ARG_VERBOSE),
),
ActionCommand(
- name='import',
- help='Import pools',
- func=lazy_load_command('airflow.cli.commands.pool_command.pool_import'),
+ name="import",
+ help="Import pools",
+ func=lazy_load_command("airflow.cli.commands.pool_command.pool_import"),
args=(ARG_POOL_IMPORT, ARG_VERBOSE),
),
ActionCommand(
- name='export',
- help='Export all pools',
- func=lazy_load_command('airflow.cli.commands.pool_command.pool_export'),
- args=(ARG_POOL_EXPORT,),
+ name="export",
+ help="Export all pools",
+ func=lazy_load_command("airflow.cli.commands.pool_command.pool_export"),
+ args=(ARG_POOL_EXPORT, ARG_VERBOSE),
),
)
VARIABLES_COMMANDS = (
ActionCommand(
- name='list',
- help='List variables',
- func=lazy_load_command('airflow.cli.commands.variable_command.variables_list'),
+ name="list",
+ help="List variables",
+ func=lazy_load_command("airflow.cli.commands.variable_command.variables_list"),
args=(ARG_OUTPUT, ARG_VERBOSE),
),
ActionCommand(
- name='get',
- help='Get variable',
- func=lazy_load_command('airflow.cli.commands.variable_command.variables_get'),
- args=(ARG_VAR, ARG_JSON, ARG_DEFAULT, ARG_VERBOSE),
+ name="get",
+ help="Get variable",
+ func=lazy_load_command("airflow.cli.commands.variable_command.variables_get"),
+ args=(ARG_VAR, ARG_DESERIALIZE_JSON, ARG_DEFAULT, ARG_VERBOSE),
),
ActionCommand(
- name='set',
- help='Set variable',
- func=lazy_load_command('airflow.cli.commands.variable_command.variables_set'),
- args=(ARG_VAR, ARG_VAR_VALUE, ARG_JSON),
+ name="set",
+ help="Set variable",
+ func=lazy_load_command("airflow.cli.commands.variable_command.variables_set"),
+ args=(ARG_VAR, ARG_VAR_VALUE, ARG_SERIALIZE_JSON, ARG_VERBOSE),
),
ActionCommand(
- name='delete',
- help='Delete variable',
- func=lazy_load_command('airflow.cli.commands.variable_command.variables_delete'),
- args=(ARG_VAR,),
+ name="delete",
+ help="Delete variable",
+ func=lazy_load_command("airflow.cli.commands.variable_command.variables_delete"),
+ args=(ARG_VAR, ARG_VERBOSE),
),
ActionCommand(
- name='import',
- help='Import variables',
- func=lazy_load_command('airflow.cli.commands.variable_command.variables_import'),
- args=(ARG_VAR_IMPORT,),
+ name="import",
+ help="Import variables",
+ func=lazy_load_command("airflow.cli.commands.variable_command.variables_import"),
+ args=(ARG_VAR_IMPORT, ARG_VERBOSE),
),
ActionCommand(
- name='export',
- help='Export all variables',
- func=lazy_load_command('airflow.cli.commands.variable_command.variables_export'),
- args=(ARG_VAR_EXPORT,),
+ name="export",
+ help="Export all variables",
+ func=lazy_load_command("airflow.cli.commands.variable_command.variables_export"),
+ args=(ARG_VAR_EXPORT, ARG_VERBOSE),
),
)
DB_COMMANDS = (
ActionCommand(
- name='init',
+ name="init",
help="Initialize the metadata database",
- func=lazy_load_command('airflow.cli.commands.db_command.initdb'),
- args=(),
+ func=lazy_load_command("airflow.cli.commands.db_command.initdb"),
+ args=(ARG_VERBOSE,),
),
ActionCommand(
name="check-migrations",
help="Check if migration have finished",
description="Check if migration have finished (or continually check until timeout)",
- func=lazy_load_command('airflow.cli.commands.db_command.check_migrations'),
- args=(ARG_MIGRATION_TIMEOUT,),
+ func=lazy_load_command("airflow.cli.commands.db_command.check_migrations"),
+ args=(ARG_MIGRATION_TIMEOUT, ARG_VERBOSE),
),
ActionCommand(
- name='reset',
+ name="reset",
help="Burn down and rebuild the metadata database",
- func=lazy_load_command('airflow.cli.commands.db_command.resetdb'),
- args=(ARG_YES, ARG_DB_SKIP_INIT),
+ func=lazy_load_command("airflow.cli.commands.db_command.resetdb"),
+ args=(ARG_YES, ARG_DB_SKIP_INIT, ARG_VERBOSE),
),
ActionCommand(
- name='upgrade',
+ name="upgrade",
help="Upgrade the metadata database to latest version",
description=(
"Upgrade the schema of the metadata database. "
@@ -1400,17 +1458,19 @@ class GroupCommand(NamedTuple):
"``--show-sql-only``, because if actually *running* migrations, we should only "
"migrate from the *current* Alembic revision."
),
- func=lazy_load_command('airflow.cli.commands.db_command.upgradedb'),
+ func=lazy_load_command("airflow.cli.commands.db_command.upgradedb"),
args=(
ARG_DB_REVISION__UPGRADE,
ARG_DB_VERSION__UPGRADE,
ARG_DB_SQL_ONLY,
ARG_DB_FROM_REVISION,
ARG_DB_FROM_VERSION,
+ ARG_DB_RESERIALIZE_DAGS,
+ ARG_VERBOSE,
),
),
ActionCommand(
- name='downgrade',
+ name="downgrade",
help="Downgrade the schema of the metadata database.",
description=(
"Downgrade the schema of the metadata database. "
@@ -1420,7 +1480,7 @@ class GroupCommand(NamedTuple):
"because if actually *running* migrations, we should only migrate from the *current* Alembic "
"revision."
),
- func=lazy_load_command('airflow.cli.commands.db_command.downgrade'),
+ func=lazy_load_command("airflow.cli.commands.db_command.downgrade"),
args=(
ARG_DB_REVISION__DOWNGRADE,
ARG_DB_VERSION__DOWNGRADE,
@@ -1428,61 +1488,63 @@ class GroupCommand(NamedTuple):
ARG_YES,
ARG_DB_FROM_REVISION,
ARG_DB_FROM_VERSION,
+ ARG_VERBOSE,
),
),
ActionCommand(
- name='shell',
+ name="shell",
help="Runs a shell to access the database",
- func=lazy_load_command('airflow.cli.commands.db_command.shell'),
- args=(),
+ func=lazy_load_command("airflow.cli.commands.db_command.shell"),
+ args=(ARG_VERBOSE,),
),
ActionCommand(
- name='check',
+ name="check",
help="Check if the database can be reached",
- func=lazy_load_command('airflow.cli.commands.db_command.check'),
- args=(),
+ func=lazy_load_command("airflow.cli.commands.db_command.check"),
+ args=(ARG_VERBOSE,),
),
ActionCommand(
- name='clean',
+ name="clean",
help="Purge old records in metastore tables",
- func=lazy_load_command('airflow.cli.commands.db_command.cleanup_tables'),
+ func=lazy_load_command("airflow.cli.commands.db_command.cleanup_tables"),
args=(
ARG_DB_TABLES,
ARG_DB_DRY_RUN,
ARG_DB_CLEANUP_TIMESTAMP,
ARG_VERBOSE,
ARG_YES,
+ ARG_DB_SKIP_ARCHIVE,
),
),
)
CONNECTIONS_COMMANDS = (
ActionCommand(
- name='get',
- help='Get a connection',
- func=lazy_load_command('airflow.cli.commands.connection_command.connections_get'),
+ name="get",
+ help="Get a connection",
+ func=lazy_load_command("airflow.cli.commands.connection_command.connections_get"),
args=(ARG_CONN_ID, ARG_COLOR, ARG_OUTPUT, ARG_VERBOSE),
),
ActionCommand(
- name='list',
- help='List connections',
- func=lazy_load_command('airflow.cli.commands.connection_command.connections_list'),
+ name="list",
+ help="List connections",
+ func=lazy_load_command("airflow.cli.commands.connection_command.connections_list"),
args=(ARG_OUTPUT, ARG_VERBOSE, ARG_CONN_ID_FILTER),
),
ActionCommand(
- name='add',
- help='Add a connection',
- func=lazy_load_command('airflow.cli.commands.connection_command.connections_add'),
+ name="add",
+ help="Add a connection",
+ func=lazy_load_command("airflow.cli.commands.connection_command.connections_add"),
args=(ARG_CONN_ID, ARG_CONN_URI, ARG_CONN_JSON, ARG_CONN_EXTRA) + tuple(ALTERNATIVE_CONN_SPECS_ARGS),
),
ActionCommand(
- name='delete',
- help='Delete a connection',
- func=lazy_load_command('airflow.cli.commands.connection_command.connections_delete'),
- args=(ARG_CONN_ID, ARG_COLOR),
+ name="delete",
+ help="Delete a connection",
+ func=lazy_load_command("airflow.cli.commands.connection_command.connections_delete"),
+ args=(ARG_CONN_ID, ARG_COLOR, ARG_VERBOSE),
),
ActionCommand(
- name='export',
- help='Export all connections',
+ name="export",
+ help="Export all connections",
description=(
"All connections can be exported in STDOUT using the following command:\n"
"airflow connections export -\n"
@@ -1498,96 +1560,100 @@ class GroupCommand(NamedTuple):
"is used to serialize the connection by passing `uri` or `json` with option "
"`--serialization-format`.\n"
),
- func=lazy_load_command('airflow.cli.commands.connection_command.connections_export'),
+ func=lazy_load_command("airflow.cli.commands.connection_command.connections_export"),
args=(
ARG_CONN_EXPORT,
ARG_CONN_EXPORT_FORMAT,
ARG_CONN_EXPORT_FILE_FORMAT,
ARG_CONN_SERIALIZATION_FORMAT,
+ ARG_VERBOSE,
),
),
ActionCommand(
- name='import',
- help='Import connections from a file',
+ name="import",
+ help="Import connections from a file",
description=(
"Connections can be imported from the output of the export command.\n"
"The filetype must by json, yaml or env and will be automatically inferred."
),
- func=lazy_load_command('airflow.cli.commands.connection_command.connections_import'),
- args=(ARG_CONN_IMPORT,),
+ func=lazy_load_command("airflow.cli.commands.connection_command.connections_import"),
+ args=(
+ ARG_CONN_IMPORT,
+ ARG_VERBOSE,
+ ),
),
)
PROVIDERS_COMMANDS = (
ActionCommand(
- name='list',
- help='List installed providers',
- func=lazy_load_command('airflow.cli.commands.provider_command.providers_list'),
+ name="list",
+ help="List installed providers",
+ func=lazy_load_command("airflow.cli.commands.provider_command.providers_list"),
args=(ARG_OUTPUT, ARG_VERBOSE),
),
ActionCommand(
- name='get',
- help='Get detailed information about a provider',
- func=lazy_load_command('airflow.cli.commands.provider_command.provider_get'),
+ name="get",
+ help="Get detailed information about a provider",
+ func=lazy_load_command("airflow.cli.commands.provider_command.provider_get"),
args=(ARG_OUTPUT, ARG_VERBOSE, ARG_FULL, ARG_COLOR, ARG_PROVIDER_NAME),
),
ActionCommand(
- name='links',
- help='List extra links registered by the providers',
- func=lazy_load_command('airflow.cli.commands.provider_command.extra_links_list'),
+ name="links",
+ help="List extra links registered by the providers",
+ func=lazy_load_command("airflow.cli.commands.provider_command.extra_links_list"),
args=(ARG_OUTPUT, ARG_VERBOSE),
),
ActionCommand(
- name='widgets',
- help='Get information about registered connection form widgets',
- func=lazy_load_command('airflow.cli.commands.provider_command.connection_form_widget_list'),
+ name="widgets",
+ help="Get information about registered connection form widgets",
+ func=lazy_load_command("airflow.cli.commands.provider_command.connection_form_widget_list"),
args=(
ARG_OUTPUT,
ARG_VERBOSE,
),
),
ActionCommand(
- name='hooks',
- help='List registered provider hooks',
- func=lazy_load_command('airflow.cli.commands.provider_command.hooks_list'),
+ name="hooks",
+ help="List registered provider hooks",
+ func=lazy_load_command("airflow.cli.commands.provider_command.hooks_list"),
args=(ARG_OUTPUT, ARG_VERBOSE),
),
ActionCommand(
- name='behaviours',
- help='Get information about registered connection types with custom behaviours',
- func=lazy_load_command('airflow.cli.commands.provider_command.connection_field_behaviours'),
+ name="behaviours",
+ help="Get information about registered connection types with custom behaviours",
+ func=lazy_load_command("airflow.cli.commands.provider_command.connection_field_behaviours"),
args=(ARG_OUTPUT, ARG_VERBOSE),
),
ActionCommand(
- name='logging',
- help='Get information about task logging handlers provided',
- func=lazy_load_command('airflow.cli.commands.provider_command.logging_list'),
+ name="logging",
+ help="Get information about task logging handlers provided",
+ func=lazy_load_command("airflow.cli.commands.provider_command.logging_list"),
args=(ARG_OUTPUT, ARG_VERBOSE),
),
ActionCommand(
- name='secrets',
- help='Get information about secrets backends provided',
- func=lazy_load_command('airflow.cli.commands.provider_command.secrets_backends_list'),
+ name="secrets",
+ help="Get information about secrets backends provided",
+ func=lazy_load_command("airflow.cli.commands.provider_command.secrets_backends_list"),
args=(ARG_OUTPUT, ARG_VERBOSE),
),
ActionCommand(
- name='auth',
- help='Get information about API auth backends provided',
- func=lazy_load_command('airflow.cli.commands.provider_command.auth_backend_list'),
+ name="auth",
+ help="Get information about API auth backends provided",
+ func=lazy_load_command("airflow.cli.commands.provider_command.auth_backend_list"),
args=(ARG_OUTPUT, ARG_VERBOSE),
),
)
USERS_COMMANDS = (
ActionCommand(
- name='list',
- help='List users',
- func=lazy_load_command('airflow.cli.commands.user_command.users_list'),
+ name="list",
+ help="List users",
+ func=lazy_load_command("airflow.cli.commands.user_command.users_list"),
args=(ARG_OUTPUT, ARG_VERBOSE),
),
ActionCommand(
- name='create',
- help='Create a user',
- func=lazy_load_command('airflow.cli.commands.user_command.users_create'),
+ name="create",
+ help="Create a user",
+ func=lazy_load_command("airflow.cli.commands.user_command.users_create"),
args=(
ARG_ROLE,
ARG_USERNAME,
@@ -1596,82 +1662,101 @@ class GroupCommand(NamedTuple):
ARG_LASTNAME,
ARG_PASSWORD,
ARG_USE_RANDOM_PASSWORD,
+ ARG_VERBOSE,
),
epilog=(
- 'examples:\n'
+ "examples:\n"
'To create an user with "Admin" role and username equals to "admin", run:\n'
- '\n'
- ' $ airflow users create \\\n'
- ' --username admin \\\n'
- ' --firstname FIRST_NAME \\\n'
- ' --lastname LAST_NAME \\\n'
- ' --role Admin \\\n'
- ' --email admin@example.org'
+ "\n"
+ " $ airflow users create \\\n"
+ " --username admin \\\n"
+ " --firstname FIRST_NAME \\\n"
+ " --lastname LAST_NAME \\\n"
+ " --role Admin \\\n"
+ " --email admin@example.org"
),
),
ActionCommand(
- name='delete',
- help='Delete a user',
- func=lazy_load_command('airflow.cli.commands.user_command.users_delete'),
- args=(ARG_USERNAME_OPTIONAL, ARG_EMAIL_OPTIONAL),
+ name="delete",
+ help="Delete a user",
+ func=lazy_load_command("airflow.cli.commands.user_command.users_delete"),
+ args=(ARG_USERNAME_OPTIONAL, ARG_EMAIL_OPTIONAL, ARG_VERBOSE),
),
ActionCommand(
- name='add-role',
- help='Add role to a user',
- func=lazy_load_command('airflow.cli.commands.user_command.add_role'),
- args=(ARG_USERNAME_OPTIONAL, ARG_EMAIL_OPTIONAL, ARG_ROLE),
+ name="add-role",
+ help="Add role to a user",
+ func=lazy_load_command("airflow.cli.commands.user_command.add_role"),
+ args=(ARG_USERNAME_OPTIONAL, ARG_EMAIL_OPTIONAL, ARG_ROLE, ARG_VERBOSE),
),
ActionCommand(
- name='remove-role',
- help='Remove role from a user',
- func=lazy_load_command('airflow.cli.commands.user_command.remove_role'),
- args=(ARG_USERNAME_OPTIONAL, ARG_EMAIL_OPTIONAL, ARG_ROLE),
+ name="remove-role",
+ help="Remove role from a user",
+ func=lazy_load_command("airflow.cli.commands.user_command.remove_role"),
+ args=(ARG_USERNAME_OPTIONAL, ARG_EMAIL_OPTIONAL, ARG_ROLE, ARG_VERBOSE),
),
ActionCommand(
- name='import',
- help='Import users',
- func=lazy_load_command('airflow.cli.commands.user_command.users_import'),
- args=(ARG_USER_IMPORT,),
+ name="import",
+ help="Import users",
+ func=lazy_load_command("airflow.cli.commands.user_command.users_import"),
+ args=(ARG_USER_IMPORT, ARG_VERBOSE),
),
ActionCommand(
- name='export',
- help='Export all users',
- func=lazy_load_command('airflow.cli.commands.user_command.users_export'),
- args=(ARG_USER_EXPORT,),
+ name="export",
+ help="Export all users",
+ func=lazy_load_command("airflow.cli.commands.user_command.users_export"),
+ args=(ARG_USER_EXPORT, ARG_VERBOSE),
),
)
ROLES_COMMANDS = (
ActionCommand(
- name='list',
- help='List roles',
- func=lazy_load_command('airflow.cli.commands.role_command.roles_list'),
- args=(ARG_OUTPUT, ARG_VERBOSE),
+ name="list",
+ help="List roles",
+ func=lazy_load_command("airflow.cli.commands.role_command.roles_list"),
+ args=(ARG_PERMISSIONS, ARG_OUTPUT, ARG_VERBOSE),
),
ActionCommand(
- name='create',
- help='Create role',
- func=lazy_load_command('airflow.cli.commands.role_command.roles_create'),
+ name="create",
+ help="Create role",
+ func=lazy_load_command("airflow.cli.commands.role_command.roles_create"),
args=(ARG_ROLES, ARG_VERBOSE),
),
ActionCommand(
- name='export',
- help='Export roles (without permissions) from db to JSON file',
- func=lazy_load_command('airflow.cli.commands.role_command.roles_export'),
+ name="delete",
+ help="Delete role",
+ func=lazy_load_command("airflow.cli.commands.role_command.roles_delete"),
+ args=(ARG_ROLES, ARG_VERBOSE),
+ ),
+ ActionCommand(
+ name="add-perms",
+ help="Add roles permissions",
+ func=lazy_load_command("airflow.cli.commands.role_command.roles_add_perms"),
+ args=(ARG_ROLES, ARG_ROLE_RESOURCE, ARG_ROLE_ACTION_REQUIRED, ARG_VERBOSE),
+ ),
+ ActionCommand(
+ name="del-perms",
+ help="Delete roles permissions",
+ func=lazy_load_command("airflow.cli.commands.role_command.roles_del_perms"),
+ args=(ARG_ROLES, ARG_ROLE_RESOURCE, ARG_ROLE_ACTION, ARG_VERBOSE),
+ ),
+ ActionCommand(
+ name="export",
+ help="Export roles (without permissions) from db to JSON file",
+ func=lazy_load_command("airflow.cli.commands.role_command.roles_export"),
args=(ARG_ROLE_EXPORT, ARG_ROLE_EXPORT_FMT, ARG_VERBOSE),
),
ActionCommand(
- name='import',
- help='Import roles (without permissions) from JSON file to db',
- func=lazy_load_command('airflow.cli.commands.role_command.roles_import'),
+ name="import",
+ help="Import roles (without permissions) from JSON file to db",
+ func=lazy_load_command("airflow.cli.commands.role_command.roles_import"),
args=(ARG_ROLE_IMPORT, ARG_VERBOSE),
),
)
CELERY_COMMANDS = (
ActionCommand(
- name='worker',
+ name="worker",
help="Start a Celery worker node",
- func=lazy_load_command('airflow.cli.commands.celery_command.worker'),
+ func=lazy_load_command("airflow.cli.commands.celery_command.worker"),
args=(
ARG_QUEUES,
ARG_CONCURRENCY,
@@ -1686,12 +1771,13 @@ class GroupCommand(NamedTuple):
ARG_SKIP_SERVE_LOGS,
ARG_WITHOUT_MINGLE,
ARG_WITHOUT_GOSSIP,
+ ARG_VERBOSE,
),
),
ActionCommand(
- name='flower',
+ name="flower",
help="Start a Celery Flower",
- func=lazy_load_command('airflow.cli.commands.celery_command.flower'),
+ func=lazy_load_command("airflow.cli.commands.celery_command.flower"),
args=(
ARG_FLOWER_HOSTNAME,
ARG_FLOWER_PORT,
@@ -1704,117 +1790,135 @@ class GroupCommand(NamedTuple):
ARG_STDOUT,
ARG_STDERR,
ARG_LOG_FILE,
+ ARG_VERBOSE,
),
),
ActionCommand(
- name='stop',
+ name="stop",
help="Stop the Celery worker gracefully",
- func=lazy_load_command('airflow.cli.commands.celery_command.stop_worker'),
- args=(ARG_PID,),
+ func=lazy_load_command("airflow.cli.commands.celery_command.stop_worker"),
+ args=(ARG_PID, ARG_VERBOSE),
),
)
CONFIG_COMMANDS = (
ActionCommand(
- name='get-value',
- help='Print the value of the configuration',
- func=lazy_load_command('airflow.cli.commands.config_command.get_value'),
+ name="get-value",
+ help="Print the value of the configuration",
+ func=lazy_load_command("airflow.cli.commands.config_command.get_value"),
args=(
ARG_SECTION,
ARG_OPTION,
+ ARG_VERBOSE,
),
),
ActionCommand(
- name='list',
- help='List options for the configuration',
- func=lazy_load_command('airflow.cli.commands.config_command.show_config'),
- args=(ARG_COLOR,),
+ name="list",
+ help="List options for the configuration",
+ func=lazy_load_command("airflow.cli.commands.config_command.show_config"),
+ args=(ARG_COLOR, ARG_VERBOSE),
),
)
KUBERNETES_COMMANDS = (
ActionCommand(
- name='cleanup-pods',
+ name="cleanup-pods",
help=(
"Clean up Kubernetes pods "
"(created by KubernetesExecutor/KubernetesPodOperator) "
"in evicted/failed/succeeded/pending states"
),
- func=lazy_load_command('airflow.cli.commands.kubernetes_command.cleanup_pods'),
- args=(ARG_NAMESPACE, ARG_MIN_PENDING_MINUTES),
+ func=lazy_load_command("airflow.cli.commands.kubernetes_command.cleanup_pods"),
+ args=(ARG_NAMESPACE, ARG_MIN_PENDING_MINUTES, ARG_VERBOSE),
),
ActionCommand(
- name='generate-dag-yaml',
+ name="generate-dag-yaml",
help="Generate YAML files for all tasks in DAG. Useful for debugging tasks without "
"launching into a cluster",
- func=lazy_load_command('airflow.cli.commands.kubernetes_command.generate_pod_yaml'),
- args=(ARG_DAG_ID, ARG_EXECUTION_DATE, ARG_SUBDIR, ARG_OUTPUT_PATH),
+ func=lazy_load_command("airflow.cli.commands.kubernetes_command.generate_pod_yaml"),
+ args=(ARG_DAG_ID, ARG_EXECUTION_DATE, ARG_SUBDIR, ARG_OUTPUT_PATH, ARG_VERBOSE),
),
)
JOBS_COMMANDS = (
ActionCommand(
- name='check',
+ name="check",
help="Checks if job(s) are still alive",
- func=lazy_load_command('airflow.cli.commands.jobs_command.check'),
- args=(ARG_JOB_TYPE_FILTER, ARG_JOB_HOSTNAME_FILTER, ARG_JOB_LIMIT, ARG_ALLOW_MULTIPLE),
+ func=lazy_load_command("airflow.cli.commands.jobs_command.check"),
+ args=(
+ ARG_JOB_TYPE_FILTER,
+ ARG_JOB_HOSTNAME_FILTER,
+ ARG_JOB_HOSTNAME_CALLABLE_FILTER,
+ ARG_JOB_LIMIT,
+ ARG_ALLOW_MULTIPLE,
+ ARG_VERBOSE,
+ ),
epilog=(
- 'examples:\n'
- 'To check if the local scheduler is still working properly, run:\n'
- '\n'
- ' $ airflow jobs check --job-type SchedulerJob --hostname "$(hostname)"\n'
- '\n'
- 'To check if any scheduler is running when you are using high availability, run:\n'
- '\n'
- ' $ airflow jobs check --job-type SchedulerJob --allow-multiple --limit 100'
+ "examples:\n"
+ "To check if the local scheduler is still working properly, run:\n"
+ "\n"
+ ' $ airflow jobs check --job-type SchedulerJob --local"\n'
+ "\n"
+ "To check if any scheduler is running when you are using high availability, run:\n"
+ "\n"
+ " $ airflow jobs check --job-type SchedulerJob --allow-multiple --limit 100"
),
),
)
-airflow_commands: List[CLICommand] = [
+airflow_commands: list[CLICommand] = [
GroupCommand(
- name='dags',
- help='Manage DAGs',
+ name="dags",
+ help="Manage DAGs",
subcommands=DAGS_COMMANDS,
),
GroupCommand(
- name="kubernetes", help='Tools to help run the KubernetesExecutor', subcommands=KUBERNETES_COMMANDS
+ name="kubernetes", help="Tools to help run the KubernetesExecutor", subcommands=KUBERNETES_COMMANDS
),
GroupCommand(
- name='tasks',
- help='Manage tasks',
+ name="tasks",
+ help="Manage tasks",
subcommands=TASKS_COMMANDS,
),
GroupCommand(
- name='pools',
+ name="pools",
help="Manage pools",
subcommands=POOLS_COMMANDS,
),
GroupCommand(
- name='variables',
+ name="variables",
help="Manage variables",
subcommands=VARIABLES_COMMANDS,
),
GroupCommand(
- name='jobs',
+ name="jobs",
help="Manage jobs",
subcommands=JOBS_COMMANDS,
),
GroupCommand(
- name='db',
+ name="db",
help="Database operations",
subcommands=DB_COMMANDS,
),
ActionCommand(
- name='kerberos',
+ name="kerberos",
help="Start a kerberos ticket renewer",
- func=lazy_load_command('airflow.cli.commands.kerberos_command.kerberos'),
- args=(ARG_PRINCIPAL, ARG_KEYTAB, ARG_PID, ARG_DAEMON, ARG_STDOUT, ARG_STDERR, ARG_LOG_FILE),
+ func=lazy_load_command("airflow.cli.commands.kerberos_command.kerberos"),
+ args=(
+ ARG_PRINCIPAL,
+ ARG_KEYTAB,
+ ARG_PID,
+ ARG_DAEMON,
+ ARG_STDOUT,
+ ARG_STDERR,
+ ARG_LOG_FILE,
+ ARG_VERBOSE,
+ ),
),
ActionCommand(
- name='webserver',
+ name="webserver",
help="Start a Airflow webserver instance",
- func=lazy_load_command('airflow.cli.commands.webserver_command.webserver'),
+ func=lazy_load_command("airflow.cli.commands.webserver_command.webserver"),
args=(
ARG_PORT,
ARG_WORKERS,
@@ -1835,9 +1939,9 @@ class GroupCommand(NamedTuple):
),
),
ActionCommand(
- name='scheduler',
+ name="scheduler",
help="Start a scheduler instance",
- func=lazy_load_command('airflow.cli.commands.scheduler_command.scheduler'),
+ func=lazy_load_command("airflow.cli.commands.scheduler_command.scheduler"),
args=(
ARG_SUBDIR,
ARG_NUM_RUNS,
@@ -1848,20 +1952,21 @@ class GroupCommand(NamedTuple):
ARG_STDERR,
ARG_LOG_FILE,
ARG_SKIP_SERVE_LOGS,
+ ARG_VERBOSE,
),
epilog=(
- 'Signals:\n'
- '\n'
- ' - SIGUSR2: Dump a snapshot of task state being tracked by the executor.\n'
- '\n'
- ' Example:\n'
+ "Signals:\n"
+ "\n"
+ " - SIGUSR2: Dump a snapshot of task state being tracked by the executor.\n"
+ "\n"
+ " Example:\n"
' pkill -f -USR2 "airflow scheduler"'
),
),
ActionCommand(
- name='triggerer',
+ name="triggerer",
help="Start a triggerer instance",
- func=lazy_load_command('airflow.cli.commands.triggerer_command.triggerer'),
+ func=lazy_load_command("airflow.cli.commands.triggerer_command.triggerer"),
args=(
ARG_PID,
ARG_DAEMON,
@@ -1869,12 +1974,13 @@ class GroupCommand(NamedTuple):
ARG_STDERR,
ARG_LOG_FILE,
ARG_CAPACITY,
+ ARG_VERBOSE,
),
),
ActionCommand(
- name='dag-processor',
+ name="dag-processor",
help="Start a standalone Dag Processor instance",
- func=lazy_load_command('airflow.cli.commands.dag_processor_command.dag_processor'),
+ func=lazy_load_command("airflow.cli.commands.dag_processor_command.dag_processor"),
args=(
ARG_PID,
ARG_DAEMON,
@@ -1884,62 +1990,63 @@ class GroupCommand(NamedTuple):
ARG_STDOUT,
ARG_STDERR,
ARG_LOG_FILE,
+ ARG_VERBOSE,
),
),
ActionCommand(
- name='version',
+ name="version",
help="Show the version",
- func=lazy_load_command('airflow.cli.commands.version_command.version'),
+ func=lazy_load_command("airflow.cli.commands.version_command.version"),
args=(),
),
ActionCommand(
- name='cheat-sheet',
+ name="cheat-sheet",
help="Display cheat sheet",
- func=lazy_load_command('airflow.cli.commands.cheat_sheet_command.cheat_sheet'),
+ func=lazy_load_command("airflow.cli.commands.cheat_sheet_command.cheat_sheet"),
args=(ARG_VERBOSE,),
),
GroupCommand(
- name='connections',
+ name="connections",
help="Manage connections",
subcommands=CONNECTIONS_COMMANDS,
),
GroupCommand(
- name='providers',
+ name="providers",
help="Display providers",
subcommands=PROVIDERS_COMMANDS,
),
GroupCommand(
- name='users',
+ name="users",
help="Manage users",
subcommands=USERS_COMMANDS,
),
GroupCommand(
- name='roles',
- help='Manage roles',
+ name="roles",
+ help="Manage roles",
subcommands=ROLES_COMMANDS,
),
ActionCommand(
- name='sync-perm',
+ name="sync-perm",
help="Update permissions for existing roles and optionally DAGs",
- func=lazy_load_command('airflow.cli.commands.sync_perm_command.sync_perm'),
- args=(ARG_INCLUDE_DAGS,),
+ func=lazy_load_command("airflow.cli.commands.sync_perm_command.sync_perm"),
+ args=(ARG_INCLUDE_DAGS, ARG_VERBOSE),
),
ActionCommand(
- name='rotate-fernet-key',
- func=lazy_load_command('airflow.cli.commands.rotate_fernet_key_command.rotate_fernet_key'),
- help='Rotate encrypted connection credentials and variables',
+ name="rotate-fernet-key",
+ func=lazy_load_command("airflow.cli.commands.rotate_fernet_key_command.rotate_fernet_key"),
+ help="Rotate encrypted connection credentials and variables",
description=(
- 'Rotate all encrypted connection credentials and variables; see '
- 'https://airflow.apache.org/docs/apache-airflow/stable/howto/secure-connections.html'
- '#rotating-encryption-keys'
+ "Rotate all encrypted connection credentials and variables; see "
+ "https://airflow.apache.org/docs/apache-airflow/stable/howto/secure-connections.html"
+ "#rotating-encryption-keys"
),
args=(),
),
- GroupCommand(name="config", help='View configuration', subcommands=CONFIG_COMMANDS),
+ GroupCommand(name="config", help="View configuration", subcommands=CONFIG_COMMANDS),
ActionCommand(
- name='info',
- help='Show information about current Airflow and environment',
- func=lazy_load_command('airflow.cli.commands.info_command.show_info'),
+ name="info",
+ help="Show information about current Airflow and environment",
+ func=lazy_load_command("airflow.cli.commands.info_command.show_info"),
args=(
ARG_ANONYMIZE,
ARG_FILE_IO,
@@ -1948,53 +2055,53 @@ class GroupCommand(NamedTuple):
),
),
ActionCommand(
- name='plugins',
- help='Dump information about loaded plugins',
- func=lazy_load_command('airflow.cli.commands.plugins_command.dump_plugins'),
+ name="plugins",
+ help="Dump information about loaded plugins",
+ func=lazy_load_command("airflow.cli.commands.plugins_command.dump_plugins"),
args=(ARG_OUTPUT, ARG_VERBOSE),
),
GroupCommand(
name="celery",
- help='Celery components',
+ help="Celery components",
description=(
- 'Start celery components. Works only when using CeleryExecutor. For more information, see '
- 'https://airflow.apache.org/docs/apache-airflow/stable/executor/celery.html'
+ "Start celery components. Works only when using CeleryExecutor. For more information, see "
+ "https://airflow.apache.org/docs/apache-airflow/stable/executor/celery.html"
),
subcommands=CELERY_COMMANDS,
),
ActionCommand(
- name='standalone',
- help='Run an all-in-one copy of Airflow',
- func=lazy_load_command('airflow.cli.commands.standalone_command.standalone'),
+ name="standalone",
+ help="Run an all-in-one copy of Airflow",
+ func=lazy_load_command("airflow.cli.commands.standalone_command.standalone"),
args=tuple(),
),
]
-ALL_COMMANDS_DICT: Dict[str, CLICommand] = {sp.name: sp for sp in airflow_commands}
+ALL_COMMANDS_DICT: dict[str, CLICommand] = {sp.name: sp for sp in airflow_commands}
def _remove_dag_id_opt(command: ActionCommand):
cmd = command._asdict()
- cmd['args'] = (arg for arg in command.args if arg is not ARG_DAG_ID)
+ cmd["args"] = (arg for arg in command.args if arg is not ARG_DAG_ID)
return ActionCommand(**cmd)
-dag_cli_commands: List[CLICommand] = [
+dag_cli_commands: list[CLICommand] = [
GroupCommand(
- name='dags',
- help='Manage DAGs',
+ name="dags",
+ help="Manage DAGs",
subcommands=[
_remove_dag_id_opt(sp)
for sp in DAGS_COMMANDS
- if sp.name in ['backfill', 'list-runs', 'pause', 'unpause']
+ if sp.name in ["backfill", "list-runs", "pause", "unpause", "test"]
],
),
GroupCommand(
- name='tasks',
- help='Manage tasks',
- subcommands=[_remove_dag_id_opt(sp) for sp in TASKS_COMMANDS if sp.name in ['list', 'test', 'run']],
+ name="tasks",
+ help="Manage tasks",
+ subcommands=[_remove_dag_id_opt(sp) for sp in TASKS_COMMANDS if sp.name in ["list", "test", "run"]],
),
]
-DAG_CLI_DICT: Dict[str, CLICommand] = {sp.name: sp for sp in dag_cli_commands}
+DAG_CLI_DICT: dict[str, CLICommand] = {sp.name: sp for sp in dag_cli_commands}
class AirflowHelpFormatter(argparse.HelpFormatter):
@@ -2009,7 +2116,7 @@ def _format_action(self, action: Action):
parts = []
action_header = self._format_action_invocation(action)
- action_header = '%*s%s\n' % (self._current_indent, '', action_header)
+ action_header = "%*s%s\n" % (self._current_indent, "", action_header)
parts.append(action_header)
self._indent()
@@ -2018,14 +2125,14 @@ def _format_action(self, action: Action):
lambda d: isinstance(ALL_COMMANDS_DICT[d.dest], GroupCommand), subactions
)
parts.append("\n")
- parts.append('%*s%s:\n' % (self._current_indent, '', "Groups"))
+ parts.append("%*s%s:\n" % (self._current_indent, "", "Groups"))
self._indent()
for subaction in group_subcommands:
parts.append(self._format_action(subaction))
self._dedent()
parts.append("\n")
- parts.append('%*s%s:\n' % (self._current_indent, '', "Commands"))
+ parts.append("%*s%s:\n" % (self._current_indent, "", "Commands"))
self._indent()
for subaction in action_subcommands:
@@ -2041,9 +2148,9 @@ def _format_action(self, action: Action):
@lru_cache(maxsize=None)
def get_parser(dag_parser: bool = False) -> argparse.ArgumentParser:
- """Creates and returns command line argument parser"""
+ """Creates and returns command line argument parser."""
parser = DefaultHelpParser(prog="airflow", formatter_class=AirflowHelpFormatter)
- subparsers = parser.add_subparsers(dest='subcommand', metavar="GROUP_OR_COMMAND")
+ subparsers = parser.add_subparsers(dest="subcommand", metavar="GROUP_OR_COMMAND")
subparsers.required = True
command_dict = DAG_CLI_DICT if dag_parser else ALL_COMMANDS_DICT
@@ -2056,10 +2163,10 @@ def get_parser(dag_parser: bool = False) -> argparse.ArgumentParser:
def _sort_args(args: Iterable[Arg]) -> Iterable[Arg]:
- """Sort subcommand optional args, keep positional args"""
+ """Sort subcommand optional args, keep positional args."""
def get_long_option(arg: Arg):
- """Get long option from Arg.flags"""
+ """Get long option from Arg.flags."""
return arg.flags[0] if len(arg.flags) == 1 else arg.flags[1]
positional, optional = partition(lambda x: x.flags[0].startswith("-"), args)
diff --git a/airflow/cli/commands/celery_command.py b/airflow/cli/commands/celery_command.py
index affe032411c79..0d3e1295dcd4e 100644
--- a/airflow/cli/commands/celery_command.py
+++ b/airflow/cli/commands/celery_command.py
@@ -15,10 +15,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Celery command"""
+"""Celery command."""
+from __future__ import annotations
+from contextlib import contextmanager
from multiprocessing import Process
-from typing import Optional
import daemon
import psutil
@@ -39,10 +40,10 @@
@cli_utils.action_cli
def flower(args):
- """Starts Flower, Celery monitoring tool"""
+ """Starts Flower, Celery monitoring tool."""
options = [
"flower",
- conf.get('celery', 'BROKER_URL'),
+ conf.get("celery", "BROKER_URL"),
f"--address={args.hostname}",
f"--port={args.port}",
]
@@ -67,11 +68,15 @@ def flower(args):
stderr=args.stderr,
log=args.log_file,
)
- with open(stdout, "w+") as stdout, open(stderr, "w+") as stderr:
+ with open(stdout, "a") as stdout, open(stderr, "a") as stderr:
+ stdout.truncate(0)
+ stderr.truncate(0)
+
ctx = daemon.DaemonContext(
pidfile=TimeoutPIDLockFile(pidfile, -1),
stdout=stdout,
stderr=stderr,
+ umask=int(settings.DAEMON_UMASK, 8),
)
with ctx:
celery_app.start(options)
@@ -79,27 +84,21 @@ def flower(args):
celery_app.start(options)
-def _serve_logs(skip_serve_logs: bool = False) -> Optional[Process]:
- """Starts serve_logs sub-process"""
+@contextmanager
+def _serve_logs(skip_serve_logs: bool = False):
+ """Starts serve_logs sub-process."""
+ sub_proc = None
if skip_serve_logs is False:
sub_proc = Process(target=serve_logs)
sub_proc.start()
- return sub_proc
- return None
-
-
-def _run_worker(options, skip_serve_logs):
- sub_proc = _serve_logs(skip_serve_logs)
- try:
- celery_app.worker_main(options)
- finally:
- if sub_proc:
- sub_proc.terminate()
+ yield
+ if sub_proc:
+ sub_proc.terminate()
@cli_utils.action_cli
def worker(args):
- """Starts Airflow Celery worker"""
+ """Starts Airflow Celery worker."""
# Disable connection pool so that celery worker does not hold an unnecessary db connection
settings.reconfigure_orm(disable_connection_pool=True)
if not settings.validate_session():
@@ -120,7 +119,7 @@ def worker(args):
log=args.log_file,
)
- if hasattr(celery_app.backend, 'ResultSession'):
+ if hasattr(celery_app.backend, "ResultSession"):
# Pre-create the database tables now, otherwise SQLA via Celery has a
# race condition where one of the subprocesses can die with "Table
# already exists" error, because SQLA checks for which tables exist,
@@ -137,31 +136,31 @@ def worker(args):
pass
# backwards-compatible: https://github.com/apache/airflow/pull/21506#pullrequestreview-879893763
- celery_log_level = conf.get('logging', 'CELERY_LOGGING_LEVEL')
+ celery_log_level = conf.get("logging", "CELERY_LOGGING_LEVEL")
if not celery_log_level:
- celery_log_level = conf.get('logging', 'LOGGING_LEVEL')
+ celery_log_level = conf.get("logging", "LOGGING_LEVEL")
# Setup Celery worker
options = [
- 'worker',
- '-O',
- 'fair',
- '--queues',
+ "worker",
+ "-O",
+ "fair",
+ "--queues",
args.queues,
- '--concurrency',
+ "--concurrency",
args.concurrency,
- '--hostname',
+ "--hostname",
args.celery_hostname,
- '--loglevel',
+ "--loglevel",
celery_log_level,
- '--pidfile',
+ "--pidfile",
pid_file_path,
]
if autoscale:
- options.extend(['--autoscale', autoscale])
+ options.extend(["--autoscale", autoscale])
if args.without_mingle:
- options.append('--without-mingle')
+ options.append("--without-mingle")
if args.without_gossip:
- options.append('--without-gossip')
+ options.append("--without-gossip")
if conf.has_option("celery", "pool"):
pool = conf.get("celery", "pool")
@@ -171,32 +170,39 @@ def worker(args):
# https://eventlet.net/doc/patching.html#monkey-patch
# Otherwise task instances hang on the workers and are never
# executed.
- maybe_patch_concurrency(['-P', pool])
+ maybe_patch_concurrency(["-P", pool])
if args.daemon:
# Run Celery worker as daemon
handle = setup_logging(log_file)
- with open(stdout, 'w+') as stdout_handle, open(stderr, 'w+') as stderr_handle:
+ with open(stdout, "a") as stdout_handle, open(stderr, "a") as stderr_handle:
if args.umask:
umask = args.umask
+ else:
+ umask = conf.get("celery", "worker_umask", fallback=settings.DAEMON_UMASK)
- ctx = daemon.DaemonContext(
+ stdout_handle.truncate(0)
+ stderr_handle.truncate(0)
+
+ daemon_context = daemon.DaemonContext(
files_preserve=[handle],
umask=int(umask, 8),
stdout=stdout_handle,
stderr=stderr_handle,
)
- with ctx:
- _run_worker(options=options, skip_serve_logs=skip_serve_logs)
+ with daemon_context, _serve_logs(skip_serve_logs):
+ celery_app.worker_main(options)
+
else:
# Run Celery worker in the same process
- _run_worker(options=options, skip_serve_logs=skip_serve_logs)
+ with _serve_logs(skip_serve_logs):
+ celery_app.worker_main(options)
@cli_utils.action_cli
def stop_worker(args):
- """Sends SIGTERM to Celery worker"""
+ """Sends SIGTERM to Celery worker."""
# Read PID from file
if args.pid:
pid_file_path = args.pid
diff --git a/airflow/cli/commands/cheat_sheet_command.py b/airflow/cli/commands/cheat_sheet_command.py
index 001a8721330cb..88d9c5940c7ca 100644
--- a/airflow/cli/commands/cheat_sheet_command.py
+++ b/airflow/cli/commands/cheat_sheet_command.py
@@ -14,7 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Iterable, List, Optional, Union
+from __future__ import annotations
+
+from typing import Iterable
from airflow.cli.cli_parser import ActionCommand, GroupCommand, airflow_commands
from airflow.cli.simple_table import AirflowConsole, SimpleTable
@@ -31,12 +33,12 @@ def display_commands_index():
"""Display list of all commands."""
def display_recursive(
- prefix: List[str],
- commands: Iterable[Union[GroupCommand, ActionCommand]],
- help_msg: Optional[str] = None,
+ prefix: list[str],
+ commands: Iterable[GroupCommand | ActionCommand],
+ help_msg: str | None = None,
):
- actions: List[ActionCommand] = []
- groups: List[GroupCommand] = []
+ actions: list[ActionCommand] = []
+ groups: list[GroupCommand] = []
for command in commands:
if isinstance(command, GroupCommand):
groups.append(command)
diff --git a/airflow/cli/commands/config_command.py b/airflow/cli/commands/config_command.py
index 1c2674fc811b2..f3a1eecfee958 100644
--- a/airflow/cli/commands/config_command.py
+++ b/airflow/cli/commands/config_command.py
@@ -14,7 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Config sub-commands"""
+"""Config sub-commands."""
+from __future__ import annotations
+
import io
import pygments
@@ -26,7 +28,7 @@
def show_config(args):
- """Show current application configuration"""
+ """Show current application configuration."""
with io.StringIO() as output:
conf.write(output)
code = output.getvalue()
@@ -36,12 +38,12 @@ def show_config(args):
def get_value(args):
- """Get one value from configuration"""
+ """Get one value from configuration."""
if not conf.has_section(args.section):
- raise SystemExit(f'The section [{args.section}] is not found in config.')
+ raise SystemExit(f"The section [{args.section}] is not found in config.")
if not conf.has_option(args.section, args.option):
- raise SystemExit(f'The option [{args.section}/{args.option}] is not found in config.')
+ raise SystemExit(f"The option [{args.section}/{args.option}] is not found in config.")
value = conf.get(args.section, args.option)
print(value)
diff --git a/airflow/cli/commands/connection_command.py b/airflow/cli/commands/connection_command.py
index 8a0c0a3acb80a..7737f4fa2cab3 100644
--- a/airflow/cli/commands/connection_command.py
+++ b/airflow/cli/commands/connection_command.py
@@ -14,15 +14,17 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Connection sub-commands"""
+"""Connection sub-commands."""
+from __future__ import annotations
+
import io
import json
import os
import sys
import warnings
from pathlib import Path
-from typing import Any, Dict, List
-from urllib.parse import urlparse, urlunparse
+from typing import Any
+from urllib.parse import urlsplit, urlunsplit
from sqlalchemy.orm import exc
@@ -38,21 +40,21 @@
from airflow.utils.session import create_session
-def _connection_mapper(conn: Connection) -> Dict[str, Any]:
+def _connection_mapper(conn: Connection) -> dict[str, Any]:
return {
- 'id': conn.id,
- 'conn_id': conn.conn_id,
- 'conn_type': conn.conn_type,
- 'description': conn.description,
- 'host': conn.host,
- 'schema': conn.schema,
- 'login': conn.login,
- 'password': conn.password,
- 'port': conn.port,
- 'is_encrypted': conn.is_encrypted,
- 'is_extra_encrypted': conn.is_encrypted,
- 'extra_dejson': conn.extra_dejson,
- 'get_uri': conn.get_uri(),
+ "id": conn.id,
+ "conn_id": conn.conn_id,
+ "conn_type": conn.conn_type,
+ "description": conn.description,
+ "host": conn.host,
+ "schema": conn.schema,
+ "login": conn.login,
+ "password": conn.password,
+ "port": conn.port,
+ "is_encrypted": conn.is_encrypted,
+ "is_extra_encrypted": conn.is_encrypted,
+ "extra_dejson": conn.extra_dejson,
+ "get_uri": conn.get_uri(),
}
@@ -72,7 +74,7 @@ def connections_get(args):
@suppress_logs_and_warning
def connections_list(args):
- """Lists all connections at the command line"""
+ """Lists all connections at the command line."""
with create_session() as session:
query = session.query(Connection)
if args.conn_id:
@@ -99,14 +101,14 @@ def _connection_to_dict(conn: Connection) -> dict:
)
-def _format_connections(conns: List[Connection], file_format: str, serialization_format: str) -> str:
- if serialization_format == 'json':
+def _format_connections(conns: list[Connection], file_format: str, serialization_format: str) -> str:
+ if serialization_format == "json":
serializer_func = lambda x: json.dumps(_connection_to_dict(x))
- elif serialization_format == 'uri':
+ elif serialization_format == "uri":
serializer_func = Connection.get_uri
else:
raise SystemExit(f"Received unexpected value for `--serialization-format`: {serialization_format!r}")
- if file_format == '.env':
+ if file_format == ".env":
connections_env = ""
for conn in conns:
connections_env += f"{conn.conn_id}={serializer_func(conn)}\n"
@@ -116,29 +118,29 @@ def _format_connections(conns: List[Connection], file_format: str, serialization
for conn in conns:
connections_dict[conn.conn_id] = _connection_to_dict(conn)
- if file_format == '.yaml':
+ if file_format == ".yaml":
return yaml.dump(connections_dict)
- if file_format == '.json':
+ if file_format == ".json":
return json.dumps(connections_dict, indent=2)
return json.dumps(connections_dict)
def _is_stdout(fileio: io.TextIOWrapper) -> bool:
- return fileio.name == ''
+ return fileio.name == ""
def _valid_uri(uri: str) -> bool:
- """Check if a URI is valid, by checking if both scheme and netloc are available"""
- uri_parts = urlparse(uri)
- return uri_parts.scheme != '' and uri_parts.netloc != ''
+ """Check if a URI is valid, by checking if both scheme and netloc are available."""
+ uri_parts = urlsplit(uri)
+ return uri_parts.scheme != "" and uri_parts.netloc != ""
@cache
-def _get_connection_types():
+def _get_connection_types() -> list[str]:
"""Returns connection types available."""
- _connection_types = ['fs', 'mesos_framework-id', 'email', 'generic']
+ _connection_types = ["fs", "mesos_framework-id", "email", "generic"]
providers_manager = ProvidersManager()
for connection_type, provider_info in providers_manager.hooks.items():
if provider_info:
@@ -146,18 +148,14 @@ def _get_connection_types():
return _connection_types
-def _valid_conn_type(conn_type: str) -> bool:
- return conn_type in _get_connection_types()
-
-
def connections_export(args):
- """Exports all connections to a file"""
- file_formats = ['.yaml', '.json', '.env']
+ """Exports all connections to a file."""
+ file_formats = [".yaml", ".json", ".env"]
if args.format:
warnings.warn("Option `--format` is deprecated. Use `--file-format` instead.", DeprecationWarning)
if args.format and args.file_format:
- raise SystemExit('Option `--format` is deprecated. Use `--file-format` instead.')
- default_format = '.json'
+ raise SystemExit("Option `--format` is deprecated. Use `--file-format` instead.")
+ default_format = ".json"
provided_file_format = None
if args.format or args.file_format:
provided_file_format = f".{(args.format or args.file_format).lower()}"
@@ -175,7 +173,7 @@ def connections_export(args):
f"Unsupported file format. The file must have the extension {', '.join(file_formats)}."
)
- if args.serialization_format and not filetype == '.env':
+ if args.serialization_format and not filetype == ".env":
raise SystemExit("Option `--serialization-format` may only be used with file type `env`.")
with create_session() as session:
@@ -184,7 +182,7 @@ def connections_export(args):
msg = _format_connections(
conns=connections,
file_format=filetype,
- serialization_format=args.serialization_format or 'uri',
+ serialization_format=args.serialization_format or "uri",
)
with args.file as f:
@@ -196,34 +194,34 @@ def connections_export(args):
print(f"Connections successfully exported to {args.file.name}.")
-alternative_conn_specs = ['conn_type', 'conn_host', 'conn_login', 'conn_password', 'conn_schema', 'conn_port']
+alternative_conn_specs = ["conn_type", "conn_host", "conn_login", "conn_password", "conn_schema", "conn_port"]
@cli_utils.action_cli
def connections_add(args):
- """Adds new connection"""
+ """Adds new connection."""
has_uri = bool(args.conn_uri)
has_json = bool(args.conn_json)
has_type = bool(args.conn_type)
if not has_type and not (has_json or has_uri):
- raise SystemExit('Must supply either conn-uri or conn-json if not supplying conn-type')
+ raise SystemExit("Must supply either conn-uri or conn-json if not supplying conn-type")
if has_json and has_uri:
- raise SystemExit('Cannot supply both conn-uri and conn-json')
+ raise SystemExit("Cannot supply both conn-uri and conn-json")
if has_type and not (args.conn_type in _get_connection_types()):
- warnings.warn(f'The type provided to --conn-type is invalid: {args.conn_type}')
+ warnings.warn(f"The type provided to --conn-type is invalid: {args.conn_type}")
warnings.warn(
- f'Supported --conn-types are:{_get_connection_types()}.'
- 'Hence overriding the conn-type with generic'
+ f"Supported --conn-types are:{_get_connection_types()}."
+ "Hence overriding the conn-type with generic"
)
- args.conn_type = 'generic'
+ args.conn_type = "generic"
if has_uri or has_json:
invalid_args = []
if has_uri and not _valid_uri(args.conn_uri):
- raise SystemExit(f'The URI provided to --conn-uri is invalid: {args.conn_uri}')
+ raise SystemExit(f"The URI provided to --conn-uri is invalid: {args.conn_uri}")
for arg in alternative_conn_specs:
if getattr(args, arg) is not None:
@@ -245,7 +243,7 @@ def connections_add(args):
elif args.conn_json:
new_conn = Connection.from_json(conn_id=args.conn_id, value=args.conn_json)
if not new_conn.conn_type:
- raise SystemExit('conn-json is invalid; must supply conn-type')
+ raise SystemExit("conn-json is invalid; must supply conn-type")
else:
new_conn = Connection(
conn_id=args.conn_id,
@@ -263,38 +261,37 @@ def connections_add(args):
with create_session() as session:
if not session.query(Connection).filter(Connection.conn_id == new_conn.conn_id).first():
session.add(new_conn)
- msg = 'Successfully added `conn_id`={conn_id} : {uri}'
+ msg = "Successfully added `conn_id`={conn_id} : {uri}"
msg = msg.format(
conn_id=new_conn.conn_id,
uri=args.conn_uri
- or urlunparse(
+ or urlunsplit(
(
new_conn.conn_type,
f"{new_conn.login or ''}:{'******' if new_conn.password else ''}"
f"@{new_conn.host or ''}:{new_conn.port or ''}",
- new_conn.schema or '',
- '',
- '',
- '',
+ new_conn.schema or "",
+ "",
+ "",
)
),
)
print(msg)
else:
- msg = f'A connection with `conn_id`={new_conn.conn_id} already exists.'
+ msg = f"A connection with `conn_id`={new_conn.conn_id} already exists."
raise SystemExit(msg)
@cli_utils.action_cli
def connections_delete(args):
- """Deletes connection from DB"""
+ """Deletes connection from DB."""
with create_session() as session:
try:
to_delete = session.query(Connection).filter(Connection.conn_id == args.conn_id).one()
except exc.NoResultFound:
- raise SystemExit(f'Did not find a connection with `conn_id`={args.conn_id}')
+ raise SystemExit(f"Did not find a connection with `conn_id`={args.conn_id}")
except exc.MultipleResultsFound:
- raise SystemExit(f'Found more than one connection with `conn_id`={args.conn_id}')
+ raise SystemExit(f"Found more than one connection with `conn_id`={args.conn_id}")
else:
session.delete(to_delete)
print(f"Successfully deleted connection with `conn_id`={to_delete.conn_id}")
@@ -302,7 +299,7 @@ def connections_delete(args):
@cli_utils.action_cli(check_db=False)
def connections_import(args):
- """Imports connections from a file"""
+ """Imports connections from a file."""
if os.path.exists(args.file):
_import_helper(args.file)
else:
@@ -315,9 +312,9 @@ def _import_helper(file_path):
with create_session() as session:
for conn_id, conn in connections_dict.items():
if session.query(Connection).filter(Connection.conn_id == conn_id).first():
- print(f'Could not import connection {conn_id}: connection already exists.')
+ print(f"Could not import connection {conn_id}: connection already exists.")
continue
session.add(conn)
session.commit()
- print(f'Imported connection {conn_id}')
+ print(f"Imported connection {conn_id}")
diff --git a/airflow/cli/commands/dag_command.py b/airflow/cli/commands/dag_command.py
index a06c5a706907a..0465c474ed916 100644
--- a/airflow/cli/commands/dag_command.py
+++ b/airflow/cli/commands/dag_command.py
@@ -14,8 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+"""Dag sub-commands."""
+from __future__ import annotations
-"""Dag sub-commands"""
import ast
import errno
import json
@@ -23,31 +24,32 @@
import signal
import subprocess
import sys
-from typing import Optional
from graphviz.dot import Dot
+from sqlalchemy.orm import Session
from sqlalchemy.sql.functions import func
from airflow import settings
from airflow.api.client import get_current_api_client
from airflow.cli.simple_table import AirflowConsole
from airflow.configuration import conf
-from airflow.exceptions import AirflowException, BackfillUnfinished
-from airflow.executors.debug_executor import DebugExecutor
+from airflow.exceptions import AirflowException, RemovedInAirflow3Warning
from airflow.jobs.base_job import BaseJob
from airflow.models import DagBag, DagModel, DagRun, TaskInstance
from airflow.models.dag import DAG
from airflow.models.serialized_dag import SerializedDagModel
-from airflow.utils import cli as cli_utils
-from airflow.utils.cli import get_dag, process_subdir, sigint_handler, suppress_logs_and_warning
+from airflow.utils import cli as cli_utils, timezone
+from airflow.utils.cli import get_dag, get_dags, process_subdir, sigint_handler, suppress_logs_and_warning
from airflow.utils.dot_renderer import render_dag, render_dag_dependencies
from airflow.utils.session import NEW_SESSION, create_session, provide_session
from airflow.utils.state import DagRunState
+log = logging.getLogger(__name__)
+
@cli_utils.action_cli
def dag_backfill(args, dag=None):
- """Creates backfill job or dry run for a DAG"""
+ """Creates backfill job or dry run for a DAG or list of DAGs using regex."""
logging.basicConfig(level=settings.LOGGING_LEVEL, format=settings.SIMPLE_LOG_FORMAT)
signal.signal(signal.SIGTERM, sigint_handler)
@@ -55,8 +57,8 @@ def dag_backfill(args, dag=None):
import warnings
warnings.warn(
- '--ignore-first-depends-on-past is deprecated as the value is always set to True',
- category=PendingDeprecationWarning,
+ "--ignore-first-depends-on-past is deprecated as the value is always set to True",
+ category=RemovedInAirflow3Warning,
)
if args.ignore_first_depends_on_past is False:
@@ -65,69 +67,79 @@ def dag_backfill(args, dag=None):
if not args.start_date and not args.end_date:
raise AirflowException("Provide a start_date and/or end_date")
- dag = dag or get_dag(args.subdir, args.dag_id)
+ if not dag:
+ dags = get_dags(args.subdir, dag_id=args.dag_id, use_regex=args.treat_dag_as_regex)
+ else:
+ dags = dag if type(dag) == list else [dag]
+
+ dags.sort(key=lambda d: d.dag_id)
# If only one date is passed, using same as start and end
args.end_date = args.end_date or args.start_date
args.start_date = args.start_date or args.end_date
- if args.task_regex:
- dag = dag.partial_subset(
- task_ids_or_regex=args.task_regex, include_upstream=not args.ignore_dependencies
- )
- if not dag.task_dict:
- raise AirflowException(
- f"There are no tasks that match '{args.task_regex}' regex. Nothing to run, exiting..."
- )
-
run_conf = None
if args.conf:
run_conf = json.loads(args.conf)
- if args.dry_run:
- print(f"Dry run of DAG {args.dag_id} on {args.start_date}")
- dr = DagRun(dag.dag_id, execution_date=args.start_date)
- for task in dag.tasks:
- print(f"Task {task.task_id}")
- ti = TaskInstance(task, run_id=None)
- ti.dag_run = dr
- ti.dry_run()
- else:
- if args.reset_dagruns:
- DAG.clear_dags(
- [dag],
- start_date=args.start_date,
- end_date=args.end_date,
- confirm_prompt=not args.yes,
- include_subdags=True,
- dag_run_state=DagRunState.QUEUED,
- )
-
- try:
- dag.run(
- start_date=args.start_date,
- end_date=args.end_date,
- mark_success=args.mark_success,
- local=args.local,
- donot_pickle=(args.donot_pickle or conf.getboolean('core', 'donot_pickle')),
- ignore_first_depends_on_past=args.ignore_first_depends_on_past,
- ignore_task_deps=args.ignore_dependencies,
- pool=args.pool,
- delay_on_limit_secs=args.delay_on_limit,
- verbose=args.verbose,
- conf=run_conf,
- rerun_failed_tasks=args.rerun_failed_tasks,
- run_backwards=args.run_backwards,
- continue_on_failures=args.continue_on_failures,
+ for dag in dags:
+ if args.task_regex:
+ dag = dag.partial_subset(
+ task_ids_or_regex=args.task_regex, include_upstream=not args.ignore_dependencies
)
- except ValueError as vr:
- print(str(vr))
- sys.exit(1)
+ if not dag.task_dict:
+ raise AirflowException(
+ f"There are no tasks that match '{args.task_regex}' regex. Nothing to run, exiting..."
+ )
+
+ if args.dry_run:
+ print(f"Dry run of DAG {dag.dag_id} on {args.start_date}")
+ dr = DagRun(dag.dag_id, execution_date=args.start_date)
+ for task in dag.tasks:
+ print(f"Task {task.task_id} located in DAG {dag.dag_id}")
+ ti = TaskInstance(task, run_id=None)
+ ti.dag_run = dr
+ ti.dry_run()
+ else:
+ if args.reset_dagruns:
+ DAG.clear_dags(
+ [dag],
+ start_date=args.start_date,
+ end_date=args.end_date,
+ confirm_prompt=not args.yes,
+ include_subdags=True,
+ dag_run_state=DagRunState.QUEUED,
+ )
+
+ try:
+ dag.run(
+ start_date=args.start_date,
+ end_date=args.end_date,
+ mark_success=args.mark_success,
+ local=args.local,
+ donot_pickle=(args.donot_pickle or conf.getboolean("core", "donot_pickle")),
+ ignore_first_depends_on_past=args.ignore_first_depends_on_past,
+ ignore_task_deps=args.ignore_dependencies,
+ pool=args.pool,
+ delay_on_limit_secs=args.delay_on_limit,
+ verbose=args.verbose,
+ conf=run_conf,
+ rerun_failed_tasks=args.rerun_failed_tasks,
+ run_backwards=args.run_backwards,
+ continue_on_failures=args.continue_on_failures,
+ disable_retry=args.disable_retry,
+ )
+ except ValueError as vr:
+ print(str(vr))
+ sys.exit(1)
+
+ if len(dags) > 1:
+ log.info("All of the backfills are done.")
@cli_utils.action_cli
def dag_trigger(args):
- """Creates a dag run for the specified dag"""
+ """Creates a dag run for the specified dag."""
api_client = get_current_api_client()
try:
message = api_client.trigger_dag(
@@ -140,7 +152,7 @@ def dag_trigger(args):
@cli_utils.action_cli
def dag_delete(args):
- """Deletes all DB records related to the specified dag"""
+ """Deletes all DB records related to the specified dag."""
api_client = get_current_api_client()
if (
args.yes
@@ -158,18 +170,18 @@ def dag_delete(args):
@cli_utils.action_cli
def dag_pause(args):
- """Pauses a DAG"""
+ """Pauses a DAG."""
set_is_paused(True, args)
@cli_utils.action_cli
def dag_unpause(args):
- """Unpauses a DAG"""
+ """Unpauses a DAG."""
set_is_paused(False, args)
def set_is_paused(is_paused, args):
- """Sets is_paused for DAG by a given dag_id"""
+ """Sets is_paused for DAG by a given dag_id."""
dag = DagModel.get_dagmodel(args.dag_id)
if not dag:
@@ -181,7 +193,7 @@ def set_is_paused(is_paused, args):
def dag_dependencies_show(args):
- """Displays DAG dependencies, save to file or show as imgcat image"""
+ """Displays DAG dependencies, save to file or show as imgcat image."""
dot = render_dag_dependencies(SerializedDagModel.get_dag_dependencies())
filename = args.save
imgcat = args.imgcat
@@ -200,7 +212,7 @@ def dag_dependencies_show(args):
def dag_show(args):
- """Displays DAG or saves it's graphic representation to the file"""
+ """Displays DAG or saves it's graphic representation to the file."""
dag = get_dag(args.subdir, args.dag_id)
dot = render_dag(dag)
filename = args.save
@@ -220,25 +232,23 @@ def dag_show(args):
def _display_dot_via_imgcat(dot: Dot):
- data = dot.pipe(format='png')
+ data = dot.pipe(format="png")
try:
with subprocess.Popen("imgcat", stdout=subprocess.PIPE, stdin=subprocess.PIPE) as proc:
out, err = proc.communicate(data)
if out:
- print(out.decode('utf-8'))
+ print(out.decode("utf-8"))
if err:
- print(err.decode('utf-8'))
+ print(err.decode("utf-8"))
except OSError as e:
if e.errno == errno.ENOENT:
- raise SystemExit(
- "Failed to execute. Make sure the imgcat executables are on your systems \'PATH\'"
- )
+ raise SystemExit("Failed to execute. Make sure the imgcat executables are on your systems 'PATH'")
else:
raise
def _save_dot_to_file(dot: Dot, filename: str):
- filename_without_ext, _, ext = filename.rpartition('.')
+ filename_without_ext, _, ext = filename.rpartition(".")
dot.render(filename=filename_without_ext, format=ext, cleanup=True)
print(f"File {filename} saved")
@@ -253,16 +263,15 @@ def dag_state(args, session=NEW_SESSION):
>>> airflow dags state a_dag_with_conf_passed 2015-01-01T00:00:00.000000
failed, {"name": "bob", "age": "42"}
"""
-
dag = DagModel.get_dagmodel(args.dag_id, session=session)
if not dag:
raise SystemExit(f"DAG: {args.dag_id} does not exist in 'dag' table")
dr = session.query(DagRun).filter_by(dag_id=args.dag_id, execution_date=args.execution_date).one_or_none()
out = dr.state if dr else None
- conf_out = ''
+ conf_out = ""
if out and dr.conf:
- conf_out = ', ' + json.dumps(dr.conf)
+ conf_out = ", " + json.dumps(dr.conf)
print(str(out) + conf_out)
@@ -284,7 +293,7 @@ def dag_next_execution(args):
.filter(DagRun.dag_id == dag.dag_id)
.subquery()
)
- max_date_run: Optional[DagRun] = (
+ max_date_run: DagRun | None = (
session.query(DagRun)
.filter(DagRun.dag_id == dag.dag_id, DagRun.execution_date == max_date_subq.c.max_date)
.one_or_none()
@@ -314,7 +323,7 @@ def dag_next_execution(args):
@cli_utils.action_cli
@suppress_logs_and_warning
def dag_list_dags(args):
- """Displays dags with or without stats at the command line"""
+ """Displays dags with or without stats at the command line."""
dagbag = DagBag(process_subdir(args.subdir))
if dagbag.import_errors:
from rich import print as rich_print
@@ -339,7 +348,7 @@ def dag_list_dags(args):
@cli_utils.action_cli
@suppress_logs_and_warning
def dag_list_import_errors(args):
- """Displays dags with import errors on the command line"""
+ """Displays dags with import errors on the command line."""
dagbag = DagBag(process_subdir(args.subdir))
data = []
for filename, errors in dagbag.import_errors.items():
@@ -353,7 +362,7 @@ def dag_list_import_errors(args):
@cli_utils.action_cli
@suppress_logs_and_warning
def dag_report(args):
- """Displays dagbag stats at the command line"""
+ """Displays dagbag stats at the command line."""
dagbag = DagBag(process_subdir(args.subdir))
AirflowConsole().print_as(
data=dagbag.dagbag_stats,
@@ -372,7 +381,7 @@ def dag_report(args):
@suppress_logs_and_warning
@provide_session
def dag_list_jobs(args, dag=None, session=NEW_SESSION):
- """Lists latest n jobs"""
+ """Lists latest n jobs."""
queries = []
if dag:
args.dag_id = dag.dag_id
@@ -386,7 +395,7 @@ def dag_list_jobs(args, dag=None, session=NEW_SESSION):
if args.state:
queries.append(BaseJob.state == args.state)
- fields = ['dag_id', 'state', 'job_type', 'start_date', 'end_date']
+ fields = ["dag_id", "state", "job_type", "start_date", "end_date"]
all_jobs = (
session.query(BaseJob).filter(*queries).order_by(BaseJob.start_date.desc()).limit(args.limit).all()
)
@@ -402,7 +411,7 @@ def dag_list_jobs(args, dag=None, session=NEW_SESSION):
@suppress_logs_and_warning
@provide_session
def dag_list_dag_runs(args, dag=None, session=NEW_SESSION):
- """Lists dag runs for a given DAG"""
+ """Lists dag runs for a given DAG."""
if dag:
args.dag_id = dag.dag_id
else:
@@ -430,31 +439,25 @@ def dag_list_dag_runs(args, dag=None, session=NEW_SESSION):
"run_id": dr.run_id,
"state": dr.state,
"execution_date": dr.execution_date.isoformat(),
- "start_date": dr.start_date.isoformat() if dr.start_date else '',
- "end_date": dr.end_date.isoformat() if dr.end_date else '',
+ "start_date": dr.start_date.isoformat() if dr.start_date else "",
+ "end_date": dr.end_date.isoformat() if dr.end_date else "",
},
)
@provide_session
@cli_utils.action_cli
-def dag_test(args, session=None):
- """Execute one single DagRun for a given DAG and execution date, using the DebugExecutor."""
- dag = get_dag(subdir=args.subdir, dag_id=args.dag_id)
- dag.clear(start_date=args.execution_date, end_date=args.execution_date, dag_run_state=False)
- try:
- dag.run(
- executor=DebugExecutor(),
- start_date=args.execution_date,
- end_date=args.execution_date,
- # Always run the DAG at least once even if no logical runs are
- # available. This does not make a lot of sense, but Airflow has
- # been doing this prior to 2.2 so we keep compatibility.
- run_at_least_once=True,
- )
- except BackfillUnfinished as e:
- print(str(e))
-
+def dag_test(args, dag=None, session=None):
+ """Execute one single DagRun for a given DAG and execution date."""
+ run_conf = None
+ if args.conf:
+ try:
+ run_conf = json.loads(args.conf)
+ except ValueError as e:
+ raise SystemExit(f"Configuration {args.conf!r} is not valid JSON. Error: {e}")
+ execution_date = args.execution_date or timezone.utcnow()
+ dag = dag or get_dag(subdir=args.subdir, dag_id=args.dag_id)
+ dag.test(execution_date=execution_date, run_conf=run_conf, session=session)
show_dagrun = args.show_dagrun
imgcat = args.imgcat_dagrun
filename = args.save_dagrun
@@ -463,7 +466,7 @@ def dag_test(args, session=None):
session.query(TaskInstance)
.filter(
TaskInstance.dag_id == args.dag_id,
- TaskInstance.execution_date == args.execution_date,
+ TaskInstance.execution_date == execution_date,
)
.all()
)
@@ -480,10 +483,10 @@ def dag_test(args, session=None):
@provide_session
@cli_utils.action_cli
-def dag_reserialize(args, session=None):
+def dag_reserialize(args, session: Session = NEW_SESSION):
+ """Serialize a DAG instance."""
session.query(SerializedDagModel).delete(synchronize_session=False)
if not args.clear_only:
- dagbag = DagBag()
- dagbag.collect_dags(only_if_updated=False, safe_mode=False)
- dagbag.sync_to_db()
+ dagbag = DagBag(process_subdir(args.subdir))
+ dagbag.sync_to_db(session=session)
diff --git a/airflow/cli/commands/dag_processor_command.py b/airflow/cli/commands/dag_processor_command.py
index d57e26510c974..f8ce65663bd40 100644
--- a/airflow/cli/commands/dag_processor_command.py
+++ b/airflow/cli/commands/dag_processor_command.py
@@ -14,14 +14,16 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+"""DagProcessor command."""
+from __future__ import annotations
-"""DagProcessor command"""
import logging
from datetime import timedelta
import daemon
from daemon.pidfile import TimeoutPIDLockFile
+from airflow import settings
from airflow.configuration import conf
from airflow.dag_processing.manager import DagFileProcessorManager
from airflow.utils import cli as cli_utils
@@ -32,7 +34,7 @@
def _create_dag_processor_manager(args) -> DagFileProcessorManager:
"""Creates DagFileProcessorProcess instance."""
- processor_timeout_seconds: int = conf.getint('core', 'dag_file_processor_timeout')
+ processor_timeout_seconds: int = conf.getint("core", "dag_file_processor_timeout")
processor_timeout = timedelta(seconds=processor_timeout_seconds)
return DagFileProcessorManager(
dag_directory=args.subdir,
@@ -45,13 +47,13 @@ def _create_dag_processor_manager(args) -> DagFileProcessorManager:
@cli_utils.action_cli
def dag_processor(args):
- """Starts Airflow Dag Processor Job"""
+ """Starts Airflow Dag Processor Job."""
if not conf.getboolean("scheduler", "standalone_dag_processor"):
- raise SystemExit('The option [scheduler/standalone_dag_processor] must be True.')
+ raise SystemExit("The option [scheduler/standalone_dag_processor] must be True.")
- sql_conn: str = conf.get('database', 'sql_alchemy_conn').lower()
- if sql_conn.startswith('sqlite'):
- raise SystemExit('Standalone DagProcessor is not supported when using sqlite.')
+ sql_conn: str = conf.get("database", "sql_alchemy_conn").lower()
+ if sql_conn.startswith("sqlite"):
+ raise SystemExit("Standalone DagProcessor is not supported when using sqlite.")
manager = _create_dag_processor_manager(args)
@@ -60,12 +62,16 @@ def dag_processor(args):
"dag-processor", args.pid, args.stdout, args.stderr, args.log_file
)
handle = setup_logging(log_file)
- with open(stdout, 'w+') as stdout_handle, open(stderr, 'w+') as stderr_handle:
+ with open(stdout, "a") as stdout_handle, open(stderr, "a") as stderr_handle:
+ stdout_handle.truncate(0)
+ stderr_handle.truncate(0)
+
ctx = daemon.DaemonContext(
pidfile=TimeoutPIDLockFile(pid, -1),
files_preserve=[handle],
stdout=stdout_handle,
stderr=stderr_handle,
+ umask=int(settings.DAEMON_UMASK, 8),
)
with ctx:
try:
diff --git a/airflow/cli/commands/db_command.py b/airflow/cli/commands/db_command.py
index c9201ad59ba80..4064285db8b45 100644
--- a/airflow/cli/commands/db_command.py
+++ b/airflow/cli/commands/db_command.py
@@ -14,7 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Database sub-commands"""
+"""Database sub-commands."""
+from __future__ import annotations
+
import os
import textwrap
from tempfile import NamedTemporaryFile
@@ -30,14 +32,14 @@
def initdb(args):
- """Initializes the metadata database"""
+ """Initializes the metadata database."""
print("DB: " + repr(settings.engine.url))
db.initdb()
print("Initialization done")
def resetdb(args):
- """Resets the metadata database"""
+ """Resets the metadata database."""
print("DB: " + repr(settings.engine.url))
if not (args.yes or input("This will drop existing tables if they exist. Proceed? (y/n)").upper() == "Y"):
raise SystemExit("Cancelled")
@@ -46,7 +48,7 @@ def resetdb(args):
@cli_utils.action_cli(check_db=False)
def upgradedb(args):
- """Upgrades the metadata database"""
+ """Upgrades the metadata database."""
print("DB: " + repr(settings.engine.url))
if args.to_revision and args.to_version:
raise SystemExit("Cannot supply both `--to-revision` and `--to-version`.")
@@ -61,7 +63,7 @@ def upgradedb(args):
if args.from_revision:
from_revision = args.from_revision
elif args.from_version:
- if parse_version(args.from_version) < parse_version('2.0.0'):
+ if parse_version(args.from_version) < parse_version("2.0.0"):
raise SystemExit("--from-version must be greater or equal to than 2.0.0")
from_revision = REVISION_HEADS_MAP.get(args.from_version)
if not from_revision:
@@ -79,14 +81,19 @@ def upgradedb(args):
else:
print("Generating sql for upgrade -- upgrade commands will *not* be submitted.")
- db.upgradedb(to_revision=to_revision, from_revision=from_revision, show_sql_only=args.show_sql_only)
+ db.upgradedb(
+ to_revision=to_revision,
+ from_revision=from_revision,
+ show_sql_only=args.show_sql_only,
+ reserialize_dags=args.reserialize_dags,
+ )
if not args.show_sql_only:
print("Upgrades done")
@cli_utils.action_cli(check_db=False)
def downgrade(args):
- """Downgrades the metadata database"""
+ """Downgrades the metadata database."""
if args.to_revision and args.to_version:
raise SystemExit("Cannot supply both `--to-revision` and `--to-version`.")
if args.from_version and args.from_revision:
@@ -132,17 +139,17 @@ def downgrade(args):
def check_migrations(args):
- """Function to wait for all airflow migrations to complete. Used for launching airflow in k8s"""
+ """Function to wait for all airflow migrations to complete. Used for launching airflow in k8s."""
db.check_migrations(timeout=args.migration_wait_timeout)
@cli_utils.action_cli(check_db=False)
def shell(args):
- """Run a shell that allows to access metadata database"""
+ """Run a shell that allows to access metadata database."""
url = settings.engine.url
print("DB: " + repr(url))
- if url.get_backend_name() == 'mysql':
+ if url.get_backend_name() == "mysql":
with NamedTemporaryFile(suffix="my.cnf") as f:
content = textwrap.dedent(
f"""
@@ -157,23 +164,23 @@ def shell(args):
f.write(content.encode())
f.flush()
execute_interactive(["mysql", f"--defaults-extra-file={f.name}"])
- elif url.get_backend_name() == 'sqlite':
+ elif url.get_backend_name() == "sqlite":
execute_interactive(["sqlite3", url.database])
- elif url.get_backend_name() == 'postgresql':
+ elif url.get_backend_name() == "postgresql":
env = os.environ.copy()
- env['PGHOST'] = url.host or ""
- env['PGPORT'] = str(url.port or "5432")
- env['PGUSER'] = url.username or ""
+ env["PGHOST"] = url.host or ""
+ env["PGPORT"] = str(url.port or "5432")
+ env["PGUSER"] = url.username or ""
# PostgreSQL does not allow the use of PGPASSFILE if the current user is root.
env["PGPASSWORD"] = url.password or ""
- env['PGDATABASE'] = url.database
+ env["PGDATABASE"] = url.database
execute_interactive(["psql"], env=env)
- elif url.get_backend_name() == 'mssql':
+ elif url.get_backend_name() == "mssql":
env = os.environ.copy()
- env['MSSQL_CLI_SERVER'] = url.host
- env['MSSQL_CLI_DATABASE'] = url.database
- env['MSSQL_CLI_USER'] = url.username
- env['MSSQL_CLI_PASSWORD'] = url.password
+ env["MSSQL_CLI_SERVER"] = url.host
+ env["MSSQL_CLI_DATABASE"] = url.database
+ env["MSSQL_CLI_USER"] = url.username
+ env["MSSQL_CLI_PASSWORD"] = url.password
execute_interactive(["mssql-cli"], env=env)
else:
raise AirflowException(f"Unknown driver: {url.drivername}")
@@ -191,11 +198,12 @@ def check(_):
@cli_utils.action_cli(check_db=False)
def cleanup_tables(args):
- """Purges old records in metadata database"""
+ """Purges old records in metadata database."""
run_cleanup(
table_names=args.tables,
dry_run=args.dry_run,
clean_before_timestamp=args.clean_before_timestamp,
verbose=args.verbose,
confirm=not args.yes,
+ skip_archive=args.skip_archive,
)
diff --git a/airflow/cli/commands/info_command.py b/airflow/cli/commands/info_command.py
index fc03615210af3..7261dfc484156 100644
--- a/airflow/cli/commands/info_command.py
+++ b/airflow/cli/commands/info_command.py
@@ -14,14 +14,16 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Config sub-commands"""
+"""Config sub-commands."""
+from __future__ import annotations
+
import locale
import logging
import os
import platform
import subprocess
import sys
-from typing import List, Optional
+from enum import Enum
from urllib.parse import urlsplit, urlunsplit
import httpx
@@ -42,13 +44,13 @@ class Anonymizer(Protocol):
"""Anonymizer protocol."""
def process_path(self, value) -> str:
- """Remove pii from paths"""
+ """Remove pii from paths."""
def process_username(self, value) -> str:
- """Remove pii from username"""
+ """Remove pii from username."""
def process_url(self, value) -> str:
- """Remove pii from URL"""
+ """Remove pii from URL."""
class NullAnonymizer(Anonymizer):
@@ -123,17 +125,18 @@ def process_url(self, value) -> str:
return urlunsplit((url_parts.scheme, netloc, url_parts.path, url_parts.query, url_parts.fragment))
-class OperatingSystem:
- """Operating system"""
+class OperatingSystem(Enum):
+ """Operating system."""
WINDOWS = "Windows"
LINUX = "Linux"
MACOSX = "Mac OS"
CYGWIN = "Cygwin"
+ UNKNOWN = "Unknown"
@staticmethod
- def get_current() -> Optional[str]:
- """Get current operating system"""
+ def get_current() -> OperatingSystem:
+ """Get current operating system."""
if os.name == "nt":
return OperatingSystem.WINDOWS
elif "linux" in sys.platform:
@@ -142,24 +145,26 @@ def get_current() -> Optional[str]:
return OperatingSystem.MACOSX
elif "cygwin" in sys.platform:
return OperatingSystem.CYGWIN
- return None
+ return OperatingSystem.UNKNOWN
-class Architecture:
- """Compute architecture"""
+class Architecture(Enum):
+ """Compute architecture."""
X86_64 = "x86_64"
X86 = "x86"
PPC = "ppc"
ARM = "arm"
+ UNKNOWN = "unknown"
@staticmethod
- def get_current():
- """Get architecture"""
- return _MACHINE_TO_ARCHITECTURE.get(platform.machine().lower())
+ def get_current() -> Architecture:
+ """Get architecture."""
+ current_architecture = _MACHINE_TO_ARCHITECTURE.get(platform.machine().lower())
+ return current_architecture if current_architecture else Architecture.UNKNOWN
-_MACHINE_TO_ARCHITECTURE = {
+_MACHINE_TO_ARCHITECTURE: dict[str, Architecture] = {
"amd64": Architecture.X86_64,
"x86_64": Architecture.X86_64,
"i686-64": Architecture.X86_64,
@@ -175,17 +180,18 @@ def get_current():
"arm64": Architecture.ARM,
"armv7": Architecture.ARM,
"armv7l": Architecture.ARM,
+ "aarch64": Architecture.ARM,
}
class AirflowInfo:
- """Renders information about Airflow instance"""
+ """Renders information about Airflow instance."""
def __init__(self, anonymizer):
self.anonymizer = anonymizer
@staticmethod
- def _get_version(cmd: List[str], grep: Optional[bytes] = None):
+ def _get_version(cmd: list[str], grep: bytes | None = None):
"""Return tools version."""
try:
with subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) as proc:
@@ -209,10 +215,10 @@ def get_fullname(o):
if module is None or module == str.__class__.__module__:
return o.__class__.__name__ # Avoid reporting __builtin__
else:
- return module + '.' + o.__class__.__name__
+ return module + "." + o.__class__.__name__
try:
- handler_names = [get_fullname(handler) for handler in logging.getLogger('airflow.task').handlers]
+ handler_names = [get_fullname(handler) for handler in logging.getLogger("airflow.task").handlers]
return ", ".join(handler_names)
except Exception:
return "NOT AVAILABLE"
@@ -257,8 +263,8 @@ def _system_info(self):
python_version = sys.version.replace("\n", " ")
return [
- ("OS", operating_system or "NOT AVAILABLE"),
- ("architecture", arch or "NOT AVAILABLE"),
+ ("OS", operating_system.value),
+ ("architecture", arch.value),
("uname", str(uname)),
("locale", str(_locale)),
("python_version", python_version),
@@ -304,10 +310,10 @@ def _paths_info(self):
@property
def _providers_info(self):
- return [(p.data['package-name'], p.version) for p in ProvidersManager().providers.values()]
+ return [(p.data["package-name"], p.version) for p in ProvidersManager().providers.values()]
- def show(self, output: str, console: Optional[AirflowConsole] = None) -> None:
- """Shows information about Airflow instance"""
+ def show(self, output: str, console: AirflowConsole | None = None) -> None:
+ """Shows information about Airflow instance."""
all_info = {
"Apache Airflow": self._airflow_info,
"System info": self._system_info,
@@ -329,7 +335,7 @@ def show(self, output: str, console: Optional[AirflowConsole] = None) -> None:
)
def render_text(self, output: str) -> str:
- """Exports the info to string"""
+ """Exports the info to string."""
console = AirflowConsole(record=True)
with console.capture():
self.show(output=output, console=console)
@@ -337,7 +343,7 @@ def render_text(self, output: str) -> str:
class FileIoException(Exception):
- """Raises when error happens in FileIo.io integration"""
+ """Raises when error happens in FileIo.io integration."""
@tenacity.retry(
@@ -348,7 +354,7 @@ class FileIoException(Exception):
after=tenacity.after_log(log, logging.DEBUG),
)
def _upload_text_to_fileio(content):
- """Upload text file to File.io service and return lnk"""
+ """Upload text file to File.io service and return link."""
resp = httpx.post("https://file.io", content=content)
if resp.status_code not in [200, 201]:
print(resp.json())
diff --git a/airflow/cli/commands/jobs_command.py b/airflow/cli/commands/jobs_command.py
index f6d8d55fb6b9a..c030b5ea9b31b 100644
--- a/airflow/cli/commands/jobs_command.py
+++ b/airflow/cli/commands/jobs_command.py
@@ -14,19 +14,22 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
-from typing import List
+from __future__ import annotations
from airflow.jobs.base_job import BaseJob
+from airflow.utils.net import get_hostname
from airflow.utils.session import provide_session
from airflow.utils.state import State
@provide_session
def check(args, session=None):
- """Checks if job(s) are still alive"""
+ """Checks if job(s) are still alive."""
if args.allow_multiple and not args.limit > 1:
raise SystemExit("To use option --allow-multiple, you must set the limit to a value greater than 1.")
+ if args.hostname and args.local:
+ raise SystemExit("You can't use --hostname and --local at the same time")
+
query = (
session.query(BaseJob)
.filter(BaseJob.state == State.RUNNING)
@@ -36,10 +39,12 @@ def check(args, session=None):
query = query.filter(BaseJob.job_type == args.job_type)
if args.hostname:
query = query.filter(BaseJob.hostname == args.hostname)
+ if args.local:
+ query = query.filter(BaseJob.hostname == get_hostname())
if args.limit > 0:
query = query.limit(args.limit)
- jobs: List[BaseJob] = query.all()
+ jobs: list[BaseJob] = query.all()
alive_jobs = [job for job in jobs if job.is_alive()]
count_alive_jobs = len(alive_jobs)
diff --git a/airflow/cli/commands/kerberos_command.py b/airflow/cli/commands/kerberos_command.py
index fea874349923b..4bbe3f6df919d 100644
--- a/airflow/cli/commands/kerberos_command.py
+++ b/airflow/cli/commands/kerberos_command.py
@@ -14,8 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+"""Kerberos command."""
+from __future__ import annotations
-"""Kerberos command"""
import daemon
from daemon.pidfile import TimeoutPIDLockFile
@@ -27,18 +28,22 @@
@cli_utils.action_cli
def kerberos(args):
- """Start a kerberos ticket renewer"""
+ """Start a kerberos ticket renewer."""
print(settings.HEADER)
if args.daemon:
pid, stdout, stderr, _ = setup_locations(
"kerberos", args.pid, args.stdout, args.stderr, args.log_file
)
- with open(stdout, 'w+') as stdout_handle, open(stderr, 'w+') as stderr_handle:
+ with open(stdout, "a") as stdout_handle, open(stderr, "a") as stderr_handle:
+ stdout_handle.truncate(0)
+ stderr_handle.truncate(0)
+
ctx = daemon.DaemonContext(
pidfile=TimeoutPIDLockFile(pid, -1),
stdout=stdout_handle,
stderr=stderr_handle,
+ umask=int(settings.DAEMON_UMASK, 8),
)
with ctx:
diff --git a/airflow/cli/commands/kubernetes_command.py b/airflow/cli/commands/kubernetes_command.py
index 7c26821780575..052bf339c07b0 100644
--- a/airflow/cli/commands/kubernetes_command.py
+++ b/airflow/cli/commands/kubernetes_command.py
@@ -14,7 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Kubernetes sub-commands"""
+"""Kubernetes sub-commands."""
+from __future__ import annotations
+
import os
import sys
from datetime import datetime, timedelta
@@ -35,7 +37,7 @@
@cli_utils.action_cli
def generate_pod_yaml(args):
- """Generates yaml files for each task in the DAG. Used for testing output of KubernetesExecutor"""
+ """Generates yaml files for each task in the DAG. Used for testing output of KubernetesExecutor."""
execution_date = args.execution_date
dag = get_dag(subdir=args.subdir, dag_id=args.dag_id)
yaml_output_path = args.output_path
@@ -70,7 +72,7 @@ def generate_pod_yaml(args):
@cli_utils.action_cli
def cleanup_pods(args):
- """Clean up k8s pods in evicted/failed/succeeded/pending states"""
+ """Clean up k8s pods in evicted/failed/succeeded/pending states."""
namespace = args.namespace
min_pending_minutes = args.min_pending_minutes
@@ -80,42 +82,42 @@ def cleanup_pods(args):
# https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle/
# All Containers in the Pod have terminated in success, and will not be restarted.
- pod_succeeded = 'succeeded'
+ pod_succeeded = "succeeded"
# The Pod has been accepted by the Kubernetes cluster,
# but one or more of the containers has not been set up and made ready to run.
- pod_pending = 'pending'
+ pod_pending = "pending"
# All Containers in the Pod have terminated, and at least one Container has terminated in failure.
# That is, the Container either exited with non-zero status or was terminated by the system.
- pod_failed = 'failed'
+ pod_failed = "failed"
# https://kubernetes.io/docs/tasks/administer-cluster/out-of-resource/
- pod_reason_evicted = 'evicted'
+ pod_reason_evicted = "evicted"
# If pod is failed and restartPolicy is:
# * Always: Restart Container; Pod phase stays Running.
# * OnFailure: Restart Container; Pod phase stays Running.
# * Never: Pod phase becomes Failed.
- pod_restart_policy_never = 'never'
+ pod_restart_policy_never = "never"
- print('Loading Kubernetes configuration')
+ print("Loading Kubernetes configuration")
kube_client = get_kube_client()
- print(f'Listing pods in namespace {namespace}')
+ print(f"Listing pods in namespace {namespace}")
airflow_pod_labels = [
- 'dag_id',
- 'task_id',
- 'try_number',
- 'airflow_version',
+ "dag_id",
+ "task_id",
+ "try_number",
+ "airflow_version",
]
- list_kwargs = {"namespace": namespace, "limit": 500, "label_selector": ','.join(airflow_pod_labels)}
+ list_kwargs = {"namespace": namespace, "limit": 500, "label_selector": ",".join(airflow_pod_labels)}
while True:
pod_list = kube_client.list_namespaced_pod(**list_kwargs)
for pod in pod_list.items:
pod_name = pod.metadata.name
- print(f'Inspecting pod {pod_name}')
+ print(f"Inspecting pod {pod_name}")
pod_phase = pod.status.phase.lower()
- pod_reason = pod.status.reason.lower() if pod.status.reason else ''
+ pod_reason = pod.status.reason.lower() if pod.status.reason else ""
pod_restart_policy = pod.spec.restart_policy.lower()
current_time = datetime.now(pod.metadata.creation_timestamp.tzinfo)
@@ -138,7 +140,7 @@ def cleanup_pods(args):
except ApiException as e:
print(f"Can't remove POD: {e}", file=sys.stderr)
continue
- print(f'No action taken on pod {pod_name}')
+ print(f"No action taken on pod {pod_name}")
continue_token = pod_list.metadata._continue
if not continue_token:
break
@@ -146,7 +148,7 @@ def cleanup_pods(args):
def _delete_pod(name, namespace):
- """Helper Function for cleanup_pods"""
+ """Helper Function for cleanup_pods."""
core_v1 = client.CoreV1Api()
delete_options = client.V1DeleteOptions()
print(f'Deleting POD "{name}" from "{namespace}" namespace')
diff --git a/airflow/cli/commands/legacy_commands.py b/airflow/cli/commands/legacy_commands.py
index 94f9b690327b2..910c6e442703f 100644
--- a/airflow/cli/commands/legacy_commands.py
+++ b/airflow/cli/commands/legacy_commands.py
@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
from argparse import ArgumentError
@@ -49,7 +50,7 @@
def check_legacy_command(action, value):
- """Checks command value and raise error if value is in removed command"""
+ """Checks command value and raise error if value is in removed command."""
new_command = COMMAND_MAP.get(value)
if new_command is not None:
msg = f"`airflow {value}` command, has been removed, please use `airflow {new_command}`"
diff --git a/airflow/cli/commands/plugins_command.py b/airflow/cli/commands/plugins_command.py
index 2d59e901a187a..50ee583099110 100644
--- a/airflow/cli/commands/plugins_command.py
+++ b/airflow/cli/commands/plugins_command.py
@@ -14,8 +14,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
import inspect
-from typing import Any, Dict, List, Union
+from typing import Any
from airflow import plugins_manager
from airflow.cli.simple_table import AirflowConsole
@@ -31,15 +33,15 @@ def _get_name(class_like_object) -> str:
return class_like_object.__class__.__name__
-def _join_plugins_names(value: Union[List[Any], Any]) -> str:
+def _join_plugins_names(value: list[Any] | Any) -> str:
value = value if isinstance(value, list) else [value]
return ",".join(_get_name(v) for v in value)
@suppress_logs_and_warning
def dump_plugins(args):
- """Dump plugins information"""
- plugins_info: List[Dict[str, str]] = get_plugin_info()
+ """Dump plugins information."""
+ plugins_info: list[dict[str, str]] = get_plugin_info()
if not plugins_manager.plugins:
print("No plugins loaded")
return
diff --git a/airflow/cli/commands/pool_command.py b/airflow/cli/commands/pool_command.py
index e435c2a4833bc..aa56ba8fea7de 100644
--- a/airflow/cli/commands/pool_command.py
+++ b/airflow/cli/commands/pool_command.py
@@ -15,7 +15,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Pools sub-commands"""
+"""Pools sub-commands."""
+from __future__ import annotations
+
import json
import os
from json import JSONDecodeError
@@ -41,7 +43,7 @@ def _show_pools(pools, output):
@suppress_logs_and_warning
def pool_list(args):
- """Displays info of all the pools"""
+ """Displays info of all the pools."""
api_client = get_current_api_client()
pools = api_client.get_pools()
_show_pools(pools=pools, output=args.output)
@@ -49,7 +51,7 @@ def pool_list(args):
@suppress_logs_and_warning
def pool_get(args):
- """Displays pool info by a given name"""
+ """Displays pool info by a given name."""
api_client = get_current_api_client()
try:
pools = [api_client.get_pool(name=args.pool)]
@@ -61,7 +63,7 @@ def pool_get(args):
@cli_utils.action_cli
@suppress_logs_and_warning
def pool_set(args):
- """Creates new pool with a given name and slots"""
+ """Creates new pool with a given name and slots."""
api_client = get_current_api_client()
api_client.create_pool(name=args.pool, slots=args.slots, description=args.description)
print(f"Pool {args.pool} created")
@@ -70,7 +72,7 @@ def pool_set(args):
@cli_utils.action_cli
@suppress_logs_and_warning
def pool_delete(args):
- """Deletes pool by a given name"""
+ """Deletes pool by a given name."""
api_client = get_current_api_client()
try:
api_client.delete_pool(name=args.pool)
@@ -82,7 +84,7 @@ def pool_delete(args):
@cli_utils.action_cli
@suppress_logs_and_warning
def pool_import(args):
- """Imports pools from the file"""
+ """Imports pools from the file."""
if not os.path.exists(args.file):
raise SystemExit(f"Missing pools file {args.file}")
pools, failed = pool_import_helper(args.file)
@@ -92,13 +94,13 @@ def pool_import(args):
def pool_export(args):
- """Exports all of the pools to the file"""
+ """Exports all the pools to the file."""
pools = pool_export_helper(args.file)
print(f"Exported {len(pools)} pools to {args.file}")
def pool_import_helper(filepath):
- """Helps import pools from the json file"""
+ """Helps import pools from the json file."""
api_client = get_current_api_client()
with open(filepath) as poolfile:
@@ -118,12 +120,12 @@ def pool_import_helper(filepath):
def pool_export_helper(filepath):
- """Helps export all of the pools to the json file"""
+ """Helps export all the pools to the json file."""
api_client = get_current_api_client()
pool_dict = {}
pools = api_client.get_pools()
for pool in pools:
pool_dict[pool[0]] = {"slots": pool[1], "description": pool[2]}
- with open(filepath, 'w') as poolfile:
+ with open(filepath, "w") as poolfile:
poolfile.write(json.dumps(pool_dict, sort_keys=True, indent=4))
return pools
diff --git a/airflow/cli/commands/provider_command.py b/airflow/cli/commands/provider_command.py
index b89f77fb5d2a4..cae605679490f 100644
--- a/airflow/cli/commands/provider_command.py
+++ b/airflow/cli/commands/provider_command.py
@@ -14,7 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Providers sub-commands"""
+"""Providers sub-commands."""
+from __future__ import annotations
+
import re
from airflow.cli.simple_table import AirflowConsole
@@ -50,7 +52,7 @@ def provider_get(args):
@suppress_logs_and_warning
def providers_list(args):
- """Lists all providers at the command line"""
+ """Lists all providers at the command line."""
AirflowConsole().print_as(
data=list(ProvidersManager().providers.values()),
output=args.output,
@@ -64,7 +66,7 @@ def providers_list(args):
@suppress_logs_and_warning
def hooks_list(args):
- """Lists all hooks at the command line"""
+ """Lists all hooks at the command line."""
AirflowConsole().print_as(
data=list(ProvidersManager().hooks.items()),
output=args.output,
@@ -72,30 +74,30 @@ def hooks_list(args):
"connection_type": x[0],
"class": x[1].hook_class_name if x[1] else ERROR_IMPORTING_HOOK,
"conn_id_attribute_name": x[1].connection_id_attribute_name if x[1] else ERROR_IMPORTING_HOOK,
- 'package_name': x[1].package_name if x[1] else ERROR_IMPORTING_HOOK,
- 'hook_name': x[1].hook_name if x[1] else ERROR_IMPORTING_HOOK,
+ "package_name": x[1].package_name if x[1] else ERROR_IMPORTING_HOOK,
+ "hook_name": x[1].hook_name if x[1] else ERROR_IMPORTING_HOOK,
},
)
@suppress_logs_and_warning
def connection_form_widget_list(args):
- """Lists all custom connection form fields at the command line"""
+ """Lists all custom connection form fields at the command line."""
AirflowConsole().print_as(
- data=list(ProvidersManager().connection_form_widgets.items()),
+ data=list(sorted(ProvidersManager().connection_form_widgets.items())),
output=args.output,
mapper=lambda x: {
"connection_parameter_name": x[0],
"class": x[1].hook_class_name,
- 'package_name': x[1].package_name,
- 'field_type': x[1].field.field_class.__name__,
+ "package_name": x[1].package_name,
+ "field_type": x[1].field.field_class.__name__,
},
)
@suppress_logs_and_warning
def connection_field_behaviours(args):
- """Lists field behaviours"""
+ """Lists field behaviours."""
AirflowConsole().print_as(
data=list(ProvidersManager().field_behaviours.keys()),
output=args.output,
@@ -107,7 +109,7 @@ def connection_field_behaviours(args):
@suppress_logs_and_warning
def extra_links_list(args):
- """Lists all extra links at the command line"""
+ """Lists all extra links at the command line."""
AirflowConsole().print_as(
data=ProvidersManager().extra_links_class_names,
output=args.output,
@@ -119,7 +121,7 @@ def extra_links_list(args):
@suppress_logs_and_warning
def logging_list(args):
- """Lists all log task handlers at the command line"""
+ """Lists all log task handlers at the command line."""
AirflowConsole().print_as(
data=list(ProvidersManager().logging_class_names),
output=args.output,
@@ -131,7 +133,7 @@ def logging_list(args):
@suppress_logs_and_warning
def secrets_backends_list(args):
- """Lists all secrets backends at the command line"""
+ """Lists all secrets backends at the command line."""
AirflowConsole().print_as(
data=list(ProvidersManager().secrets_backend_class_names),
output=args.output,
@@ -143,7 +145,7 @@ def secrets_backends_list(args):
@suppress_logs_and_warning
def auth_backend_list(args):
- """Lists all API auth backend modules at the command line"""
+ """Lists all API auth backend modules at the command line."""
AirflowConsole().print_as(
data=list(ProvidersManager().auth_backend_module_names),
output=args.output,
diff --git a/airflow/cli/commands/role_command.py b/airflow/cli/commands/role_command.py
index 241a0409c4679..571db3eefe11e 100644
--- a/airflow/cli/commands/role_command.py
+++ b/airflow/cli/commands/role_command.py
@@ -15,8 +15,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-#
-"""Roles sub-commands"""
+"""Roles sub-commands."""
+from __future__ import annotations
+
+import collections
+import itertools
import json
import os
@@ -24,33 +27,128 @@
from airflow.utils import cli as cli_utils
from airflow.utils.cli import suppress_logs_and_warning
from airflow.www.app import cached_app
-from airflow.www.security import EXISTING_ROLES
+from airflow.www.fab_security.sqla.models import Action, Permission, Resource, Role
+from airflow.www.security import EXISTING_ROLES, AirflowSecurityManager
@suppress_logs_and_warning
def roles_list(args):
- """Lists all existing roles"""
+ """Lists all existing roles."""
appbuilder = cached_app().appbuilder
roles = appbuilder.sm.get_all_roles()
+
+ if not args.permission:
+ AirflowConsole().print_as(
+ data=sorted(r.name for r in roles), output=args.output, mapper=lambda x: {"name": x}
+ )
+ return
+
+ permission_map: dict[tuple[str, str], list[str]] = collections.defaultdict(list)
+ for role in roles:
+ for permission in role.permissions:
+ permission_map[(role.name, permission.resource.name)].append(permission.action.name)
+
AirflowConsole().print_as(
- data=sorted(r.name for r in roles), output=args.output, mapper=lambda x: {"name": x}
+ data=sorted(permission_map),
+ output=args.output,
+ mapper=lambda x: {"name": x[0], "resource": x[1], "action": ",".join(sorted(permission_map[x]))},
)
@cli_utils.action_cli
@suppress_logs_and_warning
def roles_create(args):
- """Creates new empty role in DB"""
+ """Creates new empty role in DB."""
appbuilder = cached_app().appbuilder
for role_name in args.role:
appbuilder.sm.add_role(role_name)
print(f"Added {len(args.role)} role(s)")
+@cli_utils.action_cli
+@suppress_logs_and_warning
+def roles_delete(args):
+ """Deletes role in DB."""
+ appbuilder = cached_app().appbuilder
+
+ for role_name in args.role:
+ role = appbuilder.sm.find_role(role_name)
+ if not role:
+ print(f"Role named '{role_name}' does not exist")
+ exit(1)
+
+ for role_name in args.role:
+ appbuilder.sm.delete_role(role_name)
+ print(f"Deleted {len(args.role)} role(s)")
+
+
+def __roles_add_or_remove_permissions(args):
+ asm: AirflowSecurityManager = cached_app().appbuilder.sm
+ is_add: bool = args.subcommand.startswith("add")
+
+ role_map = {}
+ perm_map: dict[tuple[str, str], set[str]] = collections.defaultdict(set)
+ for name in args.role:
+ role: Role | None = asm.find_role(name)
+ if not role:
+ print(f"Role named '{name}' does not exist")
+ exit(1)
+
+ role_map[name] = role
+ for permission in role.permissions:
+ perm_map[(name, permission.resource.name)].add(permission.action.name)
+
+ for name in args.resource:
+ resource: Resource | None = asm.get_resource(name)
+ if not resource:
+ print(f"Resource named '{name}' does not exist")
+ exit(1)
+
+ for name in args.action or []:
+ action: Action | None = asm.get_action(name)
+ if not action:
+ print(f"Action named '{name}' does not exist")
+ exit(1)
+
+ permission_count = 0
+ for (role_name, resource_name, action_name) in list(
+ itertools.product(args.role, args.resource, args.action or [None])
+ ):
+ res_key = (role_name, resource_name)
+ if is_add and action_name not in perm_map[res_key]:
+ perm: Permission | None = asm.create_permission(action_name, resource_name)
+ asm.add_permission_to_role(role_map[role_name], perm)
+ print(f"Added {perm} to role {role_name}")
+ permission_count += 1
+ elif not is_add and res_key in perm_map:
+ for _action_name in perm_map[res_key] if action_name is None else [action_name]:
+ perm: Permission | None = asm.get_permission(_action_name, resource_name)
+ asm.remove_permission_from_role(role_map[role_name], perm)
+ print(f"Deleted {perm} from role {role_name}")
+ permission_count += 1
+
+ print(f"{'Added' if is_add else 'Deleted'} {permission_count} permission(s)")
+
+
+@cli_utils.action_cli
+@suppress_logs_and_warning
+def roles_add_perms(args):
+ """Adds permissions to role in DB."""
+ __roles_add_or_remove_permissions(args)
+
+
+@cli_utils.action_cli
+@suppress_logs_and_warning
+def roles_del_perms(args):
+ """Deletes permissions from role in DB."""
+ __roles_add_or_remove_permissions(args)
+
+
@suppress_logs_and_warning
def roles_export(args):
"""
- Exports all the rules from the data base to a file.
+ Exports all the roles from the database to a file.
+
Note, this function does not export the permissions associated for each role.
Strictly, it exports the role names into the passed role json file.
"""
@@ -58,8 +156,8 @@ def roles_export(args):
roles = appbuilder.sm.get_all_roles()
exporting_roles = [role.name for role in roles if role.name not in EXISTING_ROLES]
filename = os.path.expanduser(args.file)
- kwargs = {} if not args.pretty else {'sort_keys': True, 'indent': 4}
- with open(filename, 'w', encoding='utf-8') as f:
+ kwargs = {} if not args.pretty else {"sort_keys": True, "indent": 4}
+ with open(filename, "w", encoding="utf-8") as f:
json.dump(exporting_roles, f, **kwargs)
print(f"{len(exporting_roles)} roles successfully exported to {filename}")
@@ -69,6 +167,7 @@ def roles_export(args):
def roles_import(args):
"""
Import all the roles into the db from the given json file.
+
Note, this function does not import the permissions for different roles and import them as well.
Strictly, it imports the role names in the role json file passed.
"""
diff --git a/airflow/cli/commands/rotate_fernet_key_command.py b/airflow/cli/commands/rotate_fernet_key_command.py
index 9334344d9610b..f9e18735976c7 100644
--- a/airflow/cli/commands/rotate_fernet_key_command.py
+++ b/airflow/cli/commands/rotate_fernet_key_command.py
@@ -14,7 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Rotate Fernet key command"""
+"""Rotate Fernet key command."""
+from __future__ import annotations
+
from airflow.models import Connection, Variable
from airflow.utils import cli as cli_utils
from airflow.utils.session import create_session
@@ -22,7 +24,7 @@
@cli_utils.action_cli
def rotate_fernet_key(args):
- """Rotates all encrypted connection credentials and variables"""
+ """Rotates all encrypted connection credentials and variables."""
with create_session() as session:
for conn in session.query(Connection).filter(Connection.is_encrypted | Connection.is_extra_encrypted):
conn.rotate_fernet_key()
diff --git a/airflow/cli/commands/scheduler_command.py b/airflow/cli/commands/scheduler_command.py
index bc6e983ee5c8e..44544716df6da 100644
--- a/airflow/cli/commands/scheduler_command.py
+++ b/airflow/cli/commands/scheduler_command.py
@@ -14,44 +14,38 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+"""Scheduler command."""
+from __future__ import annotations
-"""Scheduler command"""
import signal
+from contextlib import contextmanager
from multiprocessing import Process
-from typing import Optional
import daemon
from daemon.pidfile import TimeoutPIDLockFile
from airflow import settings
+from airflow.configuration import conf
from airflow.jobs.scheduler_job import SchedulerJob
from airflow.utils import cli as cli_utils
from airflow.utils.cli import process_subdir, setup_locations, setup_logging, sigint_handler, sigquit_handler
+from airflow.utils.scheduler_health import serve_health_check
-def _create_scheduler_job(args):
+def _run_scheduler_job(args):
job = SchedulerJob(
subdir=process_subdir(args.subdir),
num_runs=args.num_runs,
do_pickle=args.do_pickle,
)
- return job
-
-
-def _run_scheduler_job(args):
- skip_serve_logs = args.skip_serve_logs
- job = _create_scheduler_job(args)
- sub_proc = _serve_logs(skip_serve_logs)
- try:
+ enable_health_check = conf.getboolean("scheduler", "ENABLE_HEALTH_CHECK")
+ with _serve_logs(args.skip_serve_logs), _serve_health_check(enable_health_check):
job.run()
- finally:
- if sub_proc:
- sub_proc.terminate()
@cli_utils.action_cli
def scheduler(args):
- """Starts Airflow Scheduler"""
+ """Starts Airflow Scheduler."""
print(settings.HEADER)
if args.daemon:
@@ -59,12 +53,16 @@ def scheduler(args):
"scheduler", args.pid, args.stdout, args.stderr, args.log_file
)
handle = setup_logging(log_file)
- with open(stdout, 'w+') as stdout_handle, open(stderr, 'w+') as stderr_handle:
+ with open(stdout, "a") as stdout_handle, open(stderr, "a") as stderr_handle:
+ stdout_handle.truncate(0)
+ stderr_handle.truncate(0)
+
ctx = daemon.DaemonContext(
pidfile=TimeoutPIDLockFile(pid, -1),
files_preserve=[handle],
stdout=stdout_handle,
stderr=stderr_handle,
+ umask=int(settings.DAEMON_UMASK, 8),
)
with ctx:
_run_scheduler_job(args=args)
@@ -75,14 +73,29 @@ def scheduler(args):
_run_scheduler_job(args=args)
-def _serve_logs(skip_serve_logs: bool = False) -> Optional[Process]:
- """Starts serve_logs sub-process"""
+@contextmanager
+def _serve_logs(skip_serve_logs: bool = False):
+ """Starts serve_logs sub-process."""
from airflow.configuration import conf
from airflow.utils.serve_logs import serve_logs
+ sub_proc = None
if conf.get("core", "executor") in ["LocalExecutor", "SequentialExecutor"]:
if skip_serve_logs is False:
sub_proc = Process(target=serve_logs)
sub_proc.start()
- return sub_proc
- return None
+ yield
+ if sub_proc:
+ sub_proc.terminate()
+
+
+@contextmanager
+def _serve_health_check(enable_health_check: bool = False):
+ """Starts serve_health_check sub-process."""
+ sub_proc = None
+ if enable_health_check:
+ sub_proc = Process(target=serve_health_check)
+ sub_proc.start()
+ yield
+ if sub_proc:
+ sub_proc.terminate()
diff --git a/airflow/cli/commands/standalone_command.py b/airflow/cli/commands/standalone_command.py
index 3860942adb056..2a5670e83f01e 100644
--- a/airflow/cli/commands/standalone_command.py
+++ b/airflow/cli/commands/standalone_command.py
@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
import logging
import os
@@ -23,7 +24,6 @@
import threading
import time
from collections import deque
-from typing import Dict, List
from termcolor import colored
@@ -55,7 +55,7 @@ def __init__(self):
self.ready_delay = 3
def run(self):
- """Main run loop"""
+ """Main run loop."""
self.print_output("standalone", "Starting Airflow Standalone")
# Silence built-in logging at INFO
logging.getLogger("").setLevel(logging.WARNING)
@@ -82,7 +82,7 @@ def run(self):
env=env,
)
- self.web_server_port = conf.getint('webserver', 'WEB_SERVER_PORT', fallback=8080)
+ self.web_server_port = conf.getint("webserver", "WEB_SERVER_PORT", fallback=8080)
# Run subcommand threads
for command in self.subcommands.values():
command.start()
@@ -116,7 +116,7 @@ def run(self):
self.print_output("standalone", "Complete")
def update_output(self):
- """Drains the output queue and prints its contents to the screen"""
+ """Drains the output queue and prints its contents to the screen."""
while self.output_queue:
# Extract info
name, line = self.output_queue.popleft()
@@ -126,8 +126,9 @@ def update_output(self):
def print_output(self, name: str, output):
"""
- Prints an output line with name and colouring. You can pass multiple
- lines to output if you wish; it will be split for you.
+ Prints an output line with name and colouring.
+
+ You can pass multiple lines to output if you wish; it will be split for you.
"""
color = {
"webserver": "green",
@@ -141,14 +142,16 @@ def print_output(self, name: str, output):
def print_error(self, name: str, output):
"""
- Prints an error message to the console (this is the same as
- print_output but with the text red)
+ Prints an error message to the console.
+
+ This is the same as print_output but with the text red
"""
self.print_output(name, colored(output, "red"))
def calculate_env(self):
"""
Works out the environment variables needed to run subprocesses.
+
We override some settings as part of being standalone.
"""
env = dict(os.environ)
@@ -217,7 +220,8 @@ def is_ready(self):
def port_open(self, port):
"""
Checks if the given port is listening on the local machine.
- (used to tell if webserver is alive)
+
+ Used to tell if webserver is alive.
"""
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
@@ -231,8 +235,9 @@ def port_open(self, port):
def job_running(self, job):
"""
- Checks if the given job name is running and heartbeating correctly
- (used to tell if scheduler is alive)
+ Checks if the given job name is running and heartbeating correctly.
+
+ Used to tell if scheduler is alive.
"""
recent = job.most_recent_job()
if not recent:
@@ -241,8 +246,9 @@ def job_running(self, job):
def print_ready(self):
"""
- Prints the banner shown when Airflow is ready to go, with login
- details.
+ Prints the banner shown when Airflow is ready to go.
+
+ Include with login details.
"""
self.print_output("standalone", "")
self.print_output("standalone", "Airflow is ready")
@@ -260,12 +266,14 @@ def print_ready(self):
class SubCommand(threading.Thread):
"""
+ Execute a subcommand on another thread.
+
Thread that launches a process and then streams its output back to the main
command. We use threads to avoid using select() and raw filehandles, and the
complex logic that brings doing line buffering.
"""
- def __init__(self, parent, name: str, command: List[str], env: Dict[str, str]):
+ def __init__(self, parent, name: str, command: list[str], env: dict[str, str]):
super().__init__()
self.parent = parent
self.name = name
@@ -273,7 +281,7 @@ def __init__(self, parent, name: str, command: List[str], env: Dict[str, str]):
self.env = env
def run(self):
- """Runs the actual process and captures it output to a queue"""
+ """Runs the actual process and captures it output to a queue."""
self.process = subprocess.Popen(
["airflow"] + self.command,
stdout=subprocess.PIPE,
@@ -284,7 +292,7 @@ def run(self):
self.parent.output_queue.append((self.name, line))
def stop(self):
- """Call to stop this process (and thus this thread)"""
+ """Call to stop this process (and thus this thread)."""
self.process.terminate()
diff --git a/airflow/cli/commands/sync_perm_command.py b/airflow/cli/commands/sync_perm_command.py
index d580631fcc8a8..6a92ef99d5c8c 100644
--- a/airflow/cli/commands/sync_perm_command.py
+++ b/airflow/cli/commands/sync_perm_command.py
@@ -15,19 +15,21 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Sync permission command"""
+"""Sync permission command."""
+from __future__ import annotations
+
from airflow.utils import cli as cli_utils
from airflow.www.app import cached_app
@cli_utils.action_cli
def sync_perm(args):
- """Updates permissions for existing roles and DAGs"""
+ """Updates permissions for existing roles and DAGs."""
appbuilder = cached_app().appbuilder
- print('Updating actions and resources for all existing roles')
+ print("Updating actions and resources for all existing roles")
# Add missing permissions for all the Base Views _before_ syncing/creating roles
appbuilder.add_permissions(update_perms=True)
appbuilder.sm.sync_roles()
if args.include_dags:
- print('Updating permission on all DAG views')
+ print("Updating permission on all DAG views")
appbuilder.sm.create_dag_specific_permissions()
diff --git a/airflow/cli/commands/task_command.py b/airflow/cli/commands/task_command.py
index f8caf08487187..2f37579c351d3 100644
--- a/airflow/cli/commands/task_command.py
+++ b/airflow/cli/commands/task_command.py
@@ -15,15 +15,17 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Task sub-commands"""
+"""Task sub-commands."""
+from __future__ import annotations
+
import datetime
import importlib
import json
import logging
import os
import textwrap
-from contextlib import contextmanager, redirect_stderr, redirect_stdout
-from typing import Dict, Generator, List, Optional, Tuple, Union
+from contextlib import contextmanager, redirect_stderr, redirect_stdout, suppress
+from typing import Generator, Union
from pendulum.parsing.exceptions import ParserError
from sqlalchemy.orm.exc import NoResultFound
@@ -35,17 +37,18 @@
from airflow.exceptions import AirflowException, DagRunNotFound, TaskInstanceNotFound
from airflow.executors.executor_loader import ExecutorLoader
from airflow.jobs.local_task_job import LocalTaskJob
+from airflow.listeners.listener import get_listener_manager
from airflow.models import DagPickle, TaskInstance
from airflow.models.baseoperator import BaseOperator
from airflow.models.dag import DAG
from airflow.models.dagrun import DagRun
+from airflow.models.operator import needs_expansion
from airflow.ti_deps.dep_context import DepContext
from airflow.ti_deps.dependencies_deps import SCHEDULER_QUEUED_DEPS
from airflow.typing_compat import Literal
from airflow.utils import cli as cli_utils
from airflow.utils.cli import (
get_dag,
- get_dag_by_deserialization,
get_dag_by_file_location,
get_dag_by_pickle,
get_dags,
@@ -53,6 +56,7 @@
)
from airflow.utils.dates import timezone
from airflow.utils.log.logging_mixin import StreamLogWriter
+from airflow.utils.log.secrets_masker import RedactedIO
from airflow.utils.net import get_hostname
from airflow.utils.session import NEW_SESSION, create_session, provide_session
from airflow.utils.state import DagRunState
@@ -74,10 +78,10 @@ def _generate_temporary_run_id() -> str:
def _get_dag_run(
*,
dag: DAG,
- exec_date_or_run_id: str,
create_if_necessary: CreateIfNecessary,
+ exec_date_or_run_id: str | None = None,
session: Session,
-) -> Tuple[DagRun, bool]:
+) -> tuple[DagRun, bool]:
"""Try to retrieve a DAG run from a string representing either a run ID or logical date.
This checks DAG runs like this:
@@ -91,33 +95,35 @@ def _get_dag_run(
the logical date; otherwise use it as a run ID and set the logical date
to the current time.
"""
- dag_run = dag.get_dagrun(run_id=exec_date_or_run_id, session=session)
- if dag_run:
- return dag_run, False
-
- try:
- execution_date: Optional[datetime.datetime] = timezone.parse(exec_date_or_run_id)
- except (ParserError, TypeError):
- execution_date = None
-
- try:
- dag_run = (
- session.query(DagRun)
- .filter(DagRun.dag_id == dag.dag_id, DagRun.execution_date == execution_date)
- .one()
- )
- except NoResultFound:
- if not create_if_necessary:
- raise DagRunNotFound(
- f"DagRun for {dag.dag_id} with run_id or execution_date of {exec_date_or_run_id!r} not found"
- ) from None
- else:
- return dag_run, False
+ if not exec_date_or_run_id and not create_if_necessary:
+ raise ValueError("Must provide `exec_date_or_run_id` if not `create_if_necessary`.")
+ execution_date: datetime.datetime | None = None
+ if exec_date_or_run_id:
+ dag_run = dag.get_dagrun(run_id=exec_date_or_run_id, session=session)
+ if dag_run:
+ return dag_run, False
+ with suppress(ParserError, TypeError):
+ execution_date = timezone.parse(exec_date_or_run_id)
+ try:
+ dag_run = (
+ session.query(DagRun)
+ .filter(DagRun.dag_id == dag.dag_id, DagRun.execution_date == execution_date)
+ .one()
+ )
+ except NoResultFound:
+ if not create_if_necessary:
+ raise DagRunNotFound(
+ f"DagRun for {dag.dag_id} with run_id or execution_date "
+ f"of {exec_date_or_run_id!r} not found"
+ ) from None
+ else:
+ return dag_run, False
if execution_date is not None:
dag_run_execution_date = execution_date
else:
dag_run_execution_date = timezone.utcnow()
+
if create_if_necessary == "memory":
dag_run = DagRun(dag.dag_id, run_id=exec_date_or_run_id, execution_date=dag_run_execution_date)
return dag_run, True
@@ -135,15 +141,17 @@ def _get_dag_run(
@provide_session
def _get_ti(
task: BaseOperator,
- exec_date_or_run_id: str,
map_index: int,
*,
- pool: Optional[str] = None,
+ exec_date_or_run_id: str | None = None,
+ pool: str | None = None,
create_if_necessary: CreateIfNecessary = False,
session: Session = NEW_SESSION,
-) -> Tuple[TaskInstance, bool]:
- """Get the task instance through DagRun.run_id, if that fails, get the TI the old way"""
- if task.is_mapped:
+) -> tuple[TaskInstance, bool]:
+ """Get the task instance through DagRun.run_id, if that fails, get the TI the old way."""
+ if not exec_date_or_run_id and not create_if_necessary:
+ raise ValueError("Must provide `exec_date_or_run_id` if not `create_if_necessary`.")
+ if needs_expansion(task):
if map_index < 0:
raise RuntimeError("No map_index passed to mapped task")
elif map_index >= 0:
@@ -173,7 +181,9 @@ def _get_ti(
def _run_task_by_selected_method(args, dag: DAG, ti: TaskInstance) -> None:
"""
- Runs the task in one of 3 modes
+ Runs the task based on a mode.
+
+ Any of the 3 modes are available:
- using LocalTaskJob
- as raw task
@@ -189,8 +199,9 @@ def _run_task_by_selected_method(args, dag: DAG, ti: TaskInstance) -> None:
def _run_task_by_executor(args, dag, ti):
"""
- Sends the task to the executor for execution. This can result in the task being started by another host
- if the executor implementation does
+ Sends the task to the executor for execution.
+
+ This can result in the task being started by another host if the executor implementation does.
"""
pickle_id = None
if args.ship_dag:
@@ -201,9 +212,9 @@ def _run_task_by_executor(args, dag, ti):
session.add(pickle)
pickle_id = pickle.id
# TODO: This should be written to a log
- print(f'Pickled dag {dag} as pickle_id: {pickle_id}')
+ print(f"Pickled dag {dag} as pickle_id: {pickle_id}")
except Exception as e:
- print('Could not pickle the DAG')
+ print("Could not pickle the DAG")
print(e)
raise e
executor = ExecutorLoader.get_default_executor()
@@ -225,7 +236,7 @@ def _run_task_by_executor(args, dag, ti):
def _run_task_by_local_task_job(args, ti):
- """Run LocalTaskJob, which monitors the raw task execution process"""
+ """Run LocalTaskJob, which monitors the raw task execution process."""
run_job = LocalTaskJob(
task_instance=ti,
mark_success=args.mark_success,
@@ -254,7 +265,7 @@ def _run_task_by_local_task_job(args, ti):
def _run_raw_task(args, ti: TaskInstance) -> None:
- """Runs the main task handling code"""
+ """Runs the main task handling code."""
ti._run_raw_task(
mark_success=args.mark_success,
job_id=args.job_id,
@@ -262,7 +273,7 @@ def _run_raw_task(args, ti: TaskInstance) -> None:
)
-def _extract_external_executor_id(args) -> Optional[str]:
+def _extract_external_executor_id(args) -> str | None:
if hasattr(args, "external_executor_id"):
return getattr(args, "external_executor_id")
return os.environ.get("external_executor_id", None)
@@ -270,7 +281,8 @@ def _extract_external_executor_id(args) -> Optional[str]:
@contextmanager
def _capture_task_logs(ti: TaskInstance) -> Generator[None, None, None]:
- """Manage logging context for a task run
+ """
+ Manage logging context for a task run.
- Replace the root logger configuration with the airflow.task configuration
so we can capture logs from any custom loggers used in the task.
@@ -280,9 +292,8 @@ def _capture_task_logs(ti: TaskInstance) -> Generator[None, None, None]:
"""
modify = not settings.DONOT_MODIFY_HANDLERS
-
if modify:
- root_logger, task_logger = logging.getLogger(), logging.getLogger('airflow.task')
+ root_logger, task_logger = logging.getLogger(), logging.getLogger("airflow.task")
orig_level = root_logger.level
root_logger.setLevel(task_logger.level)
@@ -303,9 +314,14 @@ def _capture_task_logs(ti: TaskInstance) -> Generator[None, None, None]:
root_logger.handlers[:] = orig_handlers
+class TaskCommandMarker:
+ """Marker for listener hooks, to properly detect from which component they are called."""
+
+
@cli_utils.action_cli(check_db=False)
def task_run(args, dag=None):
- """Run a single task instance.
+ """
+ Run a single task instance.
Note that there must be at least one DagRun for this to start,
i.e. it must have been scheduled and/or triggered previously.
@@ -324,8 +340,8 @@ def task_run(args, dag=None):
unsupported_options = [o for o in RAW_TASK_UNSUPPORTED_OPTION if getattr(args, o)]
if unsupported_options:
- unsupported_raw_task_flags = ', '.join(f'--{o}' for o in RAW_TASK_UNSUPPORTED_OPTION)
- unsupported_flags = ', '.join(f'--{o}' for o in unsupported_options)
+ unsupported_raw_task_flags = ", ".join(f"--{o}" for o in RAW_TASK_UNSUPPORTED_OPTION)
+ unsupported_flags = ", ".join(f"--{o}" for o in unsupported_options)
raise AirflowException(
"Option --raw does not work with some of the other options on this command. "
"You can't use --raw option and the following options: "
@@ -353,39 +369,42 @@ def task_run(args, dag=None):
# processing hundreds of simultaneous tasks.
settings.reconfigure_orm(disable_connection_pool=True)
+ get_listener_manager().hook.on_starting(component=TaskCommandMarker())
+
if args.pickle:
- print(f'Loading pickle id: {args.pickle}')
+ print(f"Loading pickle id: {args.pickle}")
dag = get_dag_by_pickle(args.pickle)
elif not dag:
- if args.local:
- try:
- dag = get_dag_by_deserialization(args.dag_id)
- except AirflowException:
- print(f'DAG {args.dag_id} does not exist in the database, trying to parse the dag_file')
- dag = get_dag(args.subdir, args.dag_id)
- else:
- dag = get_dag(args.subdir, args.dag_id)
+ dag = get_dag(args.subdir, args.dag_id)
else:
# Use DAG from parameter
pass
task = dag.get_task(task_id=args.task_id)
- ti, _ = _get_ti(task, args.execution_date_or_run_id, args.map_index, pool=args.pool)
+ ti, _ = _get_ti(task, args.map_index, exec_date_or_run_id=args.execution_date_or_run_id, pool=args.pool)
ti.init_run_context(raw=args.raw)
hostname = get_hostname()
log.info("Running %s on host %s", ti, hostname)
- if args.interactive:
- _run_task_by_selected_method(args, dag, ti)
- else:
- with _capture_task_logs(ti):
+ try:
+ if args.interactive:
_run_task_by_selected_method(args, dag, ti)
+ else:
+ with _capture_task_logs(ti):
+ _run_task_by_selected_method(args, dag, ti)
+ finally:
+ try:
+ get_listener_manager().hook.before_stopping(component=TaskCommandMarker())
+ except Exception:
+ pass
@cli_utils.action_cli(check_db=False)
def task_failed_deps(args):
"""
+ Get task instance dependencies that were not met.
+
Returns the unmet dependencies for a task instance from the perspective of the
scheduler (i.e. why a task instance doesn't get scheduled and then queued by the
scheduler, and then run by an executor).
@@ -397,7 +416,7 @@ def task_failed_deps(args):
"""
dag = get_dag(args.subdir, args.dag_id)
task = dag.get_task(task_id=args.task_id)
- ti, _ = _get_ti(task, args.execution_date_or_run_id, args.map_index)
+ ti, _ = _get_ti(task, args.map_index, exec_date_or_run_id=args.execution_date_or_run_id)
dep_context = DepContext(deps=SCHEDULER_QUEUED_DEPS)
failed_deps = list(ti.get_failed_dep_statuses(dep_context=dep_context))
@@ -420,14 +439,14 @@ def task_state(args):
"""
dag = get_dag(args.subdir, args.dag_id)
task = dag.get_task(task_id=args.task_id)
- ti, _ = _get_ti(task, args.execution_date_or_run_id, args.map_index)
+ ti, _ = _get_ti(task, args.map_index, exec_date_or_run_id=args.execution_date_or_run_id)
print(ti.current_state())
@cli_utils.action_cli(check_db=False)
@suppress_logs_and_warning
def task_list(args, dag=None):
- """Lists the tasks within a DAG at the command line"""
+ """Lists the tasks within a DAG at the command line."""
dag = dag or get_dag(args.subdir, args.dag_id)
if args.tree:
dag.tree_view()
@@ -436,7 +455,7 @@ def task_list(args, dag=None):
print("\n".join(tasks))
-SUPPORTED_DEBUGGER_MODULES: List[str] = [
+SUPPORTED_DEBUGGER_MODULES: list[str] = [
"pudb",
"web_pdb",
"ipdb",
@@ -446,8 +465,9 @@ def task_list(args, dag=None):
def _guess_debugger():
"""
- Trying to guess the debugger used by the user. When it doesn't find any user-installed debugger,
- returns ``pdb``.
+ Trying to guess the debugger used by the user.
+
+ When it doesn't find any user-installed debugger, returns ``pdb``.
List of supported debuggers:
@@ -468,7 +488,7 @@ def _guess_debugger():
@suppress_logs_and_warning
@provide_session
def task_states_for_dag_run(args, session=None):
- """Get the status of all task instances in a DagRun"""
+ """Get the status of all task instances in a DagRun."""
dag_run = (
session.query(DagRun)
.filter(DagRun.run_id == args.execution_date_or_run_id, DagRun.dag_id == args.dag_id)
@@ -493,7 +513,7 @@ def task_states_for_dag_run(args, session=None):
has_mapped_instances = any(ti.map_index >= 0 for ti in dag_run.task_instances)
- def format_task_instance(ti: TaskInstance) -> Dict[str, str]:
+ def format_task_instance(ti: TaskInstance) -> dict[str, str]:
data = {
"dag_id": ti.dag_id,
"execution_date": dag_run.execution_date.isoformat(),
@@ -511,23 +531,23 @@ def format_task_instance(ti: TaskInstance) -> Dict[str, str]:
@cli_utils.action_cli(check_db=False)
def task_test(args, dag=None):
- """Tests task for a given dag_id"""
+ """Tests task for a given dag_id."""
# We want to log output from operators etc to show up here. Normally
# airflow.task would redirect to a file, but here we want it to propagate
# up to the normal airflow handler.
settings.MASK_SECRETS_IN_LOGS = True
- handlers = logging.getLogger('airflow.task').handlers
+ handlers = logging.getLogger("airflow.task").handlers
already_has_stream_handler = False
for handler in handlers:
already_has_stream_handler = isinstance(handler, logging.StreamHandler)
if already_has_stream_handler:
break
if not already_has_stream_handler:
- logging.getLogger('airflow.task').propagate = True
+ logging.getLogger("airflow.task").propagate = True
- env_vars = {'AIRFLOW_TEST_MODE': 'True'}
+ env_vars = {"AIRFLOW_TEST_MODE": "True"}
if args.env_vars:
env_vars.update(args.env_vars)
os.environ.update(env_vars)
@@ -543,13 +563,16 @@ def task_test(args, dag=None):
if task.params:
task.params.validate()
- ti, dr_created = _get_ti(task, args.execution_date_or_run_id, args.map_index, create_if_necessary="db")
+ ti, dr_created = _get_ti(
+ task, args.map_index, exec_date_or_run_id=args.execution_date_or_run_id, create_if_necessary="db"
+ )
try:
- if args.dry_run:
- ti.dry_run()
- else:
- ti.run(ignore_task_deps=True, ignore_ti_state=True, test_mode=True)
+ with redirect_stdout(RedactedIO()):
+ if args.dry_run:
+ ti.dry_run()
+ else:
+ ti.run(ignore_task_deps=True, ignore_ti_state=True, test_mode=True)
except Exception:
if args.post_mortem:
debugger = _guess_debugger()
@@ -560,7 +583,7 @@ def task_test(args, dag=None):
if not already_has_stream_handler:
# Make sure to reset back to normal. When run for CLI this doesn't
# matter, but it does for test suite
- logging.getLogger('airflow.task').propagate = False
+ logging.getLogger("airflow.task").propagate = False
if dr_created:
with create_session() as session:
session.delete(ti.dag_run)
@@ -568,19 +591,22 @@ def task_test(args, dag=None):
@cli_utils.action_cli(check_db=False)
@suppress_logs_and_warning
-def task_render(args):
- """Renders and displays templated fields for a given task"""
- dag = get_dag(args.subdir, args.dag_id)
+def task_render(args, dag=None):
+ """Renders and displays templated fields for a given task."""
+ if not dag:
+ dag = get_dag(args.subdir, args.dag_id)
task = dag.get_task(task_id=args.task_id)
- ti, _ = _get_ti(task, args.execution_date_or_run_id, args.map_index, create_if_necessary="memory")
+ ti, _ = _get_ti(
+ task, args.map_index, exec_date_or_run_id=args.execution_date_or_run_id, create_if_necessary="memory"
+ )
ti.render_templates()
- for attr in task.__class__.template_fields:
+ for attr in task.template_fields:
print(
textwrap.dedent(
f""" # ----------------------------------------------------------
# property: {attr}
# ----------------------------------------------------------
- {getattr(task, attr)}
+ {getattr(ti.task, attr)}
"""
)
)
@@ -588,7 +614,7 @@ def task_render(args):
@cli_utils.action_cli(check_db=False)
def task_clear(args):
- """Clears all task instances or only those matched by regex for a DAG(s)"""
+ """Clears all task instances or only those matched by regex for a DAG(s)."""
logging.basicConfig(level=settings.LOGGING_LEVEL, format=settings.SIMPLE_LOG_FORMAT)
if args.dag_id and not args.subdir and not args.dag_regex and not args.task_regex:
diff --git a/airflow/cli/commands/triggerer_command.py b/airflow/cli/commands/triggerer_command.py
index 8bf419268059b..64755f3830291 100644
--- a/airflow/cli/commands/triggerer_command.py
+++ b/airflow/cli/commands/triggerer_command.py
@@ -14,8 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+"""Triggerer command."""
+from __future__ import annotations
-"""Triggerer command"""
import signal
import daemon
@@ -24,12 +25,12 @@
from airflow import settings
from airflow.jobs.triggerer_job import TriggererJob
from airflow.utils import cli as cli_utils
-from airflow.utils.cli import setup_locations, setup_logging, sigquit_handler
+from airflow.utils.cli import setup_locations, setup_logging, sigint_handler, sigquit_handler
@cli_utils.action_cli
def triggerer(args):
- """Starts Airflow Triggerer"""
+ """Starts Airflow Triggerer."""
settings.MASK_SECRETS_IN_LOGS = True
print(settings.HEADER)
job = TriggererJob(capacity=args.capacity)
@@ -39,30 +40,22 @@ def triggerer(args):
"triggerer", args.pid, args.stdout, args.stderr, args.log_file
)
handle = setup_logging(log_file)
- with open(stdout, 'w+') as stdout_handle, open(stderr, 'w+') as stderr_handle:
+ with open(stdout, "a") as stdout_handle, open(stderr, "a") as stderr_handle:
+ stdout_handle.truncate(0)
+ stderr_handle.truncate(0)
+
ctx = daemon.DaemonContext(
pidfile=TimeoutPIDLockFile(pid, -1),
files_preserve=[handle],
stdout=stdout_handle,
stderr=stderr_handle,
+ umask=int(settings.DAEMON_UMASK, 8),
)
with ctx:
job.run()
else:
- # There is a bug in CPython (fixed in March 2022 but not yet released) that
- # makes async.io handle SIGTERM improperly by using async unsafe
- # functions and hanging the triggerer receive SIGPIPE while handling
- # SIGTERN/SIGINT and deadlocking itself. Until the bug is handled
- # we should rather rely on standard handling of the signals rather than
- # adding our own signal handlers. Seems that even if our signal handler
- # just run exit(0) - it caused a race condition that led to the hanging.
- #
- # More details:
- # * https://bugs.python.org/issue39622
- # * https://github.com/python/cpython/issues/83803
- #
- # signal.signal(signal.SIGINT, sigint_handler)
- # signal.signal(signal.SIGTERM, sigint_handler)
+ signal.signal(signal.SIGINT, sigint_handler)
+ signal.signal(signal.SIGTERM, sigint_handler)
signal.signal(signal.SIGQUIT, sigquit_handler)
job.run()
diff --git a/airflow/cli/commands/user_command.py b/airflow/cli/commands/user_command.py
index ddbb7cfc82d6c..f1c806e942870 100644
--- a/airflow/cli/commands/user_command.py
+++ b/airflow/cli/commands/user_command.py
@@ -14,7 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""User sub-commands"""
+"""User sub-commands."""
+from __future__ import annotations
+
import functools
import getpass
import json
@@ -22,7 +24,7 @@
import random
import re
import string
-from typing import Any, Dict, List
+from typing import Any
from marshmallow import Schema, fields, validate
from marshmallow.exceptions import ValidationError
@@ -34,7 +36,7 @@
class UserSchema(Schema):
- """user collection item schema"""
+ """user collection item schema."""
id = fields.Int()
firstname = fields.Str(required=True)
@@ -46,10 +48,10 @@ class UserSchema(Schema):
@suppress_logs_and_warning
def users_list(args):
- """Lists users at the command line"""
+ """Lists users at the command line."""
appbuilder = cached_app().appbuilder
users = appbuilder.sm.get_all_users()
- fields = ['id', 'username', 'email', 'first_name', 'last_name', 'roles']
+ fields = ["id", "username", "email", "first_name", "last_name", "roles"]
AirflowConsole().print_as(
data=users, output=args.output, mapper=lambda x: {f: x.__getattribute__(f) for f in fields}
@@ -58,39 +60,39 @@ def users_list(args):
@cli_utils.action_cli(check_db=True)
def users_create(args):
- """Creates new user in the DB"""
+ """Creates new user in the DB."""
appbuilder = cached_app().appbuilder
role = appbuilder.sm.find_role(args.role)
if not role:
valid_roles = appbuilder.sm.get_all_roles()
- raise SystemExit(f'{args.role} is not a valid role. Valid roles are: {valid_roles}')
+ raise SystemExit(f"{args.role} is not a valid role. Valid roles are: {valid_roles}")
if args.use_random_password:
- password = ''.join(random.choice(string.printable) for _ in range(16))
+ password = "".join(random.choice(string.printable) for _ in range(16))
elif args.password:
password = args.password
else:
- password = getpass.getpass('Password:')
- password_confirmation = getpass.getpass('Repeat for confirmation:')
+ password = getpass.getpass("Password:")
+ password_confirmation = getpass.getpass("Repeat for confirmation:")
if password != password_confirmation:
- raise SystemExit('Passwords did not match')
+ raise SystemExit("Passwords did not match")
if appbuilder.sm.find_user(args.username):
- print(f'{args.username} already exist in the db')
+ print(f"{args.username} already exist in the db")
return
user = appbuilder.sm.add_user(args.username, args.firstname, args.lastname, args.email, role, password)
if user:
print(f'User "{args.username}" created with role "{args.role}"')
else:
- raise SystemExit('Failed to create user')
+ raise SystemExit("Failed to create user")
def _find_user(args):
if not args.username and not args.email:
- raise SystemExit('Missing args: must supply one of --username or --email')
+ raise SystemExit("Missing args: must supply one of --username or --email")
if args.username and args.email:
- raise SystemExit('Conflicting args: must supply either --username or --email, but not both')
+ raise SystemExit("Conflicting args: must supply either --username or --email, but not both")
appbuilder = cached_app().appbuilder
@@ -102,7 +104,7 @@ def _find_user(args):
@cli_utils.action_cli
def users_delete(args):
- """Deletes user from DB"""
+ """Deletes user from DB."""
user = _find_user(args)
appbuilder = cached_app().appbuilder
@@ -110,12 +112,12 @@ def users_delete(args):
if appbuilder.sm.del_register_user(user):
print(f'User "{user.username}" deleted')
else:
- raise SystemExit('Failed to delete user')
+ raise SystemExit("Failed to delete user")
@cli_utils.action_cli
def users_manage_role(args, remove=False):
- """Deletes or appends user roles"""
+ """Deletes or appends user roles."""
user = _find_user(args)
appbuilder = cached_app().appbuilder
@@ -142,10 +144,10 @@ def users_manage_role(args, remove=False):
def users_export(args):
- """Exports all users to the json file"""
+ """Exports all users to the json file."""
appbuilder = cached_app().appbuilder
users = appbuilder.sm.get_all_users()
- fields = ['id', 'username', 'email', 'first_name', 'last_name', 'roles']
+ fields = ["id", "username", "email", "first_name", "last_name", "roles"]
# In the User model the first and last name fields have underscores,
# but the corresponding parameters in the CLI don't
@@ -155,22 +157,22 @@ def remove_underscores(s):
users = [
{
remove_underscores(field): user.__getattribute__(field)
- if field != 'roles'
+ if field != "roles"
else [r.name for r in user.roles]
for field in fields
}
for user in users
]
- with open(args.export, 'w') as file:
+ with open(args.export, "w") as file:
file.write(json.dumps(users, sort_keys=True, indent=4))
print(f"{len(users)} users successfully exported to {file.name}")
@cli_utils.action_cli
def users_import(args):
- """Imports users from the json file"""
- json_file = getattr(args, 'import')
+ """Imports users from the json file."""
+ json_file = getattr(args, "import")
if not os.path.exists(json_file):
raise SystemExit(f"File '{json_file}' does not exist")
@@ -189,7 +191,7 @@ def users_import(args):
print("Updated the following users:\n\t{}".format("\n\t".join(users_updated)))
-def _import_users(users_list: List[Dict[str, Any]]):
+def _import_users(users_list: list[dict[str, Any]]):
appbuilder = cached_app().appbuilder
users_created = []
users_updated = []
@@ -199,15 +201,15 @@ def _import_users(users_list: List[Dict[str, Any]]):
except ValidationError as e:
msg = []
for row_num, failure in e.normalized_messages().items():
- msg.append(f'[Item {row_num}]')
+ msg.append(f"[Item {row_num}]")
for key, value in failure.items():
- msg.append(f'\t{key}: {value}')
- raise SystemExit("Error: Input file didn't pass validation. See below:\n{}".format('\n'.join(msg)))
+ msg.append(f"\t{key}: {value}")
+ raise SystemExit("Error: Input file didn't pass validation. See below:\n{}".format("\n".join(msg)))
for user in users_list:
roles = []
- for rolename in user['roles']:
+ for rolename in user["roles"]:
role = appbuilder.sm.find_role(rolename)
if not role:
valid_roles = appbuilder.sm.get_all_roles()
@@ -215,31 +217,31 @@ def _import_users(users_list: List[Dict[str, Any]]):
roles.append(role)
- existing_user = appbuilder.sm.find_user(email=user['email'])
+ existing_user = appbuilder.sm.find_user(email=user["email"])
if existing_user:
print(f"Found existing user with email '{user['email']}'")
- if existing_user.username != user['username']:
+ if existing_user.username != user["username"]:
raise SystemExit(
f"Error: Changing the username is not allowed - please delete and recreate the user with"
f" email {user['email']!r}"
)
existing_user.roles = roles
- existing_user.first_name = user['firstname']
- existing_user.last_name = user['lastname']
+ existing_user.first_name = user["firstname"]
+ existing_user.last_name = user["lastname"]
appbuilder.sm.update_user(existing_user)
- users_updated.append(user['email'])
+ users_updated.append(user["email"])
else:
print(f"Creating new user with email '{user['email']}'")
appbuilder.sm.add_user(
- username=user['username'],
- first_name=user['firstname'],
- last_name=user['lastname'],
- email=user['email'],
+ username=user["username"],
+ first_name=user["firstname"],
+ last_name=user["lastname"],
+ email=user["email"],
role=roles,
)
- users_created.append(user['email'])
+ users_created.append(user["email"])
return users_created, users_updated
diff --git a/airflow/cli/commands/variable_command.py b/airflow/cli/commands/variable_command.py
index 40eb193148a80..009b4704aaaee 100644
--- a/airflow/cli/commands/variable_command.py
+++ b/airflow/cli/commands/variable_command.py
@@ -15,7 +15,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Variable subcommands"""
+"""Variable subcommands."""
+from __future__ import annotations
+
import json
import os
from json import JSONDecodeError
@@ -29,7 +31,7 @@
@suppress_logs_and_warning
def variables_list(args):
- """Displays all of the variables"""
+ """Displays all the variables."""
with create_session() as session:
variables = session.query(Variable)
AirflowConsole().print_as(data=variables, output=args.output, mapper=lambda x: {"key": x.key})
@@ -37,7 +39,7 @@ def variables_list(args):
@suppress_logs_and_warning
def variables_get(args):
- """Displays variable by a given name"""
+ """Displays variable by a given name."""
try:
if args.default is None:
var = Variable.get(args.key, deserialize_json=args.json)
@@ -51,21 +53,21 @@ def variables_get(args):
@cli_utils.action_cli
def variables_set(args):
- """Creates new variable with a given name and value"""
+ """Creates new variable with a given name and value."""
Variable.set(args.key, args.value, serialize_json=args.json)
print(f"Variable {args.key} created")
@cli_utils.action_cli
def variables_delete(args):
- """Deletes variable by a given name"""
+ """Deletes variable by a given name."""
Variable.delete(args.key)
print(f"Variable {args.key} deleted")
@cli_utils.action_cli
def variables_import(args):
- """Imports variables from a given file"""
+ """Imports variables from a given file."""
if os.path.exists(args.file):
_import_helper(args.file)
else:
@@ -73,12 +75,12 @@ def variables_import(args):
def variables_export(args):
- """Exports all of the variables to the file"""
+ """Exports all the variables to the file."""
_variable_export_helper(args.file)
def _import_helper(filepath):
- """Helps import variables from the file"""
+ """Helps import variables from the file."""
with open(filepath) as varfile:
data = varfile.read()
@@ -92,7 +94,7 @@ def _import_helper(filepath):
try:
Variable.set(k, v, serialize_json=not isinstance(v, str))
except Exception as e:
- print(f'Variable import failed: {repr(e)}')
+ print(f"Variable import failed: {repr(e)}")
fail_count += 1
else:
suc_count += 1
@@ -102,7 +104,7 @@ def _import_helper(filepath):
def _variable_export_helper(filepath):
- """Helps export all of the variables to the file"""
+ """Helps export all the variables to the file."""
var_dict = {}
with create_session() as session:
qry = session.query(Variable).all()
@@ -115,6 +117,6 @@ def _variable_export_helper(filepath):
val = var.val
var_dict[var.key] = val
- with open(filepath, 'w') as varfile:
+ with open(filepath, "w") as varfile:
varfile.write(json.dumps(var_dict, sort_keys=True, indent=4))
print(f"{len(var_dict)} variables successfully exported to {filepath}")
diff --git a/airflow/cli/commands/version_command.py b/airflow/cli/commands/version_command.py
index 7e5190185884f..d3b735951c1fd 100644
--- a/airflow/cli/commands/version_command.py
+++ b/airflow/cli/commands/version_command.py
@@ -14,10 +14,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Version command"""
+"""Version command."""
+from __future__ import annotations
+
import airflow
def version(args):
- """Displays Airflow version at the command line"""
+ """Displays Airflow version at the command line."""
print(airflow.__version__)
diff --git a/airflow/cli/commands/webserver_command.py b/airflow/cli/commands/webserver_command.py
index c74513d23e2b9..a14f6a38e7357 100644
--- a/airflow/cli/commands/webserver_command.py
+++ b/airflow/cli/commands/webserver_command.py
@@ -14,8 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+"""Webserver command."""
+from __future__ import annotations
-"""Webserver command"""
import hashlib
import logging
import os
@@ -26,7 +27,7 @@
import time
from contextlib import suppress
from time import sleep
-from typing import Dict, List, NoReturn
+from typing import NoReturn
import daemon
import psutil
@@ -40,16 +41,16 @@
from airflow.utils.cli import setup_locations, setup_logging
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.process_utils import check_if_pidfile_process_is_running
-from airflow.www.app import create_app
log = logging.getLogger(__name__)
class GunicornMonitor(LoggingMixin):
"""
- Runs forever, monitoring the child processes of @gunicorn_master_proc and
- restarting workers occasionally or when files in the plug-in directory
- has been modified.
+ Runs forever.
+
+ Monitoring the child processes of @gunicorn_master_proc and restarting
+ workers occasionally or when files in the plug-in directory has been modified.
Each iteration of the loop traverses one edge of this state transition
diagram, where each state (node) represents
@@ -103,15 +104,17 @@ def __init__(
self._last_plugin_state = self._generate_plugin_state() if reload_on_plugin_change else None
self._restart_on_next_plugin_check = False
- def _generate_plugin_state(self) -> Dict[str, float]:
+ def _generate_plugin_state(self) -> dict[str, float]:
"""
+ Get plugin states.
+
Generate dict of filenames and last modification time of all files in settings.PLUGINS_FOLDER
directory.
"""
if not settings.PLUGINS_FOLDER:
return {}
- all_filenames: List[str] = []
+ all_filenames: list[str] = []
for (root, _, filenames) in os.walk(settings.PLUGINS_FOLDER):
all_filenames.extend(os.path.join(root, f) for f in filenames)
plugin_state = {f: self._get_file_hash(f) for f in sorted(all_filenames)}
@@ -119,7 +122,7 @@ def _generate_plugin_state(self) -> Dict[str, float]:
@staticmethod
def _get_file_hash(fname: str):
- """Calculate MD5 hash for file"""
+ """Calculate MD5 hash for file."""
hash_md5 = hashlib.md5()
with open(fname, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
@@ -127,7 +130,7 @@ def _get_file_hash(fname: str):
return hash_md5.hexdigest()
def _get_num_ready_workers_running(self) -> int:
- """Returns number of ready Gunicorn workers by looking for READY_PREFIX in process name"""
+ """Returns number of ready Gunicorn workers by looking for READY_PREFIX in process name."""
workers = psutil.Process(self.gunicorn_master_proc.pid).children()
def ready_prefix_on_cmdline(proc):
@@ -143,12 +146,12 @@ def ready_prefix_on_cmdline(proc):
return len(ready_workers)
def _get_num_workers_running(self) -> int:
- """Returns number of running Gunicorn workers processes"""
+ """Returns number of running Gunicorn workers processes."""
workers = psutil.Process(self.gunicorn_master_proc.pid).children()
return len(workers)
def _wait_until_true(self, fn, timeout: int = 0) -> None:
- """Sleeps until fn is true"""
+ """Sleeps until fn is true."""
start_time = time.monotonic()
while not fn():
if 0 < timeout <= time.monotonic() - start_time:
@@ -188,8 +191,10 @@ def _kill_old_workers(self, count: int) -> None:
def _reload_gunicorn(self) -> None:
"""
- Send signal to reload the gunicorn configuration. When gunicorn receive signals, it reload the
- configuration, start the new worker processes with a new configuration and gracefully
+ Send signal to reload the gunicorn configuration.
+
+ When gunicorn receive signals, it reloads the configuration,
+ start the new worker processes with a new configuration and gracefully
shutdown older workers.
"""
# HUP: Reload the configuration.
@@ -229,7 +234,7 @@ def _check_workers(self) -> None:
# Whenever some workers are not ready, wait until all workers are ready
if num_ready_workers_running < num_workers_running:
self.log.debug(
- '[%d / %d] Some workers are starting up, waiting...',
+ "[%d / %d] Some workers are starting up, waiting...",
num_ready_workers_running,
num_workers_running,
)
@@ -241,7 +246,7 @@ def _check_workers(self) -> None:
if num_workers_running > self.num_workers_expected:
excess = min(num_workers_running - self.num_workers_expected, self.worker_refresh_batch_size)
self.log.debug(
- '[%d / %d] Killing %s workers', num_ready_workers_running, num_workers_running, excess
+ "[%d / %d] Killing %s workers", num_ready_workers_running, num_workers_running, excess
)
self._kill_old_workers(excess)
return
@@ -262,7 +267,7 @@ def _check_workers(self) -> None:
)
# log at info since we are trying fix an error logged just above
self.log.info(
- '[%d / %d] Spawning %d workers',
+ "[%d / %d] Spawning %d workers",
num_ready_workers_running,
num_workers_running,
new_worker_count,
@@ -279,7 +284,7 @@ def _check_workers(self) -> None:
if self.worker_refresh_interval < last_refresh_diff:
num_new_workers = self.worker_refresh_batch_size
self.log.debug(
- '[%d / %d] Starting doing a refresh. Starting %d workers.',
+ "[%d / %d] Starting doing a refresh. Starting %d workers.",
num_ready_workers_running,
num_workers_running,
num_new_workers,
@@ -295,8 +300,8 @@ def _check_workers(self) -> None:
# If changed, wait until its content is fully saved.
if new_state != self._last_plugin_state:
self.log.debug(
- '[%d / %d] Plugins folder changed. The gunicorn will be restarted the next time the '
- 'plugin directory is checked, if there is no change in it.',
+ "[%d / %d] Plugins folder changed. The gunicorn will be restarted the next time the "
+ "plugin directory is checked, if there is no change in it.",
num_ready_workers_running,
num_workers_running,
)
@@ -304,7 +309,7 @@ def _check_workers(self) -> None:
self._last_plugin_state = new_state
elif self._restart_on_next_plugin_check:
self.log.debug(
- '[%d / %d] Starts reloading the gunicorn configuration.',
+ "[%d / %d] Starts reloading the gunicorn configuration.",
num_ready_workers_running,
num_workers_running,
)
@@ -315,11 +320,11 @@ def _check_workers(self) -> None:
@cli_utils.action_cli
def webserver(args):
- """Starts Airflow Webserver"""
+ """Starts Airflow Webserver."""
print(settings.HEADER)
# Check for old/insecure config, and fail safe (i.e. don't launch) if the config is wildly insecure.
- if conf.get('webserver', 'secret_key') == 'temporary_key':
+ if conf.get("webserver", "secret_key") == "temporary_key":
from rich import print as rich_print
rich_print(
@@ -331,24 +336,26 @@ def webserver(args):
)
sys.exit(1)
- access_logfile = args.access_logfile or conf.get('webserver', 'access_logfile')
- error_logfile = args.error_logfile or conf.get('webserver', 'error_logfile')
- access_logformat = args.access_logformat or conf.get('webserver', 'access_logformat')
- num_workers = args.workers or conf.get('webserver', 'workers')
- worker_timeout = args.worker_timeout or conf.get('webserver', 'web_server_worker_timeout')
- ssl_cert = args.ssl_cert or conf.get('webserver', 'web_server_ssl_cert')
- ssl_key = args.ssl_key or conf.get('webserver', 'web_server_ssl_key')
+ access_logfile = args.access_logfile or conf.get("webserver", "access_logfile")
+ error_logfile = args.error_logfile or conf.get("webserver", "error_logfile")
+ access_logformat = args.access_logformat or conf.get("webserver", "access_logformat")
+ num_workers = args.workers or conf.get("webserver", "workers")
+ worker_timeout = args.worker_timeout or conf.get("webserver", "web_server_worker_timeout")
+ ssl_cert = args.ssl_cert or conf.get("webserver", "web_server_ssl_cert")
+ ssl_key = args.ssl_key or conf.get("webserver", "web_server_ssl_key")
if not ssl_cert and ssl_key:
- raise AirflowException('An SSL certificate must also be provided for use with ' + ssl_key)
+ raise AirflowException("An SSL certificate must also be provided for use with " + ssl_key)
if ssl_cert and not ssl_key:
- raise AirflowException('An SSL key must also be provided for use with ' + ssl_cert)
+ raise AirflowException("An SSL key must also be provided for use with " + ssl_cert)
+
+ from airflow.www.app import create_app
if args.debug:
print(f"Starting the web server on port {args.port} and host {args.hostname}.")
- app = create_app(testing=conf.getboolean('core', 'unit_test_mode'))
+ app = create_app(testing=conf.getboolean("core", "unit_test_mode"))
app.run(
debug=True,
- use_reloader=not app.config['TESTING'],
+ use_reloader=not app.config["TESTING"],
port=args.port,
host=args.hostname,
ssl_context=(ssl_cert, ssl_key) if ssl_cert and ssl_key else None,
@@ -364,54 +371,60 @@ def webserver(args):
print(
textwrap.dedent(
- f'''\
+ f"""\
Running the Gunicorn Server with:
Workers: {num_workers} {args.workerclass}
Host: {args.hostname}:{args.port}
Timeout: {worker_timeout}
Logfiles: {access_logfile} {error_logfile}
Access Logformat: {access_logformat}
- ================================================================='''
+ ================================================================="""
)
)
run_args = [
sys.executable,
- '-m',
- 'gunicorn',
- '--workers',
+ "-m",
+ "gunicorn",
+ "--workers",
str(num_workers),
- '--worker-class',
+ "--worker-class",
str(args.workerclass),
- '--timeout',
+ "--timeout",
str(worker_timeout),
- '--bind',
- args.hostname + ':' + str(args.port),
- '--name',
- 'airflow-webserver',
- '--pid',
+ "--bind",
+ args.hostname + ":" + str(args.port),
+ "--name",
+ "airflow-webserver",
+ "--pid",
pid_file,
- '--config',
- 'python:airflow.www.gunicorn_config',
+ "--config",
+ "python:airflow.www.gunicorn_config",
]
if args.access_logfile:
- run_args += ['--access-logfile', str(args.access_logfile)]
+ run_args += ["--access-logfile", str(args.access_logfile)]
if args.error_logfile:
- run_args += ['--error-logfile', str(args.error_logfile)]
+ run_args += ["--error-logfile", str(args.error_logfile)]
if args.access_logformat and args.access_logformat.strip():
- run_args += ['--access-logformat', str(args.access_logformat)]
+ run_args += ["--access-logformat", str(args.access_logformat)]
if args.daemon:
- run_args += ['--daemon']
+ run_args += ["--daemon"]
if ssl_cert:
- run_args += ['--certfile', ssl_cert, '--keyfile', ssl_key]
+ run_args += ["--certfile", ssl_cert, "--keyfile", ssl_key]
run_args += ["airflow.www.app:cached_app()"]
+ # To prevent different workers creating the web app and
+ # all writing to the database at the same time, we use the --preload option.
+ # With the preload option, the app is loaded before the workers are forked, and each worker will
+ # then have a copy of the app
+ run_args += ["--preload"]
+
gunicorn_master_proc = None
def kill_proc(signum, _):
@@ -432,29 +445,33 @@ def monitor_gunicorn(gunicorn_master_pid: int):
GunicornMonitor(
gunicorn_master_pid=gunicorn_master_pid,
num_workers_expected=num_workers,
- master_timeout=conf.getint('webserver', 'web_server_master_timeout'),
- worker_refresh_interval=conf.getint('webserver', 'worker_refresh_interval', fallback=30),
- worker_refresh_batch_size=conf.getint('webserver', 'worker_refresh_batch_size', fallback=1),
+ master_timeout=conf.getint("webserver", "web_server_master_timeout"),
+ worker_refresh_interval=conf.getint("webserver", "worker_refresh_interval", fallback=30),
+ worker_refresh_batch_size=conf.getint("webserver", "worker_refresh_batch_size", fallback=1),
reload_on_plugin_change=conf.getboolean(
- 'webserver', 'reload_on_plugin_change', fallback=False
+ "webserver", "reload_on_plugin_change", fallback=False
),
).start()
if args.daemon:
# This makes possible errors get reported before daemonization
- os.environ['SKIP_DAGS_PARSING'] = 'True'
+ os.environ["SKIP_DAGS_PARSING"] = "True"
app = create_app(None)
- os.environ.pop('SKIP_DAGS_PARSING')
+ os.environ.pop("SKIP_DAGS_PARSING")
handle = setup_logging(log_file)
base, ext = os.path.splitext(pid_file)
- with open(stdout, 'w+') as stdout, open(stderr, 'w+') as stderr:
+ with open(stdout, "a") as stdout, open(stderr, "a") as stderr:
+ stdout.truncate(0)
+ stderr.truncate(0)
+
ctx = daemon.DaemonContext(
pidfile=TimeoutPIDLockFile(f"{base}-monitor{ext}", -1),
files_preserve=[handle],
stdout=stdout,
stderr=stderr,
+ umask=int(settings.DAEMON_UMASK, 8),
)
with ctx:
subprocess.Popen(run_args, close_fds=True)
diff --git a/airflow/cli/simple_table.py b/airflow/cli/simple_table.py
index efd418d21fcbe..87f7ce9f3b853 100644
--- a/airflow/cli/simple_table.py
+++ b/airflow/cli/simple_table.py
@@ -14,9 +14,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
import inspect
import json
-from typing import Any, Callable, Dict, List, Optional, Union
+from typing import Any, Callable
from rich.box import ASCII_DOUBLE_HEAD
from rich.console import Console
@@ -30,7 +32,7 @@
class AirflowConsole(Console):
- """Airflow rich console"""
+ """Airflow rich console."""
def __init__(self, show_header: bool = True, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -40,18 +42,18 @@ def __init__(self, show_header: bool = True, *args, **kwargs):
# If show header in tables
self.show_header = show_header
- def print_as_json(self, data: Dict):
- """Renders dict as json text representation"""
+ def print_as_json(self, data: dict):
+ """Renders dict as json text representation."""
json_content = json.dumps(data)
self.print(Syntax(json_content, "json", theme="ansi_dark"), soft_wrap=True)
- def print_as_yaml(self, data: Dict):
- """Renders dict as yaml text representation"""
+ def print_as_yaml(self, data: dict):
+ """Renders dict as yaml text representation."""
yaml_content = yaml.dump(data)
self.print(Syntax(yaml_content, "yaml", theme="ansi_dark"), soft_wrap=True)
- def print_as_table(self, data: List[Dict]):
- """Renders list of dictionaries as table"""
+ def print_as_table(self, data: list[dict]):
+ """Renders list of dictionaries as table."""
if not data:
self.print("No data found")
return
@@ -64,8 +66,8 @@ def print_as_table(self, data: List[Dict]):
table.add_row(*(str(d) for d in row.values()))
self.print(table)
- def print_as_plain_table(self, data: List[Dict]):
- """Renders list of dictionaries as a simple table than can be easily piped"""
+ def print_as_plain_table(self, data: list[dict]):
+ """Renders list of dictionaries as a simple table than can be easily piped."""
if not data:
self.print("No data found")
return
@@ -73,7 +75,7 @@ def print_as_plain_table(self, data: List[Dict]):
output = tabulate(rows, tablefmt="plain", headers=list(data[0].keys()))
print(output)
- def _normalize_data(self, value: Any, output: str) -> Optional[Union[list, str, dict]]:
+ def _normalize_data(self, value: Any, output: str) -> list | str | dict | None:
if isinstance(value, (tuple, list)):
if output == "table":
return ",".join(str(self._normalize_data(x, output)) for x in value)
@@ -86,9 +88,9 @@ def _normalize_data(self, value: Any, output: str) -> Optional[Union[list, str,
return None
return str(value)
- def print_as(self, data: List[Union[Dict, Any]], output: str, mapper: Optional[Callable] = None):
- """Prints provided using format specified by output argument"""
- output_to_renderer: Dict[str, Callable[[Any], None]] = {
+ def print_as(self, data: list[dict | Any], output: str, mapper: Callable | None = None):
+ """Prints provided using format specified by output argument."""
+ output_to_renderer: dict[str, Callable[[Any], None]] = {
"json": self.print_as_json,
"yaml": self.print_as_yaml,
"table": self.print_as_table,
@@ -104,7 +106,7 @@ def print_as(self, data: List[Union[Dict, Any]], output: str, mapper: Optional[C
raise ValueError("To tabulate non-dictionary data you need to provide `mapper` function")
if mapper:
- dict_data: List[Dict] = [mapper(d) for d in data]
+ dict_data: list[dict] = [mapper(d) for d in data]
else:
dict_data = data
dict_data = [{k: self._normalize_data(v, output) for k, v in d.items()} for d in dict_data]
@@ -125,6 +127,6 @@ def __init__(self, *args, **kwargs):
self.caption = kwargs.get("caption", " ")
def add_column(self, *args, **kwargs) -> None:
- """Add a column to the table. We use different default"""
+ """Add a column to the table. We use different default."""
kwargs["overflow"] = kwargs.get("overflow") # to avoid truncating
super().add_column(*args, **kwargs)
diff --git a/airflow/compat/functools.py b/airflow/compat/functools.py
index e3dea0a660bbc..dc0c520b796c0 100644
--- a/airflow/compat/functools.py
+++ b/airflow/compat/functools.py
@@ -15,6 +15,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
import sys
if sys.version_info >= (3, 8):
diff --git a/airflow/compat/sqlalchemy.py b/airflow/compat/sqlalchemy.py
deleted file mode 100644
index 427db90a73d67..0000000000000
--- a/airflow/compat/sqlalchemy.py
+++ /dev/null
@@ -1,35 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-from sqlalchemy import Table
-from sqlalchemy.engine import Connection
-
-try:
- from sqlalchemy import inspect
-except AttributeError:
- from sqlalchemy.engine.reflection import Inspector
-
- inspect = Inspector.from_engine
-
-__all__ = ["has_table", "inspect"]
-
-
-def has_table(conn: Connection, table: Table):
- try:
- return inspect(conn).has_table(table)
- except AttributeError:
- return table.exists(conn)
diff --git a/airflow/config_templates/airflow_local_settings.py b/airflow/config_templates/airflow_local_settings.py
index b2752c2be7c25..01edea7520f33 100644
--- a/airflow/config_templates/airflow_local_settings.py
+++ b/airflow/config_templates/airflow_local_settings.py
@@ -15,12 +15,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Airflow logging settings"""
+"""Airflow logging settings."""
+from __future__ import annotations
import os
from pathlib import Path
-from typing import Any, Dict, Optional, Union
-from urllib.parse import urlparse
+from typing import Any
+from urllib.parse import urlsplit
from airflow.configuration import conf
from airflow.exceptions import AirflowException
@@ -29,123 +30,147 @@
# in this file instead of from airflow.cfg. Currently
# there are other log format and level configurations in
# settings.py and cli.py. Please see AIRFLOW-1455.
-LOG_LEVEL: str = conf.get_mandatory_value('logging', 'LOGGING_LEVEL').upper()
+LOG_LEVEL: str = conf.get_mandatory_value("logging", "LOGGING_LEVEL").upper()
# Flask appbuilder's info level log is very verbose,
# so it's set to 'WARN' by default.
-FAB_LOG_LEVEL: str = conf.get_mandatory_value('logging', 'FAB_LOGGING_LEVEL').upper()
+FAB_LOG_LEVEL: str = conf.get_mandatory_value("logging", "FAB_LOGGING_LEVEL").upper()
-LOG_FORMAT: str = conf.get_mandatory_value('logging', 'LOG_FORMAT')
+LOG_FORMAT: str = conf.get_mandatory_value("logging", "LOG_FORMAT")
+DAG_PROCESSOR_LOG_FORMAT: str = conf.get_mandatory_value("logging", "DAG_PROCESSOR_LOG_FORMAT")
-COLORED_LOG_FORMAT: str = conf.get_mandatory_value('logging', 'COLORED_LOG_FORMAT')
+LOG_FORMATTER_CLASS: str = conf.get_mandatory_value(
+ "logging", "LOG_FORMATTER_CLASS", fallback="airflow.utils.log.timezone_aware.TimezoneAware"
+)
+
+COLORED_LOG_FORMAT: str = conf.get_mandatory_value("logging", "COLORED_LOG_FORMAT")
+
+COLORED_LOG: bool = conf.getboolean("logging", "COLORED_CONSOLE_LOG")
-COLORED_LOG: bool = conf.getboolean('logging', 'COLORED_CONSOLE_LOG')
+COLORED_FORMATTER_CLASS: str = conf.get_mandatory_value("logging", "COLORED_FORMATTER_CLASS")
-COLORED_FORMATTER_CLASS: str = conf.get_mandatory_value('logging', 'COLORED_FORMATTER_CLASS')
+DAG_PROCESSOR_LOG_TARGET: str = conf.get_mandatory_value("logging", "DAG_PROCESSOR_LOG_TARGET")
-BASE_LOG_FOLDER: str = conf.get_mandatory_value('logging', 'BASE_LOG_FOLDER')
+BASE_LOG_FOLDER: str = conf.get_mandatory_value("logging", "BASE_LOG_FOLDER")
-PROCESSOR_LOG_FOLDER: str = conf.get_mandatory_value('scheduler', 'CHILD_PROCESS_LOG_DIRECTORY')
+PROCESSOR_LOG_FOLDER: str = conf.get_mandatory_value("scheduler", "CHILD_PROCESS_LOG_DIRECTORY")
DAG_PROCESSOR_MANAGER_LOG_LOCATION: str = conf.get_mandatory_value(
- 'logging', 'DAG_PROCESSOR_MANAGER_LOG_LOCATION'
+ "logging", "DAG_PROCESSOR_MANAGER_LOG_LOCATION"
)
-FILENAME_TEMPLATE: str = conf.get_mandatory_value('logging', 'LOG_FILENAME_TEMPLATE')
+# FILENAME_TEMPLATE only uses in Remote Logging Handlers since Airflow 2.3.3
+# All of these handlers inherited from FileTaskHandler and providing any value rather than None
+# would raise deprecation warning.
+FILENAME_TEMPLATE: str | None = None
-PROCESSOR_FILENAME_TEMPLATE: str = conf.get_mandatory_value('logging', 'LOG_PROCESSOR_FILENAME_TEMPLATE')
+PROCESSOR_FILENAME_TEMPLATE: str = conf.get_mandatory_value("logging", "LOG_PROCESSOR_FILENAME_TEMPLATE")
-DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
- 'version': 1,
- 'disable_existing_loggers': False,
- 'formatters': {
- 'airflow': {'format': LOG_FORMAT},
- 'airflow_coloured': {
- 'format': COLORED_LOG_FORMAT if COLORED_LOG else LOG_FORMAT,
- 'class': COLORED_FORMATTER_CLASS if COLORED_LOG else 'logging.Formatter',
+DEFAULT_LOGGING_CONFIG: dict[str, Any] = {
+ "version": 1,
+ "disable_existing_loggers": False,
+ "formatters": {
+ "airflow": {
+ "format": LOG_FORMAT,
+ "class": LOG_FORMATTER_CLASS,
+ },
+ "airflow_coloured": {
+ "format": COLORED_LOG_FORMAT if COLORED_LOG else LOG_FORMAT,
+ "class": COLORED_FORMATTER_CLASS if COLORED_LOG else LOG_FORMATTER_CLASS,
+ },
+ "source_processor": {
+ "format": DAG_PROCESSOR_LOG_FORMAT,
+ "class": LOG_FORMATTER_CLASS,
},
},
- 'filters': {
- 'mask_secrets': {
- '()': 'airflow.utils.log.secrets_masker.SecretsMasker',
+ "filters": {
+ "mask_secrets": {
+ "()": "airflow.utils.log.secrets_masker.SecretsMasker",
},
},
- 'handlers': {
- 'console': {
- 'class': 'airflow.utils.log.logging_mixin.RedirectStdHandler',
- 'formatter': 'airflow_coloured',
- 'stream': 'sys.stdout',
- 'filters': ['mask_secrets'],
+ "handlers": {
+ "console": {
+ "class": "airflow.utils.log.logging_mixin.RedirectStdHandler",
+ "formatter": "airflow_coloured",
+ "stream": "sys.stdout",
+ "filters": ["mask_secrets"],
+ },
+ "task": {
+ "class": "airflow.utils.log.file_task_handler.FileTaskHandler",
+ "formatter": "airflow",
+ "base_log_folder": os.path.expanduser(BASE_LOG_FOLDER),
+ "filters": ["mask_secrets"],
},
- 'task': {
- 'class': 'airflow.utils.log.file_task_handler.FileTaskHandler',
- 'formatter': 'airflow',
- 'base_log_folder': os.path.expanduser(BASE_LOG_FOLDER),
- 'filename_template': FILENAME_TEMPLATE,
- 'filters': ['mask_secrets'],
+ "processor": {
+ "class": "airflow.utils.log.file_processor_handler.FileProcessorHandler",
+ "formatter": "airflow",
+ "base_log_folder": os.path.expanduser(PROCESSOR_LOG_FOLDER),
+ "filename_template": PROCESSOR_FILENAME_TEMPLATE,
+ "filters": ["mask_secrets"],
},
- 'processor': {
- 'class': 'airflow.utils.log.file_processor_handler.FileProcessorHandler',
- 'formatter': 'airflow',
- 'base_log_folder': os.path.expanduser(PROCESSOR_LOG_FOLDER),
- 'filename_template': PROCESSOR_FILENAME_TEMPLATE,
- 'filters': ['mask_secrets'],
+ "processor_to_stdout": {
+ "class": "airflow.utils.log.logging_mixin.RedirectStdHandler",
+ "formatter": "source_processor",
+ "stream": "sys.stdout",
+ "filters": ["mask_secrets"],
},
},
- 'loggers': {
- 'airflow.processor': {
- 'handlers': ['processor'],
- 'level': LOG_LEVEL,
- 'propagate': False,
+ "loggers": {
+ "airflow.processor": {
+ "handlers": ["processor_to_stdout" if DAG_PROCESSOR_LOG_TARGET == "stdout" else "processor"],
+ "level": LOG_LEVEL,
+ # Set to true here (and reset via set_context) so that if no file is configured we still get logs!
+ "propagate": True,
},
- 'airflow.task': {
- 'handlers': ['task'],
- 'level': LOG_LEVEL,
- 'propagate': False,
- 'filters': ['mask_secrets'],
+ "airflow.task": {
+ "handlers": ["task"],
+ "level": LOG_LEVEL,
+ # Set to true here (and reset via set_context) so that if no file is configured we still get logs!
+ "propagate": True,
+ "filters": ["mask_secrets"],
},
- 'flask_appbuilder': {
- 'handlers': ['console'],
- 'level': FAB_LOG_LEVEL,
- 'propagate': True,
+ "flask_appbuilder": {
+ "handlers": ["console"],
+ "level": FAB_LOG_LEVEL,
+ "propagate": True,
},
},
- 'root': {
- 'handlers': ['console'],
- 'level': LOG_LEVEL,
- 'filters': ['mask_secrets'],
+ "root": {
+ "handlers": ["console"],
+ "level": LOG_LEVEL,
+ "filters": ["mask_secrets"],
},
}
-EXTRA_LOGGER_NAMES: Optional[str] = conf.get('logging', 'EXTRA_LOGGER_NAMES', fallback=None)
+EXTRA_LOGGER_NAMES: str | None = conf.get("logging", "EXTRA_LOGGER_NAMES", fallback=None)
if EXTRA_LOGGER_NAMES:
new_loggers = {
logger_name.strip(): {
- 'handlers': ['console'],
- 'level': LOG_LEVEL,
- 'propagate': True,
+ "handlers": ["console"],
+ "level": LOG_LEVEL,
+ "propagate": True,
}
for logger_name in EXTRA_LOGGER_NAMES.split(",")
}
- DEFAULT_LOGGING_CONFIG['loggers'].update(new_loggers)
-
-DEFAULT_DAG_PARSING_LOGGING_CONFIG: Dict[str, Dict[str, Dict[str, Any]]] = {
- 'handlers': {
- 'processor_manager': {
- 'class': 'logging.handlers.RotatingFileHandler',
- 'formatter': 'airflow',
- 'filename': DAG_PROCESSOR_MANAGER_LOG_LOCATION,
- 'mode': 'a',
- 'maxBytes': 104857600, # 100MB
- 'backupCount': 5,
+ DEFAULT_LOGGING_CONFIG["loggers"].update(new_loggers)
+
+DEFAULT_DAG_PARSING_LOGGING_CONFIG: dict[str, dict[str, dict[str, Any]]] = {
+ "handlers": {
+ "processor_manager": {
+ "class": "airflow.utils.log.non_caching_file_handler.NonCachingRotatingFileHandler",
+ "formatter": "airflow",
+ "filename": DAG_PROCESSOR_MANAGER_LOG_LOCATION,
+ "mode": "a",
+ "maxBytes": 104857600, # 100MB
+ "backupCount": 5,
}
},
- 'loggers': {
- 'airflow.processor_manager': {
- 'handlers': ['processor_manager'],
- 'level': LOG_LEVEL,
- 'propagate': False,
+ "loggers": {
+ "airflow.processor_manager": {
+ "handlers": ["processor_manager"],
+ "level": LOG_LEVEL,
+ "propagate": False,
}
},
}
@@ -153,27 +178,27 @@
# Only update the handlers and loggers when CONFIG_PROCESSOR_MANAGER_LOGGER is set.
# This is to avoid exceptions when initializing RotatingFileHandler multiple times
# in multiple processes.
-if os.environ.get('CONFIG_PROCESSOR_MANAGER_LOGGER') == 'True':
- DEFAULT_LOGGING_CONFIG['handlers'].update(DEFAULT_DAG_PARSING_LOGGING_CONFIG['handlers'])
- DEFAULT_LOGGING_CONFIG['loggers'].update(DEFAULT_DAG_PARSING_LOGGING_CONFIG['loggers'])
+if os.environ.get("CONFIG_PROCESSOR_MANAGER_LOGGER") == "True":
+ DEFAULT_LOGGING_CONFIG["handlers"].update(DEFAULT_DAG_PARSING_LOGGING_CONFIG["handlers"])
+ DEFAULT_LOGGING_CONFIG["loggers"].update(DEFAULT_DAG_PARSING_LOGGING_CONFIG["loggers"])
# Manually create log directory for processor_manager handler as RotatingFileHandler
# will only create file but not the directory.
- processor_manager_handler_config: Dict[str, Any] = DEFAULT_DAG_PARSING_LOGGING_CONFIG['handlers'][
- 'processor_manager'
+ processor_manager_handler_config: dict[str, Any] = DEFAULT_DAG_PARSING_LOGGING_CONFIG["handlers"][
+ "processor_manager"
]
- directory: str = os.path.dirname(processor_manager_handler_config['filename'])
+ directory: str = os.path.dirname(processor_manager_handler_config["filename"])
Path(directory).mkdir(parents=True, exist_ok=True, mode=0o755)
##################
# Remote logging #
##################
-REMOTE_LOGGING: bool = conf.getboolean('logging', 'remote_logging')
+REMOTE_LOGGING: bool = conf.getboolean("logging", "remote_logging")
if REMOTE_LOGGING:
- ELASTICSEARCH_HOST: Optional[str] = conf.get('elasticsearch', 'HOST')
+ ELASTICSEARCH_HOST: str | None = conf.get("elasticsearch", "HOST")
# Storage bucket URL for remote logging
# S3 buckets should start with "s3://"
@@ -181,115 +206,115 @@
# GCS buckets should start with "gs://"
# WASB buckets should start with "wasb"
# just to help Airflow select correct handler
- REMOTE_BASE_LOG_FOLDER: str = conf.get_mandatory_value('logging', 'REMOTE_BASE_LOG_FOLDER')
-
- if REMOTE_BASE_LOG_FOLDER.startswith('s3://'):
- S3_REMOTE_HANDLERS: Dict[str, Dict[str, str]] = {
- 'task': {
- 'class': 'airflow.providers.amazon.aws.log.s3_task_handler.S3TaskHandler',
- 'formatter': 'airflow',
- 'base_log_folder': str(os.path.expanduser(BASE_LOG_FOLDER)),
- 's3_log_folder': REMOTE_BASE_LOG_FOLDER,
- 'filename_template': FILENAME_TEMPLATE,
+ REMOTE_BASE_LOG_FOLDER: str = conf.get_mandatory_value("logging", "REMOTE_BASE_LOG_FOLDER")
+
+ if REMOTE_BASE_LOG_FOLDER.startswith("s3://"):
+ S3_REMOTE_HANDLERS: dict[str, dict[str, str | None]] = {
+ "task": {
+ "class": "airflow.providers.amazon.aws.log.s3_task_handler.S3TaskHandler",
+ "formatter": "airflow",
+ "base_log_folder": str(os.path.expanduser(BASE_LOG_FOLDER)),
+ "s3_log_folder": REMOTE_BASE_LOG_FOLDER,
+ "filename_template": FILENAME_TEMPLATE,
},
}
- DEFAULT_LOGGING_CONFIG['handlers'].update(S3_REMOTE_HANDLERS)
- elif REMOTE_BASE_LOG_FOLDER.startswith('cloudwatch://'):
- url_parts = urlparse(REMOTE_BASE_LOG_FOLDER)
- CLOUDWATCH_REMOTE_HANDLERS: Dict[str, Dict[str, str]] = {
- 'task': {
- 'class': 'airflow.providers.amazon.aws.log.cloudwatch_task_handler.CloudwatchTaskHandler',
- 'formatter': 'airflow',
- 'base_log_folder': str(os.path.expanduser(BASE_LOG_FOLDER)),
- 'log_group_arn': url_parts.netloc + url_parts.path,
- 'filename_template': FILENAME_TEMPLATE,
+ DEFAULT_LOGGING_CONFIG["handlers"].update(S3_REMOTE_HANDLERS)
+ elif REMOTE_BASE_LOG_FOLDER.startswith("cloudwatch://"):
+ url_parts = urlsplit(REMOTE_BASE_LOG_FOLDER)
+ CLOUDWATCH_REMOTE_HANDLERS: dict[str, dict[str, str | None]] = {
+ "task": {
+ "class": "airflow.providers.amazon.aws.log.cloudwatch_task_handler.CloudwatchTaskHandler",
+ "formatter": "airflow",
+ "base_log_folder": str(os.path.expanduser(BASE_LOG_FOLDER)),
+ "log_group_arn": url_parts.netloc + url_parts.path,
+ "filename_template": FILENAME_TEMPLATE,
},
}
- DEFAULT_LOGGING_CONFIG['handlers'].update(CLOUDWATCH_REMOTE_HANDLERS)
- elif REMOTE_BASE_LOG_FOLDER.startswith('gs://'):
- key_path = conf.get_mandatory_value('logging', 'GOOGLE_KEY_PATH', fallback=None)
- GCS_REMOTE_HANDLERS: Dict[str, Dict[str, str]] = {
- 'task': {
- 'class': 'airflow.providers.google.cloud.log.gcs_task_handler.GCSTaskHandler',
- 'formatter': 'airflow',
- 'base_log_folder': str(os.path.expanduser(BASE_LOG_FOLDER)),
- 'gcs_log_folder': REMOTE_BASE_LOG_FOLDER,
- 'filename_template': FILENAME_TEMPLATE,
- 'gcp_key_path': key_path,
+ DEFAULT_LOGGING_CONFIG["handlers"].update(CLOUDWATCH_REMOTE_HANDLERS)
+ elif REMOTE_BASE_LOG_FOLDER.startswith("gs://"):
+ key_path = conf.get_mandatory_value("logging", "GOOGLE_KEY_PATH", fallback=None)
+ GCS_REMOTE_HANDLERS: dict[str, dict[str, str | None]] = {
+ "task": {
+ "class": "airflow.providers.google.cloud.log.gcs_task_handler.GCSTaskHandler",
+ "formatter": "airflow",
+ "base_log_folder": str(os.path.expanduser(BASE_LOG_FOLDER)),
+ "gcs_log_folder": REMOTE_BASE_LOG_FOLDER,
+ "filename_template": FILENAME_TEMPLATE,
+ "gcp_key_path": key_path,
},
}
- DEFAULT_LOGGING_CONFIG['handlers'].update(GCS_REMOTE_HANDLERS)
- elif REMOTE_BASE_LOG_FOLDER.startswith('wasb'):
- WASB_REMOTE_HANDLERS: Dict[str, Dict[str, Union[str, bool]]] = {
- 'task': {
- 'class': 'airflow.providers.microsoft.azure.log.wasb_task_handler.WasbTaskHandler',
- 'formatter': 'airflow',
- 'base_log_folder': str(os.path.expanduser(BASE_LOG_FOLDER)),
- 'wasb_log_folder': REMOTE_BASE_LOG_FOLDER,
- 'wasb_container': 'airflow-logs',
- 'filename_template': FILENAME_TEMPLATE,
- 'delete_local_copy': False,
+ DEFAULT_LOGGING_CONFIG["handlers"].update(GCS_REMOTE_HANDLERS)
+ elif REMOTE_BASE_LOG_FOLDER.startswith("wasb"):
+ WASB_REMOTE_HANDLERS: dict[str, dict[str, str | bool | None]] = {
+ "task": {
+ "class": "airflow.providers.microsoft.azure.log.wasb_task_handler.WasbTaskHandler",
+ "formatter": "airflow",
+ "base_log_folder": str(os.path.expanduser(BASE_LOG_FOLDER)),
+ "wasb_log_folder": REMOTE_BASE_LOG_FOLDER,
+ "wasb_container": "airflow-logs",
+ "filename_template": FILENAME_TEMPLATE,
+ "delete_local_copy": False,
},
}
- DEFAULT_LOGGING_CONFIG['handlers'].update(WASB_REMOTE_HANDLERS)
- elif REMOTE_BASE_LOG_FOLDER.startswith('stackdriver://'):
- key_path = conf.get_mandatory_value('logging', 'GOOGLE_KEY_PATH', fallback=None)
+ DEFAULT_LOGGING_CONFIG["handlers"].update(WASB_REMOTE_HANDLERS)
+ elif REMOTE_BASE_LOG_FOLDER.startswith("stackdriver://"):
+ key_path = conf.get_mandatory_value("logging", "GOOGLE_KEY_PATH", fallback=None)
# stackdriver:///airflow-tasks => airflow-tasks
- log_name = urlparse(REMOTE_BASE_LOG_FOLDER).path[1:]
+ log_name = urlsplit(REMOTE_BASE_LOG_FOLDER).path[1:]
STACKDRIVER_REMOTE_HANDLERS = {
- 'task': {
- 'class': 'airflow.providers.google.cloud.log.stackdriver_task_handler.StackdriverTaskHandler',
- 'formatter': 'airflow',
- 'name': log_name,
- 'gcp_key_path': key_path,
+ "task": {
+ "class": "airflow.providers.google.cloud.log.stackdriver_task_handler.StackdriverTaskHandler",
+ "formatter": "airflow",
+ "name": log_name,
+ "gcp_key_path": key_path,
}
}
- DEFAULT_LOGGING_CONFIG['handlers'].update(STACKDRIVER_REMOTE_HANDLERS)
- elif REMOTE_BASE_LOG_FOLDER.startswith('oss://'):
+ DEFAULT_LOGGING_CONFIG["handlers"].update(STACKDRIVER_REMOTE_HANDLERS)
+ elif REMOTE_BASE_LOG_FOLDER.startswith("oss://"):
OSS_REMOTE_HANDLERS = {
- 'task': {
- 'class': 'airflow.providers.alibaba.cloud.log.oss_task_handler.OSSTaskHandler',
- 'formatter': 'airflow',
- 'base_log_folder': os.path.expanduser(BASE_LOG_FOLDER),
- 'oss_log_folder': REMOTE_BASE_LOG_FOLDER,
- 'filename_template': FILENAME_TEMPLATE,
+ "task": {
+ "class": "airflow.providers.alibaba.cloud.log.oss_task_handler.OSSTaskHandler",
+ "formatter": "airflow",
+ "base_log_folder": os.path.expanduser(BASE_LOG_FOLDER),
+ "oss_log_folder": REMOTE_BASE_LOG_FOLDER,
+ "filename_template": FILENAME_TEMPLATE,
},
}
- DEFAULT_LOGGING_CONFIG['handlers'].update(OSS_REMOTE_HANDLERS)
+ DEFAULT_LOGGING_CONFIG["handlers"].update(OSS_REMOTE_HANDLERS)
elif ELASTICSEARCH_HOST:
- ELASTICSEARCH_LOG_ID_TEMPLATE: str = conf.get_mandatory_value('elasticsearch', 'LOG_ID_TEMPLATE')
- ELASTICSEARCH_END_OF_LOG_MARK: str = conf.get_mandatory_value('elasticsearch', 'END_OF_LOG_MARK')
- ELASTICSEARCH_FRONTEND: str = conf.get_mandatory_value('elasticsearch', 'frontend')
- ELASTICSEARCH_WRITE_STDOUT: bool = conf.getboolean('elasticsearch', 'WRITE_STDOUT')
- ELASTICSEARCH_JSON_FORMAT: bool = conf.getboolean('elasticsearch', 'JSON_FORMAT')
- ELASTICSEARCH_JSON_FIELDS: str = conf.get_mandatory_value('elasticsearch', 'JSON_FIELDS')
- ELASTICSEARCH_HOST_FIELD: str = conf.get_mandatory_value('elasticsearch', 'HOST_FIELD')
- ELASTICSEARCH_OFFSET_FIELD: str = conf.get_mandatory_value('elasticsearch', 'OFFSET_FIELD')
-
- ELASTIC_REMOTE_HANDLERS: Dict[str, Dict[str, Union[str, bool]]] = {
- 'task': {
- 'class': 'airflow.providers.elasticsearch.log.es_task_handler.ElasticsearchTaskHandler',
- 'formatter': 'airflow',
- 'base_log_folder': str(os.path.expanduser(BASE_LOG_FOLDER)),
- 'log_id_template': ELASTICSEARCH_LOG_ID_TEMPLATE,
- 'filename_template': FILENAME_TEMPLATE,
- 'end_of_log_mark': ELASTICSEARCH_END_OF_LOG_MARK,
- 'host': ELASTICSEARCH_HOST,
- 'frontend': ELASTICSEARCH_FRONTEND,
- 'write_stdout': ELASTICSEARCH_WRITE_STDOUT,
- 'json_format': ELASTICSEARCH_JSON_FORMAT,
- 'json_fields': ELASTICSEARCH_JSON_FIELDS,
- 'host_field': ELASTICSEARCH_HOST_FIELD,
- 'offset_field': ELASTICSEARCH_OFFSET_FIELD,
+ ELASTICSEARCH_LOG_ID_TEMPLATE: str = conf.get_mandatory_value("elasticsearch", "LOG_ID_TEMPLATE")
+ ELASTICSEARCH_END_OF_LOG_MARK: str = conf.get_mandatory_value("elasticsearch", "END_OF_LOG_MARK")
+ ELASTICSEARCH_FRONTEND: str = conf.get_mandatory_value("elasticsearch", "frontend")
+ ELASTICSEARCH_WRITE_STDOUT: bool = conf.getboolean("elasticsearch", "WRITE_STDOUT")
+ ELASTICSEARCH_JSON_FORMAT: bool = conf.getboolean("elasticsearch", "JSON_FORMAT")
+ ELASTICSEARCH_JSON_FIELDS: str = conf.get_mandatory_value("elasticsearch", "JSON_FIELDS")
+ ELASTICSEARCH_HOST_FIELD: str = conf.get_mandatory_value("elasticsearch", "HOST_FIELD")
+ ELASTICSEARCH_OFFSET_FIELD: str = conf.get_mandatory_value("elasticsearch", "OFFSET_FIELD")
+
+ ELASTIC_REMOTE_HANDLERS: dict[str, dict[str, str | bool | None]] = {
+ "task": {
+ "class": "airflow.providers.elasticsearch.log.es_task_handler.ElasticsearchTaskHandler",
+ "formatter": "airflow",
+ "base_log_folder": str(os.path.expanduser(BASE_LOG_FOLDER)),
+ "log_id_template": ELASTICSEARCH_LOG_ID_TEMPLATE,
+ "filename_template": FILENAME_TEMPLATE,
+ "end_of_log_mark": ELASTICSEARCH_END_OF_LOG_MARK,
+ "host": ELASTICSEARCH_HOST,
+ "frontend": ELASTICSEARCH_FRONTEND,
+ "write_stdout": ELASTICSEARCH_WRITE_STDOUT,
+ "json_format": ELASTICSEARCH_JSON_FORMAT,
+ "json_fields": ELASTICSEARCH_JSON_FIELDS,
+ "host_field": ELASTICSEARCH_HOST_FIELD,
+ "offset_field": ELASTICSEARCH_OFFSET_FIELD,
},
}
- DEFAULT_LOGGING_CONFIG['handlers'].update(ELASTIC_REMOTE_HANDLERS)
+ DEFAULT_LOGGING_CONFIG["handlers"].update(ELASTIC_REMOTE_HANDLERS)
else:
raise AirflowException(
"Incorrect remote log configuration. Please check the configuration of option 'host' in "
diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml
index f884959a88840..2c0721aa89fcd 100644
--- a/airflow/config_templates/config.yml
+++ b/airflow/config_templates/config.yml
@@ -33,15 +33,15 @@
Hostname by providing a path to a callable, which will resolve the hostname.
The format is "package.function".
- For example, default value "socket.getfqdn" means that result from getfqdn() of "socket"
- package will be used as hostname.
+ For example, default value "airflow.utils.net.getfqdn" means that result from patched
+ version of socket.getfqdn() - see https://github.com/python/cpython/issues/49254.
No argument should be required in the function specified.
If using IP address as hostname is preferred, use value ``airflow.utils.net.get_host_ip_address``
version_added: ~
type: string
example: ~
- default: "socket.getfqdn"
+ default: "airflow.utils.net.getfqdn"
- name: default_timezone
description: |
Default timezone in case supplied date times are naive
@@ -99,6 +99,17 @@
type: string
example: ~
default: "16"
+ - name: mp_start_method
+ description: |
+ The name of the method used in order to start Python processes via the multiprocessing module.
+ This corresponds directly with the options available in the Python docs:
+ https://docs.python.org/3/library/multiprocessing.html#multiprocessing.set_start_method.
+ Must be one of the values returned by:
+ https://docs.python.org/3/library/multiprocessing.html#multiprocessing.get_all_start_methods.
+ version_added: "2.0.0"
+ type: string
+ default: ~
+ example: "fork"
- name: load_examples
description: |
Whether to load the DAG examples that ship with Airflow. It's good to
@@ -210,6 +221,15 @@
example: ~
default: "False"
see_also: "https://docs.python.org/3/library/pickle.html#comparison-with-json"
+ - name: allowed_deserialization_classes
+ description: |
+ What classes can be imported during deserialization. This is a multi line value.
+ The individual items will be parsed as regexp. Python built-in classes (like dict)
+ are always allowed
+ version_added: 2.5.0
+ type: string
+ default: 'airflow\..*'
+ example: ~
- name: killed_task_cleanup_time
description: |
When a task is killed forcefully, this is the amount of time in seconds that
@@ -253,7 +273,7 @@
description: |
The number of seconds each task is going to wait by default between retries. Can be overridden at
dag or task level.
- version_added: 2.3.2
+ version_added: 2.4.0
type: integer
example: ~
default: "300"
@@ -374,6 +394,31 @@
example: ~
default: "1024"
+ - name: daemon_umask
+ description: |
+ The default umask to use for process when run in daemon mode (scheduler, worker, etc.)
+
+ This controls the file-creation mode mask which determines the initial value of file permission bits
+ for newly created files.
+
+ This value is treated as an octal-integer.
+ version_added: 2.3.4
+ type: string
+ default: "0o077"
+ example: ~
+ - name: dataset_manager_class
+ description: Class to use as dataset manager.
+ version_added: 2.4.0
+ type: string
+ default: ~
+ example: 'airflow.datasets.manager.DatasetManager'
+ - name: dataset_manager_kwargs
+ description: Kwargs to supply to dataset manager.
+ version_added: 2.4.0
+ type: string
+ default: ~
+ example: '{"some_param": "some_value"}'
+
- name: database
description: ~
options:
@@ -405,7 +450,8 @@
default: "utf-8"
- name: sql_engine_collation_for_ids
description: |
- Collation for ``dag_id``, ``task_id``, ``key`` columns in case they have different encoding.
+ Collation for ``dag_id``, ``task_id``, ``key``, ``external_executor_id`` columns
+ in case they have different encoding.
By default this collation is the same as the database collation, however for ``mysql`` and ``mariadb``
the default is ``utf8mb3_bin`` so that the index sizes of our index keys will not exceed
the maximum size of allowed index when collation is set to ``utf8mb4`` variant
@@ -459,7 +505,7 @@
Check connection at the start of each connection pool checkout.
Typically, this is a simple statement like "SELECT 1".
More information here:
- https://docs.sqlalchemy.org/en/13/core/pooling.html#disconnect-handling-pessimistic
+ https://docs.sqlalchemy.org/en/14/core/pooling.html#disconnect-handling-pessimistic
version_added: 2.3.0
type: string
example: ~
@@ -477,7 +523,7 @@
Import path for connect args in SqlAlchemy. Defaults to an empty dict.
This is useful when you want to configure db engine args that SqlAlchemy won't parse
in connection string.
- See https://docs.sqlalchemy.org/en/13/core/engines.html#sqlalchemy.create_engine.params.connect_args
+ See https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.connect_args
version_added: 2.3.0
type: string
example: ~
@@ -633,6 +679,27 @@
type: string
example: ~
default: "%%(asctime)s %%(levelname)s - %%(message)s"
+ - name: dag_processor_log_target
+ description: Where to send dag parser logs. If "file",
+ logs are sent to log files defined by child_process_log_directory.
+ version_added: 2.4.0
+ type: string
+ example: ~
+ default: "file"
+ - name: dag_processor_log_format
+ description: |
+ Format of Dag Processor Log line
+ version_added: 2.4.0
+ type: string
+ example: ~
+ default: "[%%(asctime)s] [SOURCE:DAG_PROCESSOR]
+ {{%%(filename)s:%%(lineno)d}} %%(levelname)s - %%(message)s"
+ - name: log_formatter_class
+ description: ~
+ version_added: 2.3.4
+ type: string
+ example: ~
+ default: "airflow.utils.log.timezone_aware.TimezoneAware"
- name: task_log_prefix_template
description: |
Specify prefix pattern like mentioned below with stream handler TaskHandlerWithCustomFormatter
@@ -1135,7 +1202,9 @@
- name: worker_class
description: |
The worker class gunicorn should use. Choices include
- sync (default), eventlet, gevent
+ sync (default), eventlet, gevent. Note when using gevent you might also want to set the
+ "_AIRFLOW_PATCH_GEVENT" environment variable to "1" to make sure gevent patching is done as
+ early as possible.
version_added: ~
type: string
example: ~
@@ -1165,7 +1234,9 @@
default: ""
- name: expose_config
description: |
- Expose the configuration file in the web server
+ Expose the configuration file in the web server. Set to "non-sensitive-only" to show all values
+ except those that have security implications. "True" shows all values. "False" hides the
+ configuration completely.
version_added: ~
type: string
example: ~
@@ -1183,7 +1254,7 @@
version_added: 1.10.8
type: string
example: ~
- default: "True"
+ default: "False"
- name: dag_default_view
description: |
Default DAG view. Valid values are: ``grid``, ``graph``, ``duration``, ``gantt``, ``landing_times``
@@ -1645,15 +1716,6 @@
type: boolean
example: ~
default: "true"
- - name: worker_umask
- description: |
- Umask that will be used when starting workers with the ``airflow celery worker``
- in daemon mode. This control the file-creation mode mask which determines the initial
- value of file permission bits for newly created files.
- version_added: 2.0.0
- type: string
- example: ~
- default: "0o077"
- name: broker_url
description: |
The Celery broker URL. Celery supports RabbitMQ, Redis and experimentally
@@ -1670,12 +1732,13 @@
or insert it into a database (depending of the backend)
This status is used by the scheduler to update the state of the task
The use of a database is highly recommended
+ When not specified, sql_alchemy_conn with a db+ scheme prefix will be used
http://docs.celeryproject.org/en/latest/userguide/configuration.html#task-result-backend-settings
version_added: ~
type: string
sensitive: true
- example: ~
- default: "db+postgresql://postgres:airflow@postgres/airflow"
+ example: "db+postgresql://postgres:airflow@postgres/airflow"
+ default: ~
- name: flower_host
description: |
Celery Flower is a sweet UI for Celery. Airflow has a shortcut to start
@@ -1905,11 +1968,12 @@
type: string
example: ~
default: "30"
- - name: deactivate_stale_dags_interval
+ - name: parsing_cleanup_interval
description: |
How often (in seconds) to check for stale DAGs (DAGs which are no longer present in
- the expected files) which should be deactivated.
- version_added: 2.2.5
+ the expected files) which should be deactivated, as well as datasets that are no longer
+ referenced and should be marked as orphaned.
+ version_added: 2.5.0
type: integer
example: ~
default: "60"
@@ -1943,6 +2007,22 @@
type: string
example: ~
default: "30"
+ - name: enable_health_check
+ description: |
+ When you start a scheduler, airflow starts a tiny web server
+ subprocess to serve a health check if this is set to True
+ version_added: 2.4.0
+ type: boolean
+ example: ~
+ default: "False"
+ - name: scheduler_health_check_server_port
+ description: |
+ When you start a scheduler, airflow starts a tiny web server
+ subprocess to serve a health check on this port
+ version_added: 2.4.0
+ type: string
+ example: ~
+ default: "8974"
- name: orphaned_tasks_check_interval
description: |
How often (in seconds) should the scheduler check for orphaned tasks and SchedulerJobs
@@ -2081,6 +2161,14 @@
type: integer
example: ~
default: "20"
+ - name: dag_stale_not_seen_duration
+ description: |
+ Only applicable if `[scheduler]standalone_dag_processor` is true.
+ Time in seconds after which dags, which were not updated by Dag Processor are deactivated.
+ version_added: 2.4.0
+ type: integer
+ example: ~
+ default: "600"
- name: use_job_schedule
description: |
Turn off scheduler use of cron intervals by setting this to False.
@@ -2097,12 +2185,6 @@
type: string
example: ~
default: "False"
- - name: dependency_detector
- description: DAG dependency detector class to use
- version_added: 2.1.0
- type: string
- example: ~
- default: "airflow.serialization.serialized_objects.DependencyDetector"
- name: trigger_timeout_check_interval
description: |
How often to check for expired trigger requests that have not run yet.
@@ -2252,7 +2334,7 @@
type: string
example: ~
default: "True"
-- name: kubernetes
+- name: kubernetes_executor
description: ~
options:
- name: pod_template_file
@@ -2445,36 +2527,3 @@
type: float
example: ~
default: "604800"
-- name: smart_sensor
- description: ~
- options:
- - name: use_smart_sensor
- description: |
- When `use_smart_sensor` is True, Airflow redirects multiple qualified sensor tasks to
- smart sensor task.
- version_added: 2.0.0
- type: boolean
- example: ~
- default: "False"
- - name: shard_code_upper_limit
- description: |
- `shard_code_upper_limit` is the upper limit of `shard_code` value. The `shard_code` is generated
- by `hashcode % shard_code_upper_limit`.
- version_added: 2.0.0
- type: integer
- example: ~
- default: "10000"
- - name: shards
- description: |
- The number of running smart sensor processes for each service.
- version_added: 2.0.0
- type: integer
- example: ~
- default: "5"
- - name: sensors_enabled
- description: |
- comma separated sensor classes support in smart_sensor.
- version_added: 2.0.0
- type: string
- example: ~
- default: "NamedHivePartitionSensor"
diff --git a/airflow/config_templates/default_airflow.cfg b/airflow/config_templates/default_airflow.cfg
index b5c1d4290bd60..80915d6643a56 100644
--- a/airflow/config_templates/default_airflow.cfg
+++ b/airflow/config_templates/default_airflow.cfg
@@ -16,7 +16,6 @@
# specific language governing permissions and limitations
# under the License.
-
# This is the template for Airflow's default configuration. When Airflow is
# imported, it looks for a configuration file at $AIRFLOW_HOME/airflow.cfg. If
# it doesn't exist, Airflow uses this template to generate it by replacing
@@ -36,12 +35,12 @@ dags_folder = {AIRFLOW_HOME}/dags
# Hostname by providing a path to a callable, which will resolve the hostname.
# The format is "package.function".
#
-# For example, default value "socket.getfqdn" means that result from getfqdn() of "socket"
-# package will be used as hostname.
+# For example, default value "airflow.utils.net.getfqdn" means that result from patched
+# version of socket.getfqdn() - see https://github.com/python/cpython/issues/49254.
#
# No argument should be required in the function specified.
# If using IP address as hostname is preferred, use value ``airflow.utils.net.get_host_ip_address``
-hostname_callable = socket.getfqdn
+hostname_callable = airflow.utils.net.getfqdn
# Default timezone in case supplied date times are naive
# can be utc (default), system, or any IANA timezone string (e.g. Europe/Amsterdam)
@@ -76,6 +75,14 @@ dags_are_paused_at_creation = True
# which is defaulted as ``max_active_runs_per_dag``.
max_active_runs_per_dag = 16
+# The name of the method used in order to start Python processes via the multiprocessing module.
+# This corresponds directly with the options available in the Python docs:
+# https://docs.python.org/3/library/multiprocessing.html#multiprocessing.set_start_method.
+# Must be one of the values returned by:
+# https://docs.python.org/3/library/multiprocessing.html#multiprocessing.get_all_start_methods.
+# Example: mp_start_method = fork
+# mp_start_method =
+
# Whether to load the DAG examples that ship with Airflow. It's good to
# get started, but you probably want to set this to ``False`` in a production
# environment
@@ -128,6 +135,11 @@ unit_test_mode = False
# RCE exploits).
enable_xcom_pickling = False
+# What classes can be imported during deserialization. This is a multi line value.
+# The individual items will be parsed as regexp. Python built-in classes (like dict)
+# are always allowed
+allowed_deserialization_classes = airflow\..*
+
# When a task is killed forcefully, this is the amount of time in seconds that
# it has to cleanup after it is sent a SIGTERM, before it is SIGKILLED
killed_task_cleanup_time = 60
@@ -212,6 +224,22 @@ default_pool_task_slot_count = 128
# mapped tasks from clogging the scheduler.
max_map_length = 1024
+# The default umask to use for process when run in daemon mode (scheduler, worker, etc.)
+#
+# This controls the file-creation mode mask which determines the initial value of file permission bits
+# for newly created files.
+#
+# This value is treated as an octal-integer.
+daemon_umask = 0o077
+
+# Class to use as dataset manager.
+# Example: dataset_manager_class = airflow.datasets.manager.DatasetManager
+# dataset_manager_class =
+
+# Kwargs to supply to dataset manager.
+# Example: dataset_manager_kwargs = {{"some_param": "some_value"}}
+# dataset_manager_kwargs =
+
[database]
# The SqlAlchemy connection string to the metadata database.
# SqlAlchemy supports many different database engines.
@@ -226,7 +254,8 @@ sql_alchemy_conn = sqlite:///{AIRFLOW_HOME}/airflow.db
# The encoding for the databases
sql_engine_encoding = utf-8
-# Collation for ``dag_id``, ``task_id``, ``key`` columns in case they have different encoding.
+# Collation for ``dag_id``, ``task_id``, ``key``, ``external_executor_id`` columns
+# in case they have different encoding.
# By default this collation is the same as the database collation, however for ``mysql`` and ``mariadb``
# the default is ``utf8mb3_bin`` so that the index sizes of our index keys will not exceed
# the maximum size of allowed index when collation is set to ``utf8mb4`` variant
@@ -260,7 +289,7 @@ sql_alchemy_pool_recycle = 1800
# Check connection at the start of each connection pool checkout.
# Typically, this is a simple statement like "SELECT 1".
# More information here:
-# https://docs.sqlalchemy.org/en/13/core/pooling.html#disconnect-handling-pessimistic
+# https://docs.sqlalchemy.org/en/14/core/pooling.html#disconnect-handling-pessimistic
sql_alchemy_pool_pre_ping = True
# The schema to use for the metadata database.
@@ -270,7 +299,7 @@ sql_alchemy_schema =
# Import path for connect args in SqlAlchemy. Defaults to an empty dict.
# This is useful when you want to configure db engine args that SqlAlchemy won't parse
# in connection string.
-# See https://docs.sqlalchemy.org/en/13/core/engines.html#sqlalchemy.create_engine.params.connect_args
+# See https://docs.sqlalchemy.org/en/14/core/engines.html#sqlalchemy.create_engine.params.connect_args
# sql_alchemy_connect_args =
# Whether to load the default connections that ship with Airflow. It's good to
@@ -350,6 +379,13 @@ colored_formatter_class = airflow.utils.log.colored_log.CustomTTYColoredFormatte
log_format = [%%(asctime)s] {{%%(filename)s:%%(lineno)d}} %%(levelname)s - %%(message)s
simple_log_format = %%(asctime)s %%(levelname)s - %%(message)s
+# Where to send dag parser logs. If "file", logs are sent to log files defined by child_process_log_directory.
+dag_processor_log_target = file
+
+# Format of Dag Processor Log line
+dag_processor_log_format = [%%(asctime)s] [SOURCE:DAG_PROCESSOR] {{%%(filename)s:%%(lineno)d}} %%(levelname)s - %%(message)s
+log_formatter_class = airflow.utils.log.timezone_aware.TimezoneAware
+
# Specify prefix pattern like mentioned below with stream handler TaskHandlerWithCustomFormatter
# Example: task_log_prefix_template = {{ti.dag_id}}-{{ti.task_id}}-{{execution_date}}-{{try_number}}
task_log_prefix_template =
@@ -585,7 +621,9 @@ secret_key = {SECRET_KEY}
workers = 4
# The worker class gunicorn should use. Choices include
-# sync (default), eventlet, gevent
+# sync (default), eventlet, gevent. Note when using gevent you might also want to set the
+# "_AIRFLOW_PATCH_GEVENT" environment variable to "1" to make sure gevent patching is done as
+# early as possible.
worker_class = sync
# Log files for the gunicorn webserver. '-' means log to stderr.
@@ -599,14 +637,16 @@ error_logfile = -
# documentation - https://docs.gunicorn.org/en/stable/settings.html#access-log-format
access_logformat =
-# Expose the configuration file in the web server
+# Expose the configuration file in the web server. Set to "non-sensitive-only" to show all values
+# except those that have security implications. "True" shows all values. "False" hides the
+# configuration completely.
expose_config = False
# Expose hostname in the web server
expose_hostname = True
# Expose stacktrace in the web server
-expose_stacktrace = True
+expose_stacktrace = False
# Default DAG view. Valid values are: ``grid``, ``graph``, ``duration``, ``gantt``, ``landing_times``
dag_default_view = grid
@@ -832,11 +872,6 @@ worker_prefetch_multiplier = 1
# prevent this by setting this to false. However, with this disabled Flower won't work.
worker_enable_remote_control = true
-# Umask that will be used when starting workers with the ``airflow celery worker``
-# in daemon mode. This control the file-creation mode mask which determines the initial
-# value of file permission bits for newly created files.
-worker_umask = 0o077
-
# The Celery broker URL. Celery supports RabbitMQ, Redis and experimentally
# a sqlalchemy database. Refer to the Celery documentation for more information.
broker_url = redis://redis:6379/0
@@ -846,8 +881,10 @@ broker_url = redis://redis:6379/0
# or insert it into a database (depending of the backend)
# This status is used by the scheduler to update the state of the task
# The use of a database is highly recommended
+# When not specified, sql_alchemy_conn with a db+ scheme prefix will be used
# http://docs.celeryproject.org/en/latest/userguide/configuration.html#task-result-backend-settings
-result_backend = db+postgresql://postgres:airflow@postgres/airflow
+# Example: result_backend = db+postgresql://postgres:airflow@postgres/airflow
+# result_backend =
# Celery Flower is a sweet UI for Celery. Airflow has a shortcut to start
# it ``airflow celery flower``. This defines the IP that Celery Flower runs on
@@ -963,8 +1000,9 @@ scheduler_idle_sleep_time = 1
min_file_process_interval = 30
# How often (in seconds) to check for stale DAGs (DAGs which are no longer present in
-# the expected files) which should be deactivated.
-deactivate_stale_dags_interval = 60
+# the expected files) which should be deactivated, as well as datasets that are no longer
+# referenced and should be marked as orphaned.
+parsing_cleanup_interval = 60
# How often (in seconds) to scan the DAGs directory for new files. Default to 5 minutes.
dag_dir_list_interval = 300
@@ -980,6 +1018,14 @@ pool_metrics_interval = 5.0
# This is used by the health check in the "/health" endpoint
scheduler_health_check_threshold = 30
+# When you start a scheduler, airflow starts a tiny web server
+# subprocess to serve a health check if this is set to True
+enable_health_check = False
+
+# When you start a scheduler, airflow starts a tiny web server
+# subprocess to serve a health check on this port
+scheduler_health_check_server_port = 8974
+
# How often (in seconds) should the scheduler check for orphaned tasks and SchedulerJobs
orphaned_tasks_check_interval = 300.0
child_process_log_directory = {AIRFLOW_HOME}/logs/scheduler
@@ -1054,6 +1100,10 @@ standalone_dag_processor = False
# in database. Contains maximum number of callbacks that are fetched during a single loop.
max_callbacks_per_loop = 20
+# Only applicable if `[scheduler]standalone_dag_processor` is true.
+# Time in seconds after which dags, which were not updated by Dag Processor are deactivated.
+dag_stale_not_seen_duration = 600
+
# Turn off scheduler use of cron intervals by setting this to False.
# DAGs submitted manually in the web UI or with trigger_dag will still run.
use_job_schedule = True
@@ -1062,9 +1112,6 @@ use_job_schedule = True
# Only has effect if schedule_interval is set to None in DAG
allow_trigger_in_future = False
-# DAG dependency detector class to use
-dependency_detector = airflow.serialization.serialized_objects.DependencyDetector
-
# How often to check for expired trigger requests that have not run yet.
trigger_timeout_check_interval = 15
@@ -1122,7 +1169,7 @@ offset_field = offset
use_ssl = False
verify_certs = True
-[kubernetes]
+[kubernetes_executor]
# Path to the YAML pod file that forms the basis for KubernetesExecutor workers.
pod_template_file =
@@ -1218,18 +1265,3 @@ worker_pods_pending_timeout_batch_size = 100
[sensors]
# Sensor default timeout, 7 days by default (7 * 24 * 60 * 60).
default_timeout = 604800
-
-[smart_sensor]
-# When `use_smart_sensor` is True, Airflow redirects multiple qualified sensor tasks to
-# smart sensor task.
-use_smart_sensor = False
-
-# `shard_code_upper_limit` is the upper limit of `shard_code` value. The `shard_code` is generated
-# by `hashcode % shard_code_upper_limit`.
-shard_code_upper_limit = 10000
-
-# The number of running smart sensor processes for each service.
-shards = 5
-
-# comma separated sensor classes support in smart_sensor.
-sensors_enabled = NamedHivePartitionSensor
diff --git a/airflow/config_templates/default_celery.py b/airflow/config_templates/default_celery.py
index 9d81c6353fba2..d3d5a4adf11e2 100644
--- a/airflow/config_templates/default_celery.py
+++ b/airflow/config_templates/default_celery.py
@@ -16,6 +16,8 @@
# specific language governing permissions and limitations
# under the License.
"""Default celery configuration."""
+from __future__ import annotations
+
import logging
import ssl
@@ -29,71 +31,76 @@ def _broker_supports_visibility_timeout(url):
log = logging.getLogger(__name__)
-broker_url = conf.get('celery', 'BROKER_URL')
+broker_url = conf.get("celery", "BROKER_URL")
-broker_transport_options = conf.getsection('celery_broker_transport_options') or {}
-if 'visibility_timeout' not in broker_transport_options:
+broker_transport_options = conf.getsection("celery_broker_transport_options") or {}
+if "visibility_timeout" not in broker_transport_options:
if _broker_supports_visibility_timeout(broker_url):
- broker_transport_options['visibility_timeout'] = 21600
+ broker_transport_options["visibility_timeout"] = 21600
+
+if conf.has_option("celery", "RESULT_BACKEND"):
+ result_backend = conf.get_mandatory_value("celery", "RESULT_BACKEND")
+else:
+ log.debug("Value for celery result_backend not found. Using sql_alchemy_conn with db+ prefix.")
+ result_backend = f'db+{conf.get("database", "SQL_ALCHEMY_CONN")}'
DEFAULT_CELERY_CONFIG = {
- 'accept_content': ['json'],
- 'event_serializer': 'json',
- 'worker_prefetch_multiplier': conf.getint('celery', 'worker_prefetch_multiplier'),
- 'task_acks_late': True,
- 'task_default_queue': conf.get('operators', 'DEFAULT_QUEUE'),
- 'task_default_exchange': conf.get('operators', 'DEFAULT_QUEUE'),
- 'task_track_started': conf.getboolean('celery', 'task_track_started'),
- 'broker_url': broker_url,
- 'broker_transport_options': broker_transport_options,
- 'result_backend': conf.get('celery', 'RESULT_BACKEND'),
- 'worker_concurrency': conf.getint('celery', 'WORKER_CONCURRENCY'),
- 'worker_enable_remote_control': conf.getboolean('celery', 'worker_enable_remote_control'),
+ "accept_content": ["json"],
+ "event_serializer": "json",
+ "worker_prefetch_multiplier": conf.getint("celery", "worker_prefetch_multiplier"),
+ "task_acks_late": True,
+ "task_default_queue": conf.get("operators", "DEFAULT_QUEUE"),
+ "task_default_exchange": conf.get("operators", "DEFAULT_QUEUE"),
+ "task_track_started": conf.getboolean("celery", "task_track_started"),
+ "broker_url": broker_url,
+ "broker_transport_options": broker_transport_options,
+ "result_backend": result_backend,
+ "worker_concurrency": conf.getint("celery", "WORKER_CONCURRENCY"),
+ "worker_enable_remote_control": conf.getboolean("celery", "worker_enable_remote_control"),
}
celery_ssl_active = False
try:
- celery_ssl_active = conf.getboolean('celery', 'SSL_ACTIVE')
+ celery_ssl_active = conf.getboolean("celery", "SSL_ACTIVE")
except AirflowConfigException:
log.warning("Celery Executor will run without SSL")
try:
if celery_ssl_active:
- if broker_url and 'amqp://' in broker_url:
+ if broker_url and "amqp://" in broker_url:
broker_use_ssl = {
- 'keyfile': conf.get('celery', 'SSL_KEY'),
- 'certfile': conf.get('celery', 'SSL_CERT'),
- 'ca_certs': conf.get('celery', 'SSL_CACERT'),
- 'cert_reqs': ssl.CERT_REQUIRED,
+ "keyfile": conf.get("celery", "SSL_KEY"),
+ "certfile": conf.get("celery", "SSL_CERT"),
+ "ca_certs": conf.get("celery", "SSL_CACERT"),
+ "cert_reqs": ssl.CERT_REQUIRED,
}
- elif broker_url and 'redis://' in broker_url:
+ elif broker_url and "redis://" in broker_url:
broker_use_ssl = {
- 'ssl_keyfile': conf.get('celery', 'SSL_KEY'),
- 'ssl_certfile': conf.get('celery', 'SSL_CERT'),
- 'ssl_ca_certs': conf.get('celery', 'SSL_CACERT'),
- 'ssl_cert_reqs': ssl.CERT_REQUIRED,
+ "ssl_keyfile": conf.get("celery", "SSL_KEY"),
+ "ssl_certfile": conf.get("celery", "SSL_CERT"),
+ "ssl_ca_certs": conf.get("celery", "SSL_CACERT"),
+ "ssl_cert_reqs": ssl.CERT_REQUIRED,
}
else:
raise AirflowException(
- 'The broker you configured does not support SSL_ACTIVE to be True. '
- 'Please use RabbitMQ or Redis if you would like to use SSL for broker.'
+ "The broker you configured does not support SSL_ACTIVE to be True. "
+ "Please use RabbitMQ or Redis if you would like to use SSL for broker."
)
- DEFAULT_CELERY_CONFIG['broker_use_ssl'] = broker_use_ssl
+ DEFAULT_CELERY_CONFIG["broker_use_ssl"] = broker_use_ssl
except AirflowConfigException:
raise AirflowException(
- 'AirflowConfigException: SSL_ACTIVE is True, '
- 'please ensure SSL_KEY, '
- 'SSL_CERT and SSL_CACERT are set'
+ "AirflowConfigException: SSL_ACTIVE is True, "
+ "please ensure SSL_KEY, "
+ "SSL_CERT and SSL_CACERT are set"
)
except Exception as e:
raise AirflowException(
- f'Exception: There was an unknown Celery SSL Error. Please ensure you want to use SSL and/or have '
- f'all necessary certs and key ({e}).'
+ f"Exception: There was an unknown Celery SSL Error. Please ensure you want to use SSL and/or have "
+ f"all necessary certs and key ({e})."
)
-result_backend = str(DEFAULT_CELERY_CONFIG['result_backend'])
-if 'amqp://' in result_backend or 'redis://' in result_backend or 'rpc://' in result_backend:
+if "amqp://" in result_backend or "redis://" in result_backend or "rpc://" in result_backend:
log.warning(
"You have configured a result_backend of %s, it is highly recommended "
"to use an alternative result_backend (i.e. a database).",
diff --git a/airflow/config_templates/default_test.cfg b/airflow/config_templates/default_test.cfg
index 2f9b6fa264b13..523f52cb69a04 100644
--- a/airflow/config_templates/default_test.cfg
+++ b/airflow/config_templates/default_test.cfg
@@ -32,100 +32,37 @@
unit_test_mode = True
dags_folder = {TEST_DAGS_FOLDER}
plugins_folder = {TEST_PLUGINS_FOLDER}
-executor = SequentialExecutor
-load_examples = True
-donot_pickle = True
-max_active_tasks_per_dag = 16
dags_are_paused_at_creation = False
fernet_key = {FERNET_KEY}
-enable_xcom_pickling = False
killed_task_cleanup_time = 5
-hostname_callable = socket.getfqdn
-default_task_retries = 0
-# This is a hack, too many tests assume DAGs are already in the DB. We need to fix those tests instead
-store_serialized_dags = False
+allowed_deserialization_classes = airflow\..*
+ tests\..*
[database]
sql_alchemy_conn = sqlite:///{AIRFLOW_HOME}/unittests.db
-load_default_connections = True
[logging]
-base_log_folder = {AIRFLOW_HOME}/logs
-logging_level = INFO
celery_logging_level = WARN
-fab_logging_level = WARN
-log_filename_template = {{{{ ti.dag_id }}}}/{{{{ ti.task_id }}}}/{{{{ ts }}}}/{{{{ try_number }}}}.log
-log_processor_filename_template = {{{{ filename }}}}.log
-dag_processor_manager_log_location = {AIRFLOW_HOME}/logs/dag_processor_manager/dag_processor_manager.log
-worker_log_server_port = 8793
-
-[cli]
-api_client = airflow.api.client.local_client
-endpoint_url = http://localhost:8080
[api]
auth_backends = airflow.api.auth.backend.default
-[operators]
-default_owner = airflow
-
-
[hive]
default_hive_mapred_queue = airflow
-[webserver]
-base_url = http://localhost:8080
-web_server_host = 0.0.0.0
-web_server_port = 8080
-dag_orientation = LR
-dag_default_view = tree
-log_fetch_timeout_sec = 5
-hide_paused_dags_by_default = False
-page_size = 100
-
-[email]
-email_backend = airflow.utils.email.send_email_smtp
-email_conn_id = smtp_default
-
[smtp]
-smtp_host = localhost
smtp_user = airflow
-smtp_port = 25
smtp_password = airflow
-smtp_mail_from = airflow@example.com
-smtp_retry_limit = 5
-smtp_timeout = 30
[celery]
-celery_app_name = airflow.executors.celery_executor
-worker_concurrency = 16
broker_url = sqla+mysql://airflow:airflow@localhost:3306/airflow
result_backend = db+mysql://airflow:airflow@localhost:3306/airflow
-flower_host = 0.0.0.0
-flower_port = 5555
-sync_parallelism = 0
-worker_precheck = False
[scheduler]
job_heartbeat_sec = 1
schedule_after_task_execution = False
-scheduler_heartbeat_sec = 5
-scheduler_health_check_threshold = 30
-parsing_processes = 2
-catchup_by_default = True
-scheduler_zombie_task_threshold = 300
+scheduler_health_check_server_port = 8794
dag_dir_list_interval = 0
-max_tis_per_query = 512
[elasticsearch]
-host =
-log_id_template = {{dag_id}}-{{task_id}}-{{execution_date}}-{{try_number}}
-end_of_log_mark = end_of_log
-
-[elasticsearch_configs]
-
-use_ssl = False
-verify_certs = True
-
-[kubernetes]
-dags_volume_claim = default
+log_id_template = {{dag_id}}-{{task_id}}-{{run_id}}-{{map_index}}-{{try_number}}
diff --git a/airflow/config_templates/default_webserver_config.py b/airflow/config_templates/default_webserver_config.py
index 5c7eaa1b16a39..aa22b125fa98c 100644
--- a/airflow/config_templates/default_webserver_config.py
+++ b/airflow/config_templates/default_webserver_config.py
@@ -15,7 +15,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Default configuration for the Airflow webserver"""
+"""Default configuration for the Airflow webserver."""
+from __future__ import annotations
+
import os
from airflow.www.fab_security.manager import AUTH_DB
@@ -30,6 +32,7 @@
# Flask-WTF flag for CSRF
WTF_CSRF_ENABLED = True
+WTF_CSRF_TIME_LIMIT = None
# ----------------------------------------------------
# AUTHENTICATION CONFIG
diff --git a/airflow/configuration.py b/airflow/configuration.py
index 93612160c8290..ce55aa45c6075 100644
--- a/airflow/configuration.py
+++ b/airflow/configuration.py
@@ -14,6 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
import datetime
import functools
import json
@@ -31,14 +33,15 @@
# Ignored Mypy on configparser because it thinks the configparser module has no _UNSET attribute
from configparser import _UNSET, ConfigParser, NoOptionError, NoSectionError # type: ignore
-from contextlib import suppress
+from contextlib import contextmanager, suppress
from json.decoder import JSONDecodeError
from re import Pattern
-from typing import IO, Any, Dict, Iterable, List, Optional, Set, Tuple, Union
-from urllib.parse import urlparse
+from typing import IO, Any, Dict, Iterable, Tuple, Union
+from urllib.parse import urlsplit
from typing_extensions import overload
+from airflow.compat.functools import cached_property
from airflow.exceptions import AirflowConfigException
from airflow.secrets import DEFAULT_SECRETS_SEARCH_PATH, BaseSecretsBackend
from airflow.utils import yaml
@@ -49,8 +52,8 @@
# show Airflow's deprecation warnings
if not sys.warnoptions:
- warnings.filterwarnings(action='default', category=DeprecationWarning, module='airflow')
- warnings.filterwarnings(action='default', category=PendingDeprecationWarning, module='airflow')
+ warnings.filterwarnings(action="default", category=DeprecationWarning, module="airflow")
+ warnings.filterwarnings(action="default", category=PendingDeprecationWarning, module="airflow")
_SQLITE3_VERSION_PATTERN = re.compile(r"(?P^\d+(?:\.\d+)*)\D?.*$")
@@ -59,10 +62,10 @@
ConfigSectionSourcesType = Dict[str, Union[str, Tuple[str, str]]]
ConfigSourcesType = Dict[str, ConfigSectionSourcesType]
-ENV_VAR_PREFIX = 'AIRFLOW__'
+ENV_VAR_PREFIX = "AIRFLOW__"
-def _parse_sqlite_version(s: str) -> Tuple[int, ...]:
+def _parse_sqlite_version(s: str) -> tuple[int, ...]:
match = _SQLITE3_VERSION_PATTERN.match(s)
if match is None:
return ()
@@ -79,11 +82,12 @@ def expand_env_var(env_var: str) -> str:
...
-def expand_env_var(env_var: Union[str, None]) -> Optional[Union[str, None]]:
+def expand_env_var(env_var: str | None) -> str | None:
"""
- Expands (potentially nested) env vars by repeatedly applying
- `expandvars` and `expanduser` until interpolation stops having
- any effect.
+ Expands (potentially nested) env vars.
+
+ Repeat and apply `expandvars` and `expanduser` until
+ interpolation stops having any effect.
"""
if not env_var:
return env_var
@@ -96,11 +100,11 @@ def expand_env_var(env_var: Union[str, None]) -> Optional[Union[str, None]]:
def run_command(command: str) -> str:
- """Runs command and returns stdout"""
+ """Runs command and returns stdout."""
process = subprocess.Popen(
shlex.split(command), stdout=subprocess.PIPE, stderr=subprocess.PIPE, close_fds=True
)
- output, stderr = (stream.decode(sys.getdefaultencoding(), 'ignore') for stream in process.communicate())
+ output, stderr = (stream.decode(sys.getdefaultencoding(), "ignore") for stream in process.communicate())
if process.returncode != 0:
raise AirflowConfigException(
@@ -111,8 +115,8 @@ def run_command(command: str) -> str:
return output
-def _get_config_value_from_secret_backend(config_key: str) -> Optional[str]:
- """Get Config option values from Secret Backend"""
+def _get_config_value_from_secret_backend(config_key: str) -> str | None:
+ """Get Config option values from Secret Backend."""
try:
secrets_client = get_custom_secret_backend()
if not secrets_client:
@@ -120,163 +124,183 @@ def _get_config_value_from_secret_backend(config_key: str) -> Optional[str]:
return secrets_client.get_config(config_key)
except Exception as e:
raise AirflowConfigException(
- 'Cannot retrieve config from alternative secrets backend. '
- 'Make sure it is configured properly and that the Backend '
- 'is accessible.\n'
- f'{e}'
+ "Cannot retrieve config from alternative secrets backend. "
+ "Make sure it is configured properly and that the Backend "
+ "is accessible.\n"
+ f"{e}"
)
def _default_config_file_path(file_name: str) -> str:
- templates_dir = os.path.join(os.path.dirname(__file__), 'config_templates')
+ templates_dir = os.path.join(os.path.dirname(__file__), "config_templates")
return os.path.join(templates_dir, file_name)
-def default_config_yaml() -> List[Dict[str, Any]]:
+def default_config_yaml() -> list[dict[str, Any]]:
"""
- Read Airflow configs from YAML file
+ Read Airflow configs from YAML file.
:return: Python dictionary containing configs & their info
"""
- with open(_default_config_file_path('config.yml')) as config_file:
+ with open(_default_config_file_path("config.yml")) as config_file:
return yaml.safe_load(config_file)
+SENSITIVE_CONFIG_VALUES = {
+ ("database", "sql_alchemy_conn"),
+ ("core", "fernet_key"),
+ ("celery", "broker_url"),
+ ("celery", "flower_basic_auth"),
+ ("celery", "result_backend"),
+ ("atlas", "password"),
+ ("smtp", "smtp_password"),
+ ("webserver", "secret_key"),
+ # The following options are deprecated
+ ("core", "sql_alchemy_conn"),
+}
+
+
class AirflowConfigParser(ConfigParser):
- """Custom Airflow Configparser supporting defaults and deprecated options"""
+ """Custom Airflow Configparser supporting defaults and deprecated options."""
# These configuration elements can be fetched as the stdout of commands
# following the "{section}__{name}_cmd" pattern, the idea behind this
# is to not store password on boxes in text files.
# These configs can also be fetched from Secrets backend
# following the "{section}__{name}__secret" pattern
- sensitive_config_values: Set[Tuple[str, str]] = {
- ('database', 'sql_alchemy_conn'),
- ('core', 'fernet_key'),
- ('celery', 'broker_url'),
- ('celery', 'flower_basic_auth'),
- ('celery', 'result_backend'),
- ('atlas', 'password'),
- ('smtp', 'smtp_password'),
- ('webserver', 'secret_key'),
- # The following options are deprecated
- ('core', 'sql_alchemy_conn'),
- }
+
+ sensitive_config_values: set[tuple[str, str]] = SENSITIVE_CONFIG_VALUES
# A mapping of (new section, new option) -> (old section, old option, since_version).
# When reading new option, the old option will be checked to see if it exists. If it does a
# DeprecationWarning will be issued and the old option will be used instead
- deprecated_options: Dict[Tuple[str, str], Tuple[str, str, str]] = {
- ('celery', 'worker_precheck'): ('core', 'worker_precheck', '2.0.0'),
- ('logging', 'base_log_folder'): ('core', 'base_log_folder', '2.0.0'),
- ('logging', 'remote_logging'): ('core', 'remote_logging', '2.0.0'),
- ('logging', 'remote_log_conn_id'): ('core', 'remote_log_conn_id', '2.0.0'),
- ('logging', 'remote_base_log_folder'): ('core', 'remote_base_log_folder', '2.0.0'),
- ('logging', 'encrypt_s3_logs'): ('core', 'encrypt_s3_logs', '2.0.0'),
- ('logging', 'logging_level'): ('core', 'logging_level', '2.0.0'),
- ('logging', 'fab_logging_level'): ('core', 'fab_logging_level', '2.0.0'),
- ('logging', 'logging_config_class'): ('core', 'logging_config_class', '2.0.0'),
- ('logging', 'colored_console_log'): ('core', 'colored_console_log', '2.0.0'),
- ('logging', 'colored_log_format'): ('core', 'colored_log_format', '2.0.0'),
- ('logging', 'colored_formatter_class'): ('core', 'colored_formatter_class', '2.0.0'),
- ('logging', 'log_format'): ('core', 'log_format', '2.0.0'),
- ('logging', 'simple_log_format'): ('core', 'simple_log_format', '2.0.0'),
- ('logging', 'task_log_prefix_template'): ('core', 'task_log_prefix_template', '2.0.0'),
- ('logging', 'log_filename_template'): ('core', 'log_filename_template', '2.0.0'),
- ('logging', 'log_processor_filename_template'): ('core', 'log_processor_filename_template', '2.0.0'),
- ('logging', 'dag_processor_manager_log_location'): (
- 'core',
- 'dag_processor_manager_log_location',
- '2.0.0',
+ deprecated_options: dict[tuple[str, str], tuple[str, str, str]] = {
+ ("celery", "worker_precheck"): ("core", "worker_precheck", "2.0.0"),
+ ("logging", "base_log_folder"): ("core", "base_log_folder", "2.0.0"),
+ ("logging", "remote_logging"): ("core", "remote_logging", "2.0.0"),
+ ("logging", "remote_log_conn_id"): ("core", "remote_log_conn_id", "2.0.0"),
+ ("logging", "remote_base_log_folder"): ("core", "remote_base_log_folder", "2.0.0"),
+ ("logging", "encrypt_s3_logs"): ("core", "encrypt_s3_logs", "2.0.0"),
+ ("logging", "logging_level"): ("core", "logging_level", "2.0.0"),
+ ("logging", "fab_logging_level"): ("core", "fab_logging_level", "2.0.0"),
+ ("logging", "logging_config_class"): ("core", "logging_config_class", "2.0.0"),
+ ("logging", "colored_console_log"): ("core", "colored_console_log", "2.0.0"),
+ ("logging", "colored_log_format"): ("core", "colored_log_format", "2.0.0"),
+ ("logging", "colored_formatter_class"): ("core", "colored_formatter_class", "2.0.0"),
+ ("logging", "log_format"): ("core", "log_format", "2.0.0"),
+ ("logging", "simple_log_format"): ("core", "simple_log_format", "2.0.0"),
+ ("logging", "task_log_prefix_template"): ("core", "task_log_prefix_template", "2.0.0"),
+ ("logging", "log_filename_template"): ("core", "log_filename_template", "2.0.0"),
+ ("logging", "log_processor_filename_template"): ("core", "log_processor_filename_template", "2.0.0"),
+ ("logging", "dag_processor_manager_log_location"): (
+ "core",
+ "dag_processor_manager_log_location",
+ "2.0.0",
),
- ('logging', 'task_log_reader'): ('core', 'task_log_reader', '2.0.0'),
- ('metrics', 'statsd_on'): ('scheduler', 'statsd_on', '2.0.0'),
- ('metrics', 'statsd_host'): ('scheduler', 'statsd_host', '2.0.0'),
- ('metrics', 'statsd_port'): ('scheduler', 'statsd_port', '2.0.0'),
- ('metrics', 'statsd_prefix'): ('scheduler', 'statsd_prefix', '2.0.0'),
- ('metrics', 'statsd_allow_list'): ('scheduler', 'statsd_allow_list', '2.0.0'),
- ('metrics', 'stat_name_handler'): ('scheduler', 'stat_name_handler', '2.0.0'),
- ('metrics', 'statsd_datadog_enabled'): ('scheduler', 'statsd_datadog_enabled', '2.0.0'),
- ('metrics', 'statsd_datadog_tags'): ('scheduler', 'statsd_datadog_tags', '2.0.0'),
- ('metrics', 'statsd_custom_client_path'): ('scheduler', 'statsd_custom_client_path', '2.0.0'),
- ('scheduler', 'parsing_processes'): ('scheduler', 'max_threads', '1.10.14'),
- ('scheduler', 'scheduler_idle_sleep_time'): ('scheduler', 'processor_poll_interval', '2.2.0'),
- ('operators', 'default_queue'): ('celery', 'default_queue', '2.1.0'),
- ('core', 'hide_sensitive_var_conn_fields'): ('admin', 'hide_sensitive_variable_fields', '2.1.0'),
- ('core', 'sensitive_var_conn_names'): ('admin', 'sensitive_variable_fields', '2.1.0'),
- ('core', 'default_pool_task_slot_count'): ('core', 'non_pooled_task_slot_count', '1.10.4'),
- ('core', 'max_active_tasks_per_dag'): ('core', 'dag_concurrency', '2.2.0'),
- ('logging', 'worker_log_server_port'): ('celery', 'worker_log_server_port', '2.2.0'),
- ('api', 'access_control_allow_origins'): ('api', 'access_control_allow_origin', '2.2.0'),
- ('api', 'auth_backends'): ('api', 'auth_backend', '2.3.0'),
- ('database', 'sql_alchemy_conn'): ('core', 'sql_alchemy_conn', '2.3.0'),
- ('database', 'sql_engine_encoding'): ('core', 'sql_engine_encoding', '2.3.0'),
- ('database', 'sql_engine_collation_for_ids'): ('core', 'sql_engine_collation_for_ids', '2.3.0'),
- ('database', 'sql_alchemy_pool_enabled'): ('core', 'sql_alchemy_pool_enabled', '2.3.0'),
- ('database', 'sql_alchemy_pool_size'): ('core', 'sql_alchemy_pool_size', '2.3.0'),
- ('database', 'sql_alchemy_max_overflow'): ('core', 'sql_alchemy_max_overflow', '2.3.0'),
- ('database', 'sql_alchemy_pool_recycle'): ('core', 'sql_alchemy_pool_recycle', '2.3.0'),
- ('database', 'sql_alchemy_pool_pre_ping'): ('core', 'sql_alchemy_pool_pre_ping', '2.3.0'),
- ('database', 'sql_alchemy_schema'): ('core', 'sql_alchemy_schema', '2.3.0'),
- ('database', 'sql_alchemy_connect_args'): ('core', 'sql_alchemy_connect_args', '2.3.0'),
- ('database', 'load_default_connections'): ('core', 'load_default_connections', '2.3.0'),
- ('database', 'max_db_retries'): ('core', 'max_db_retries', '2.3.0'),
+ ("logging", "task_log_reader"): ("core", "task_log_reader", "2.0.0"),
+ ("metrics", "statsd_on"): ("scheduler", "statsd_on", "2.0.0"),
+ ("metrics", "statsd_host"): ("scheduler", "statsd_host", "2.0.0"),
+ ("metrics", "statsd_port"): ("scheduler", "statsd_port", "2.0.0"),
+ ("metrics", "statsd_prefix"): ("scheduler", "statsd_prefix", "2.0.0"),
+ ("metrics", "statsd_allow_list"): ("scheduler", "statsd_allow_list", "2.0.0"),
+ ("metrics", "stat_name_handler"): ("scheduler", "stat_name_handler", "2.0.0"),
+ ("metrics", "statsd_datadog_enabled"): ("scheduler", "statsd_datadog_enabled", "2.0.0"),
+ ("metrics", "statsd_datadog_tags"): ("scheduler", "statsd_datadog_tags", "2.0.0"),
+ ("metrics", "statsd_custom_client_path"): ("scheduler", "statsd_custom_client_path", "2.0.0"),
+ ("scheduler", "parsing_processes"): ("scheduler", "max_threads", "1.10.14"),
+ ("scheduler", "scheduler_idle_sleep_time"): ("scheduler", "processor_poll_interval", "2.2.0"),
+ ("operators", "default_queue"): ("celery", "default_queue", "2.1.0"),
+ ("core", "hide_sensitive_var_conn_fields"): ("admin", "hide_sensitive_variable_fields", "2.1.0"),
+ ("core", "sensitive_var_conn_names"): ("admin", "sensitive_variable_fields", "2.1.0"),
+ ("core", "default_pool_task_slot_count"): ("core", "non_pooled_task_slot_count", "1.10.4"),
+ ("core", "max_active_tasks_per_dag"): ("core", "dag_concurrency", "2.2.0"),
+ ("logging", "worker_log_server_port"): ("celery", "worker_log_server_port", "2.2.0"),
+ ("api", "access_control_allow_origins"): ("api", "access_control_allow_origin", "2.2.0"),
+ ("api", "auth_backends"): ("api", "auth_backend", "2.3.0"),
+ ("database", "sql_alchemy_conn"): ("core", "sql_alchemy_conn", "2.3.0"),
+ ("database", "sql_engine_encoding"): ("core", "sql_engine_encoding", "2.3.0"),
+ ("database", "sql_engine_collation_for_ids"): ("core", "sql_engine_collation_for_ids", "2.3.0"),
+ ("database", "sql_alchemy_pool_enabled"): ("core", "sql_alchemy_pool_enabled", "2.3.0"),
+ ("database", "sql_alchemy_pool_size"): ("core", "sql_alchemy_pool_size", "2.3.0"),
+ ("database", "sql_alchemy_max_overflow"): ("core", "sql_alchemy_max_overflow", "2.3.0"),
+ ("database", "sql_alchemy_pool_recycle"): ("core", "sql_alchemy_pool_recycle", "2.3.0"),
+ ("database", "sql_alchemy_pool_pre_ping"): ("core", "sql_alchemy_pool_pre_ping", "2.3.0"),
+ ("database", "sql_alchemy_schema"): ("core", "sql_alchemy_schema", "2.3.0"),
+ ("database", "sql_alchemy_connect_args"): ("core", "sql_alchemy_connect_args", "2.3.0"),
+ ("database", "load_default_connections"): ("core", "load_default_connections", "2.3.0"),
+ ("database", "max_db_retries"): ("core", "max_db_retries", "2.3.0"),
+ ("scheduler", "parsing_cleanup_interval"): ("scheduler", "deactivate_stale_dags_interval", "2.5.0"),
}
+ # A mapping of new section -> (old section, since_version).
+ deprecated_sections: dict[str, tuple[str, str]] = {"kubernetes_executor": ("kubernetes", "2.5.0")}
+
+ # Now build the inverse so we can go from old_section/old_key to new_section/new_key
+ # if someone tries to retrieve it based on old_section/old_key
+ @cached_property
+ def inversed_deprecated_options(self):
+ return {(sec, name): key for key, (sec, name, ver) in self.deprecated_options.items()}
+
+ @cached_property
+ def inversed_deprecated_sections(self):
+ return {
+ old_section: new_section for new_section, (old_section, ver) in self.deprecated_sections.items()
+ }
+
# A mapping of old default values that we want to change and warn the user
# about. Mapping of section -> setting -> { old, replace, by_version }
- deprecated_values: Dict[str, Dict[str, Tuple[Pattern, str, str]]] = {
- 'core': {
- 'hostname_callable': (re.compile(r':'), r'.', '2.1'),
+ deprecated_values: dict[str, dict[str, tuple[Pattern, str, str]]] = {
+ "core": {
+ "hostname_callable": (re.compile(r":"), r".", "2.1"),
},
- 'webserver': {
- 'navbar_color': (re.compile(r'\A#007A87\Z', re.IGNORECASE), '#fff', '2.1'),
- 'dag_default_view': (re.compile(r'^tree$'), 'grid', '3.0'),
+ "webserver": {
+ "navbar_color": (re.compile(r"\A#007A87\Z", re.IGNORECASE), "#fff", "2.1"),
+ "dag_default_view": (re.compile(r"^tree$"), "grid", "3.0"),
},
- 'email': {
- 'email_backend': (
- re.compile(r'^airflow\.contrib\.utils\.sendgrid\.send_email$'),
- r'airflow.providers.sendgrid.utils.emailer.send_email',
- '2.1',
+ "email": {
+ "email_backend": (
+ re.compile(r"^airflow\.contrib\.utils\.sendgrid\.send_email$"),
+ r"airflow.providers.sendgrid.utils.emailer.send_email",
+ "2.1",
),
},
- 'logging': {
- 'log_filename_template': (
+ "logging": {
+ "log_filename_template": (
re.compile(re.escape("{{ ti.dag_id }}/{{ ti.task_id }}/{{ ts }}/{{ try_number }}.log")),
"XX-set-after-default-config-loaded-XX",
- '3.0',
+ "3.0",
),
},
- 'api': {
- 'auth_backends': (
- re.compile(r'^airflow\.api\.auth\.backend\.deny_all$|^$'),
- 'airflow.api.auth.backend.session',
- '3.0',
+ "api": {
+ "auth_backends": (
+ re.compile(r"^airflow\.api\.auth\.backend\.deny_all$|^$"),
+ "airflow.api.auth.backend.session",
+ "3.0",
),
},
- 'elasticsearch': {
- 'log_id_template': (
- re.compile('^' + re.escape('{dag_id}-{task_id}-{run_id}-{try_number}') + '$'),
- '{dag_id}-{task_id}-{run_id}-{map_index}-{try_number}',
- '3.0',
+ "elasticsearch": {
+ "log_id_template": (
+ re.compile("^" + re.escape("{dag_id}-{task_id}-{execution_date}-{try_number}") + "$"),
+ "{dag_id}-{task_id}-{run_id}-{map_index}-{try_number}",
+ "3.0",
)
},
}
- _available_logging_levels = ['CRITICAL', 'FATAL', 'ERROR', 'WARN', 'WARNING', 'INFO', 'DEBUG']
+ _available_logging_levels = ["CRITICAL", "FATAL", "ERROR", "WARN", "WARNING", "INFO", "DEBUG"]
enums_options = {
("core", "default_task_weight_rule"): sorted(WeightRule.all_weight_rules()),
("core", "dag_ignore_file_syntax"): ["regexp", "glob"],
- ('core', 'mp_start_method'): multiprocessing.get_all_start_methods(),
+ ("core", "mp_start_method"): multiprocessing.get_all_start_methods(),
("scheduler", "file_parsing_sort_mode"): ["modified_time", "random_seeded_by_host", "alphabetical"],
("logging", "logging_level"): _available_logging_levels,
("logging", "fab_logging_level"): _available_logging_levels,
# celery_logging_level can be empty, which uses logging_level as fallback
- ("logging", "celery_logging_level"): _available_logging_levels + [''],
- ("webserver", "analytical_tool"): ['google_analytics', 'metarouter', 'segment', ''],
+ ("logging", "celery_logging_level"): _available_logging_levels + [""],
+ ("webserver", "analytical_tool"): ["google_analytics", "metarouter", "segment", ""],
}
- upgraded_values: Dict[Tuple[str, str], str]
+ upgraded_values: dict[tuple[str, str], str]
"""Mapping of (section,option) to the old value that was upgraded"""
# This method transforms option names on every read, get, or set operation.
@@ -285,7 +309,7 @@ class AirflowConfigParser(ConfigParser):
def optionxform(self, optionstr: str) -> str:
return optionstr
- def __init__(self, default_config: Optional[str] = None, *args, **kwargs):
+ def __init__(self, default_config: str | None = None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.upgraded_values = {}
@@ -293,10 +317,10 @@ def __init__(self, default_config: Optional[str] = None, *args, **kwargs):
if default_config is not None:
self.airflow_defaults.read_string(default_config)
# Set the upgrade value based on the current loaded default
- default = self.airflow_defaults.get('logging', 'log_filename_template', fallback=None)
+ default = self.airflow_defaults.get("logging", "log_filename_template", fallback=None)
if default:
- replacement = self.deprecated_values['logging']['log_filename_template']
- self.deprecated_values['logging']['log_filename_template'] = (
+ replacement = self.deprecated_values["logging"]["log_filename_template"]
+ self.deprecated_values["logging"]["log_filename_template"] = (
replacement[0],
default,
replacement[2],
@@ -304,12 +328,13 @@ def __init__(self, default_config: Optional[str] = None, *args, **kwargs):
else:
# In case of tests it might not exist
with suppress(KeyError):
- del self.deprecated_values['logging']['log_filename_template']
+ del self.deprecated_values["logging"]["log_filename_template"]
else:
with suppress(KeyError):
- del self.deprecated_values['logging']['log_filename_template']
+ del self.deprecated_values["logging"]["log_filename_template"]
self.is_validated = False
+ self._suppress_future_warnings = False
def validate(self):
self._validate_config_dependencies()
@@ -337,14 +362,15 @@ def validate(self):
def _upgrade_auth_backends(self):
"""
- Ensure a custom auth_backends setting contains session,
- which is needed by the UI for ajax queries.
+ Ensure a custom auth_backends setting contains session.
+
+ This is required by the UI for ajax queries.
"""
old_value = self.get("api", "auth_backends", fallback="")
- if old_value in ('airflow.api.auth.backend.default', ''):
+ if old_value in ("airflow.api.auth.backend.default", ""):
# handled by deprecated_values
pass
- elif old_value.find('airflow.api.auth.backend.session') == -1:
+ elif old_value.find("airflow.api.auth.backend.session") == -1:
new_value = old_value + ",airflow.api.auth.backend.session"
self._update_env_var(section="api", name="auth_backends", new_value=new_value)
self.upgraded_values[("api", "auth_backends")] = old_value
@@ -355,28 +381,33 @@ def _upgrade_auth_backends(self):
os.environ.pop(old_env_var, None)
warnings.warn(
- 'The auth_backends setting in [api] has had airflow.api.auth.backend.session added '
- 'in the running config, which is needed by the UI. Please update your config before '
- 'Apache Airflow 3.0.',
+ "The auth_backends setting in [api] has had airflow.api.auth.backend.session added "
+ "in the running config, which is needed by the UI. Please update your config before "
+ "Apache Airflow 3.0.",
FutureWarning,
)
def _upgrade_postgres_metastore_conn(self):
- """As of sqlalchemy 1.4, scheme `postgres+psycopg2` must be replaced with `postgresql`"""
- section, key = 'database', 'sql_alchemy_conn'
- old_value = self.get(section, key)
- bad_scheme = 'postgres+psycopg2'
- good_scheme = 'postgresql'
- parsed = urlparse(old_value)
- if parsed.scheme == bad_scheme:
+ """
+ Upgrade SQL schemas.
+
+ As of SQLAlchemy 1.4, schemes `postgres+psycopg2` and `postgres`
+ must be replaced with `postgresql`.
+ """
+ section, key = "database", "sql_alchemy_conn"
+ old_value = self.get(section, key, _extra_stacklevel=1)
+ bad_schemes = ["postgres+psycopg2", "postgres"]
+ good_scheme = "postgresql"
+ parsed = urlsplit(old_value)
+ if parsed.scheme in bad_schemes:
warnings.warn(
- f"Bad scheme in Airflow configuration core > sql_alchemy_conn: `{bad_scheme}`. "
- "As of SqlAlchemy 1.4 (adopted in Airflow 2.3) this is no longer supported. You must "
+ f"Bad scheme in Airflow configuration core > sql_alchemy_conn: `{parsed.scheme}`. "
+ "As of SQLAlchemy 1.4 (adopted in Airflow 2.3) this is no longer supported. You must "
f"change to `{good_scheme}` before the next Airflow release.",
FutureWarning,
)
self.upgraded_values[(section, key)] = old_value
- new_value = re.sub('^' + re.escape(f"{bad_scheme}://"), f"{good_scheme}://", old_value)
+ new_value = re.sub("^" + re.escape(f"{parsed.scheme}://"), f"{good_scheme}://", old_value)
self._update_env_var(section=section, name=key, new_value=new_value)
# if the old value is set via env var, we need to wipe it
@@ -385,7 +416,7 @@ def _upgrade_postgres_metastore_conn(self):
os.environ.pop(old_env_var, None)
def _validate_enums(self):
- """Validate that enum type config has an accepted value"""
+ """Validate that enum type config has an accepted value."""
for (section_key, option_key), enum_options in self.enums_options.items():
if self.has_option(section_key, option_key):
value = self.get(section_key, option_key)
@@ -397,14 +428,16 @@ def _validate_enums(self):
def _validate_config_dependencies(self):
"""
- Validate that config values aren't invalid given other config values
+ Validate that config based on condition.
+
+ Values are considered invalid when they conflict with other config values
or system-level limitations and requirements.
"""
is_executor_without_sqlite_support = self.get("core", "executor") not in (
- 'DebugExecutor',
- 'SequentialExecutor',
+ "DebugExecutor",
+ "SequentialExecutor",
)
- is_sqlite = "sqlite" in self.get('database', 'sql_alchemy_conn')
+ is_sqlite = "sqlite" in self.get("database", "sql_alchemy_conn")
if is_sqlite and is_executor_without_sqlite_support:
raise AirflowConfigException(f"error: cannot use sqlite with the {self.get('core', 'executor')}")
if is_sqlite:
@@ -424,7 +457,7 @@ def _validate_config_dependencies(self):
def _using_old_value(self, old: Pattern, current_value: str) -> bool:
return old.search(current_value) is not None
- def _update_env_var(self, section: str, name: str, new_value: Union[str]):
+ def _update_env_var(self, section: str, name: str, new_value: str):
env_var = self._env_var_name(section, name)
# Set it as an env var so that any subprocesses keep the same override!
os.environ[env_var] = new_value
@@ -432,14 +465,14 @@ def _update_env_var(self, section: str, name: str, new_value: Union[str]):
@staticmethod
def _create_future_warning(name: str, section: str, current_value: Any, new_value: Any, version: str):
warnings.warn(
- f'The {name!r} setting in [{section}] has the old default value of {current_value!r}. '
- f'This value has been changed to {new_value!r} in the running config, but '
- f'please update your config before Apache Airflow {version}.',
+ f"The {name!r} setting in [{section}] has the old default value of {current_value!r}. "
+ f"This value has been changed to {new_value!r} in the running config, but "
+ f"please update your config before Apache Airflow {version}.",
FutureWarning,
)
def _env_var_name(self, section: str, key: str) -> str:
- return f'{ENV_VAR_PREFIX}{section.upper()}__{key.upper()}'
+ return f"{ENV_VAR_PREFIX}{section.upper()}__{key.upper()}"
def _get_env_var_option(self, section: str, key: str):
# must have format AIRFLOW__{SECTION}__{KEY} (note double underscore)
@@ -447,13 +480,13 @@ def _get_env_var_option(self, section: str, key: str):
if env_var in os.environ:
return expand_env_var(os.environ[env_var])
# alternatively AIRFLOW__{SECTION}__{KEY}_CMD (for a command)
- env_var_cmd = env_var + '_CMD'
+ env_var_cmd = env_var + "_CMD"
if env_var_cmd in os.environ:
# if this is a valid command key...
if (section, key) in self.sensitive_config_values:
return run_command(os.environ[env_var_cmd])
# alternatively AIRFLOW__{SECTION}__{KEY}_SECRET (to get from Secrets Backend)
- env_var_secret_path = env_var + '_SECRET'
+ env_var_secret_path = env_var + "_SECRET"
if env_var_secret_path in os.environ:
# if this is a valid secret path...
if (section, key) in self.sensitive_config_values:
@@ -461,7 +494,7 @@ def _get_env_var_option(self, section: str, key: str):
return None
def _get_cmd_option(self, section: str, key: str):
- fallback_key = key + '_cmd'
+ fallback_key = key + "_cmd"
if (section, key) in self.sensitive_config_values:
if super().has_option(section, fallback_key):
command = super().get(section, fallback_key)
@@ -470,8 +503,8 @@ def _get_cmd_option(self, section: str, key: str):
def _get_cmd_option_from_config_sources(
self, config_sources: ConfigSourcesType, section: str, key: str
- ) -> Optional[str]:
- fallback_key = key + '_cmd'
+ ) -> str | None:
+ fallback_key = key + "_cmd"
if (section, key) in self.sensitive_config_values:
section_dict = config_sources.get(section)
if section_dict is not None:
@@ -484,9 +517,9 @@ def _get_cmd_option_from_config_sources(
return run_command(command)
return None
- def _get_secret_option(self, section: str, key: str) -> Optional[str]:
- """Get Config option values from Secret Backend"""
- fallback_key = key + '_secret'
+ def _get_secret_option(self, section: str, key: str) -> str | None:
+ """Get Config option values from Secret Backend."""
+ fallback_key = key + "_secret"
if (section, key) in self.sensitive_config_values:
if super().has_option(section, fallback_key):
secrets_path = super().get(section, fallback_key)
@@ -495,8 +528,8 @@ def _get_secret_option(self, section: str, key: str) -> Optional[str]:
def _get_secret_option_from_config_sources(
self, config_sources: ConfigSourcesType, section: str, key: str
- ) -> Optional[str]:
- fallback_key = key + '_secret'
+ ) -> str | None:
+ fallback_key = key + "_secret"
if (section, key) in self.sensitive_config_values:
section_dict = config_sources.get(section)
if section_dict is not None:
@@ -510,115 +543,219 @@ def _get_secret_option_from_config_sources(
return None
def get_mandatory_value(self, section: str, key: str, **kwargs) -> str:
- value = self.get(section, key, **kwargs)
+ value = self.get(section, key, _extra_stacklevel=1, **kwargs)
if value is None:
raise ValueError(f"The value {section}/{key} should be set!")
return value
- def get(self, section: str, key: str, **kwargs) -> Optional[str]: # type: ignore[override]
+ @overload # type: ignore[override]
+ def get(self, section: str, key: str, fallback: str = ..., **kwargs) -> str: # type: ignore[override]
+
+ ...
+
+ @overload # type: ignore[override]
+ def get(self, section: str, key: str, **kwargs) -> str | None: # type: ignore[override]
+
+ ...
+
+ def get( # type: ignore[override, misc]
+ self,
+ section: str,
+ key: str,
+ _extra_stacklevel: int = 0,
+ **kwargs,
+ ) -> str | None:
section = str(section).lower()
key = str(key).lower()
+ warning_emitted = False
+ deprecated_section: str | None
+ deprecated_key: str | None
+
+ # For when we rename whole sections
+ if section in self.inversed_deprecated_sections:
+ deprecated_section, deprecated_key = (section, key)
+ section = self.inversed_deprecated_sections[section]
+ if not self._suppress_future_warnings:
+ warnings.warn(
+ f"The config section [{deprecated_section}] has been renamed to "
+ f"[{section}]. Please update your `conf.get*` call to use the new name",
+ FutureWarning,
+ stacklevel=2 + _extra_stacklevel,
+ )
+ # Don't warn about individual rename if the whole section is renamed
+ warning_emitted = True
+ elif (section, key) in self.inversed_deprecated_options:
+ # Handle using deprecated section/key instead of the new section/key
+ new_section, new_key = self.inversed_deprecated_options[(section, key)]
+ if not self._suppress_future_warnings and not warning_emitted:
+ warnings.warn(
+ f"section/key [{section}/{key}] has been deprecated, you should use"
+ f"[{new_section}/{new_key}] instead. Please update your `conf.get*` call to use the "
+ "new name",
+ FutureWarning,
+ stacklevel=2 + _extra_stacklevel,
+ )
+ warning_emitted = True
+ deprecated_section, deprecated_key = section, key
+ section, key = (new_section, new_key)
+ elif section in self.deprecated_sections:
+ # When accessing the new section name, make sure we check under the old config name
+ deprecated_key = key
+ deprecated_section = self.deprecated_sections[section][0]
+ else:
+ deprecated_section, deprecated_key, _ = self.deprecated_options.get(
+ (section, key), (None, None, None)
+ )
- deprecated_section, deprecated_key, _ = self.deprecated_options.get(
- (section, key), (None, None, None)
+ # first check environment variables
+ option = self._get_environment_variables(
+ deprecated_key,
+ deprecated_section,
+ key,
+ section,
+ issue_warning=not warning_emitted,
+ extra_stacklevel=_extra_stacklevel,
)
-
- option = self._get_environment_variables(deprecated_key, deprecated_section, key, section)
if option is not None:
return option
- option = self._get_option_from_config_file(deprecated_key, deprecated_section, key, kwargs, section)
+ # ...then the config file
+ option = self._get_option_from_config_file(
+ deprecated_key,
+ deprecated_section,
+ key,
+ kwargs,
+ section,
+ issue_warning=not warning_emitted,
+ extra_stacklevel=_extra_stacklevel,
+ )
if option is not None:
return option
- option = self._get_option_from_commands(deprecated_key, deprecated_section, key, section)
+ # ...then commands
+ option = self._get_option_from_commands(
+ deprecated_key,
+ deprecated_section,
+ key,
+ section,
+ issue_warning=not warning_emitted,
+ extra_stacklevel=_extra_stacklevel,
+ )
if option is not None:
return option
- option = self._get_option_from_secrets(deprecated_key, deprecated_section, key, section)
+ # ...then from secret backends
+ option = self._get_option_from_secrets(
+ deprecated_key,
+ deprecated_section,
+ key,
+ section,
+ issue_warning=not warning_emitted,
+ extra_stacklevel=_extra_stacklevel,
+ )
if option is not None:
return option
- return self._get_option_from_default_config(section, key, **kwargs)
-
- def _get_option_from_default_config(self, section: str, key: str, **kwargs) -> Optional[str]:
# ...then the default config
- if self.airflow_defaults.has_option(section, key) or 'fallback' in kwargs:
+ if self.airflow_defaults.has_option(section, key) or "fallback" in kwargs:
return expand_env_var(self.airflow_defaults.get(section, key, **kwargs))
- else:
- log.warning("section/key [%s/%s] not found in config", section, key)
+ log.warning("section/key [%s/%s] not found in config", section, key)
- raise AirflowConfigException(f"section/key [{section}/{key}] not found in config")
+ raise AirflowConfigException(f"section/key [{section}/{key}] not found in config")
def _get_option_from_secrets(
- self, deprecated_key: Optional[str], deprecated_section: Optional[str], key: str, section: str
- ) -> Optional[str]:
- # ...then from secret backends
+ self,
+ deprecated_key: str | None,
+ deprecated_section: str | None,
+ key: str,
+ section: str,
+ issue_warning: bool = True,
+ extra_stacklevel: int = 0,
+ ) -> str | None:
option = self._get_secret_option(section, key)
if option:
return option
if deprecated_section and deprecated_key:
- option = self._get_secret_option(deprecated_section, deprecated_key)
+ with self.suppress_future_warnings():
+ option = self._get_secret_option(deprecated_section, deprecated_key)
if option:
- self._warn_deprecate(section, key, deprecated_section, deprecated_key)
+ if issue_warning:
+ self._warn_deprecate(section, key, deprecated_section, deprecated_key, extra_stacklevel)
return option
return None
def _get_option_from_commands(
- self, deprecated_key: Optional[str], deprecated_section: Optional[str], key: str, section: str
- ) -> Optional[str]:
- # ...then commands
+ self,
+ deprecated_key: str | None,
+ deprecated_section: str | None,
+ key: str,
+ section: str,
+ issue_warning: bool = True,
+ extra_stacklevel: int = 0,
+ ) -> str | None:
option = self._get_cmd_option(section, key)
if option:
return option
if deprecated_section and deprecated_key:
- option = self._get_cmd_option(deprecated_section, deprecated_key)
+ with self.suppress_future_warnings():
+ option = self._get_cmd_option(deprecated_section, deprecated_key)
if option:
- self._warn_deprecate(section, key, deprecated_section, deprecated_key)
+ if issue_warning:
+ self._warn_deprecate(section, key, deprecated_section, deprecated_key, extra_stacklevel)
return option
return None
def _get_option_from_config_file(
self,
- deprecated_key: Optional[str],
- deprecated_section: Optional[str],
+ deprecated_key: str | None,
+ deprecated_section: str | None,
key: str,
- kwargs: Dict[str, Any],
+ kwargs: dict[str, Any],
section: str,
- ) -> Optional[str]:
- # ...then the config file
+ issue_warning: bool = True,
+ extra_stacklevel: int = 0,
+ ) -> str | None:
if super().has_option(section, key):
# Use the parent's methods to get the actual config here to be able to
# separate the config from default config.
return expand_env_var(super().get(section, key, **kwargs))
if deprecated_section and deprecated_key:
if super().has_option(deprecated_section, deprecated_key):
- self._warn_deprecate(section, key, deprecated_section, deprecated_key)
- return expand_env_var(super().get(deprecated_section, deprecated_key, **kwargs))
+ if issue_warning:
+ self._warn_deprecate(section, key, deprecated_section, deprecated_key, extra_stacklevel)
+ with self.suppress_future_warnings():
+ return expand_env_var(super().get(deprecated_section, deprecated_key, **kwargs))
return None
def _get_environment_variables(
- self, deprecated_key: Optional[str], deprecated_section: Optional[str], key: str, section: str
- ) -> Optional[str]:
- # first check environment variables
+ self,
+ deprecated_key: str | None,
+ deprecated_section: str | None,
+ key: str,
+ section: str,
+ issue_warning: bool = True,
+ extra_stacklevel: int = 0,
+ ) -> str | None:
option = self._get_env_var_option(section, key)
if option is not None:
return option
if deprecated_section and deprecated_key:
- option = self._get_env_var_option(deprecated_section, deprecated_key)
+ with self.suppress_future_warnings():
+ option = self._get_env_var_option(deprecated_section, deprecated_key)
if option is not None:
- self._warn_deprecate(section, key, deprecated_section, deprecated_key)
+ if issue_warning:
+ self._warn_deprecate(section, key, deprecated_section, deprecated_key, extra_stacklevel)
return option
return None
def getboolean(self, section: str, key: str, **kwargs) -> bool: # type: ignore[override]
- val = str(self.get(section, key, **kwargs)).lower().strip()
- if '#' in val:
- val = val.split('#')[0].strip()
- if val in ('t', 'true', '1'):
+ val = str(self.get(section, key, _extra_stacklevel=1, **kwargs)).lower().strip()
+ if "#" in val:
+ val = val.split("#")[0].strip()
+ if val in ("t", "true", "1"):
return True
- elif val in ('f', 'false', '0'):
+ elif val in ("f", "false", "0"):
return False
else:
raise AirflowConfigException(
@@ -627,10 +764,10 @@ def getboolean(self, section: str, key: str, **kwargs) -> bool: # type: ignore[
)
def getint(self, section: str, key: str, **kwargs) -> int: # type: ignore[override]
- val = self.get(section, key, **kwargs)
+ val = self.get(section, key, _extra_stacklevel=1, **kwargs)
if val is None:
raise AirflowConfigException(
- f'Failed to convert value None to int. '
+ f"Failed to convert value None to int. "
f'Please check "{key}" key in "{section}" section is set.'
)
try:
@@ -642,10 +779,10 @@ def getint(self, section: str, key: str, **kwargs) -> int: # type: ignore[overr
)
def getfloat(self, section: str, key: str, **kwargs) -> float: # type: ignore[override]
- val = self.get(section, key, **kwargs)
+ val = self.get(section, key, _extra_stacklevel=1, **kwargs)
if val is None:
raise AirflowConfigException(
- f'Failed to convert value None to float. '
+ f"Failed to convert value None to float. "
f'Please check "{key}" key in "{section}" section is set.'
)
try:
@@ -679,7 +816,7 @@ def getimport(self, section: str, key: str, **kwargs) -> Any:
def getjson(
self, section: str, key: str, fallback=_UNSET, **kwargs
- ) -> Union[dict, list, str, int, float, None]:
+ ) -> dict | list | str | int | float | None:
"""
Return a config value parsed from a JSON string.
@@ -693,7 +830,7 @@ def getjson(
fallback = _UNSET
try:
- data = self.get(section=section, key=key, fallback=fallback, **kwargs)
+ data = self.get(section=section, key=key, fallback=fallback, _extra_stacklevel=1, **kwargs)
except (NoSectionError, NoOptionError):
return default
@@ -703,13 +840,14 @@ def getjson(
try:
return json.loads(data)
except JSONDecodeError as e:
- raise AirflowConfigException(f'Unable to parse [{section}] {key!r} as valid json') from e
+ raise AirflowConfigException(f"Unable to parse [{section}] {key!r} as valid json") from e
def gettimedelta(
self, section: str, key: str, fallback: Any = None, **kwargs
- ) -> Optional[datetime.timedelta]:
+ ) -> datetime.timedelta | None:
"""
Gets the config value for the given section and key, and converts it into datetime.timedelta object.
+
If the key is missing, then it is considered as `None`.
:param section: the section from the config
@@ -718,7 +856,7 @@ def gettimedelta(
:raises AirflowConfigException: raised because ValueError or OverflowError
:return: datetime.timedelta(seconds=) or None
"""
- val = self.get(section, key, fallback=fallback, **kwargs)
+ val = self.get(section, key, fallback=fallback, _extra_stacklevel=1, **kwargs)
if val:
# the given value must be convertible to integer
@@ -734,8 +872,8 @@ def gettimedelta(
return datetime.timedelta(seconds=int_val)
except OverflowError as err:
raise AirflowConfigException(
- f'Failed to convert value to timedelta in `seconds`. '
- f'{err}. '
+ f"Failed to convert value to timedelta in `seconds`. "
+ f"{err}. "
f'Please check "{key}" key in "{section}" section. Current value: "{val}".'
)
@@ -743,12 +881,7 @@ def gettimedelta(
def read(
self,
- filenames: Union[
- str,
- bytes,
- os.PathLike,
- Iterable[Union[str, bytes, os.PathLike]],
- ],
+ filenames: (str | bytes | os.PathLike | Iterable[str | bytes | os.PathLike]),
encoding=None,
):
super().read(filenames=filenames, encoding=encoding)
@@ -756,7 +889,7 @@ def read(
# The RawConfigParser defines "Mapping" from abc.collections is not subscriptable - so we have
# to use Dict here.
def read_dict( # type: ignore[override]
- self, dictionary: Dict[str, Dict[str, Any]], source: str = ''
+ self, dictionary: dict[str, dict[str, Any]], source: str = ""
):
super().read_dict(dictionary=dictionary, source=source)
@@ -765,16 +898,17 @@ def has_option(self, section: str, option: str) -> bool:
# Using self.get() to avoid reimplementing the priority order
# of config variables (env, config, cmd, defaults)
# UNSET to avoid logging a warning about missing values
- self.get(section, option, fallback=_UNSET)
+ self.get(section, option, fallback=_UNSET, _extra_stacklevel=1)
return True
except (NoOptionError, NoSectionError):
return False
def remove_option(self, section: str, option: str, remove_default: bool = True):
"""
- Remove an option if it exists in config from a file or
- default config. If both of config have the same option, this removes
- the option in both configs unless remove_default=False.
+ Remove an option if it exists in config from a file or default config.
+
+ If both of config have the same option, this removes the option
+ in both configs unless remove_default=False.
"""
if super().has_option(section, option):
super().remove_option(section, option)
@@ -782,13 +916,13 @@ def remove_option(self, section: str, option: str, remove_default: bool = True):
if self.airflow_defaults.has_option(section, option) and remove_default:
self.airflow_defaults.remove_option(section, option)
- def getsection(self, section: str) -> Optional[ConfigOptionsDictType]:
+ def getsection(self, section: str) -> ConfigOptionsDictType | None:
"""
- Returns the section as a dict. Values are converted to int, float, bool
- as required.
+ Returns the section as a dict.
+
+ Values are converted to int, float, bool as required.
:param section: section from the config
- :rtype: dict
"""
if not self.has_section(section) and not self.airflow_defaults.has_section(section):
return None
@@ -800,10 +934,10 @@ def getsection(self, section: str) -> Optional[ConfigOptionsDictType]:
if self.has_section(section):
_section.update(OrderedDict(self.items(section)))
- section_prefix = self._env_var_name(section, '')
+ section_prefix = self._env_var_name(section, "")
for env_var in sorted(os.environ.keys()):
if env_var.startswith(section_prefix):
- key = env_var.replace(section_prefix, '')
+ key = env_var.replace(section_prefix, "")
if key.endswith("_CMD"):
key = key[:-4]
key = key.lower()
@@ -812,7 +946,7 @@ def getsection(self, section: str) -> Optional[ConfigOptionsDictType]:
for key, val in _section.items():
if val is None:
raise AirflowConfigException(
- f'Failed to convert value automatically. '
+ f"Failed to convert value automatically. "
f'Please check "{key}" key in "{section}" section is set.'
)
try:
@@ -821,9 +955,9 @@ def getsection(self, section: str) -> Optional[ConfigOptionsDictType]:
try:
_section[key] = float(val)
except ValueError:
- if isinstance(val, str) and val.lower() in ('t', 'true'):
+ if isinstance(val, str) and val.lower() in ("t", "true"):
_section[key] = True
- elif isinstance(val, str) and val.lower() in ('f', 'false'):
+ elif isinstance(val, str) and val.lower() in ("f", "false"):
_section[key] = False
return _section
@@ -880,14 +1014,22 @@ def as_dict(
:param include_secret: Should the result of calling any *_secret config be
set (True, default), or should the _secret options be left as the
path to get the secret from (False)
- :rtype: Dict[str, Dict[str, str]]
:return: Dictionary, where the key is the name of the section and the content is
the dictionary with the name of the parameter and its value.
"""
+ if not display_sensitive:
+ # We want to hide the sensitive values at the appropriate methods
+ # since envs from cmds, secrets can be read at _include_envs method
+ if not all([include_env, include_cmds, include_secret]):
+ raise ValueError(
+ "If display_sensitive is false, then include_env, "
+ "include_cmds, include_secret must all be set as True"
+ )
+
config_sources: ConfigSourcesType = {}
configs = [
- ('default', self.airflow_defaults),
- ('airflow.cfg', self),
+ ("default", self.airflow_defaults),
+ ("airflow.cfg", self),
]
self._replace_config_with_display_sources(
@@ -919,6 +1061,20 @@ def as_dict(
else:
self._filter_by_source(config_sources, display_source, self._get_secret_option)
+ if not display_sensitive:
+ # This ensures the ones from config file is hidden too
+ # if they are not provided through env, cmd and secret
+ hidden = "< hidden >"
+ for (section, key) in self.sensitive_config_values:
+ if not config_sources.get(section):
+ continue
+ if config_sources[section].get(key, None):
+ if display_source:
+ source = config_sources[section][key][1]
+ config_sources[section][key] = (hidden, source)
+ else:
+ config_sources[section][key] = hidden
+
return config_sources
def _include_secrets(
@@ -929,18 +1085,18 @@ def _include_secrets(
raw: bool,
):
for (section, key) in self.sensitive_config_values:
- value: Optional[str] = self._get_secret_option_from_config_sources(config_sources, section, key)
+ value: str | None = self._get_secret_option_from_config_sources(config_sources, section, key)
if value:
if not display_sensitive:
- value = '< hidden >'
+ value = "< hidden >"
if display_source:
- opt: Union[str, Tuple[str, str]] = (value, 'secret')
+ opt: str | tuple[str, str] = (value, "secret")
elif raw:
- opt = value.replace('%', '%%')
+ opt = value.replace("%", "%%")
else:
opt = value
config_sources.setdefault(section, OrderedDict()).update({key: opt})
- del config_sources[section][key + '_secret']
+ del config_sources[section][key + "_secret"]
def _include_commands(
self,
@@ -953,17 +1109,17 @@ def _include_commands(
opt = self._get_cmd_option_from_config_sources(config_sources, section, key)
if not opt:
continue
- opt_to_set: Union[str, Tuple[str, str], None] = opt
+ opt_to_set: str | tuple[str, str] | None = opt
if not display_sensitive:
- opt_to_set = '< hidden >'
+ opt_to_set = "< hidden >"
if display_source:
- opt_to_set = (str(opt_to_set), 'cmd')
+ opt_to_set = (str(opt_to_set), "cmd")
elif raw:
- opt_to_set = str(opt_to_set).replace('%', '%%')
+ opt_to_set = str(opt_to_set).replace("%", "%%")
if opt_to_set is not None:
- dict_to_update: Dict[str, Union[str, Tuple[str, str]]] = {key: opt_to_set}
+ dict_to_update: dict[str, str | tuple[str, str]] = {key: opt_to_set}
config_sources.setdefault(section, OrderedDict()).update(dict_to_update)
- del config_sources[section][key + '_cmd']
+ del config_sources[section][key + "_cmd"]
def _include_envs(
self,
@@ -976,26 +1132,29 @@ def _include_envs(
os_environment for os_environment in os.environ if os_environment.startswith(ENV_VAR_PREFIX)
]:
try:
- _, section, key = env_var.split('__', 2)
+ _, section, key = env_var.split("__", 2)
opt = self._get_env_var_option(section, key)
except ValueError:
continue
if opt is None:
log.warning("Ignoring unknown env var '%s'", env_var)
continue
- if not display_sensitive and env_var != self._env_var_name('core', 'unit_test_mode'):
- opt = '< hidden >'
+ if not display_sensitive and env_var != self._env_var_name("core", "unit_test_mode"):
+ # Don't hide cmd/secret values here
+ if not env_var.lower().endswith("cmd") and not env_var.lower().endswith("secret"):
+ if (section, key) in self.sensitive_config_values:
+ opt = "< hidden >"
elif raw:
- opt = opt.replace('%', '%%')
+ opt = opt.replace("%", "%%")
if display_source:
- opt = (opt, 'env var')
+ opt = (opt, "env var")
section = section.lower()
# if we lower key for kubernetes_environment_variables section,
# then we won't be able to set any Airflow environment
# variables. Airflow only parse environment variables starts
# with AIRFLOW_. Therefore, we need to make it a special case.
- if section != 'kubernetes_environment_variables':
+ if section != "kubernetes_environment_variables":
key = key.lower()
config_sources.setdefault(section, OrderedDict()).update({key: opt})
@@ -1006,8 +1165,9 @@ def _filter_by_source(
getter_func,
):
"""
- Deletes default configs from current configuration (an OrderedDict of
- OrderedDicts) if it would conflict with special sensitive_config_values.
+ Deletes default configs from current configuration.
+
+ An OrderedDict of OrderedDicts, if it would conflict with special sensitive_config_values.
This is necessary because bare configs take precedence over the command
or secret key equivalents so if the current running config is
@@ -1020,7 +1180,6 @@ def _filter_by_source(
Source is either 'airflow.cfg', 'default', 'env var', or 'cmd'.
:param getter_func: A callback function that gets the user configured
override value for a particular sensitive_config_values config.
- :rtype: None
:return: None, the given config_sources is filtered if necessary,
otherwise untouched.
"""
@@ -1050,10 +1209,10 @@ def _filter_by_source(
@staticmethod
def _replace_config_with_display_sources(
config_sources: ConfigSourcesType,
- configs: Iterable[Tuple[str, ConfigParser]],
+ configs: Iterable[tuple[str, ConfigParser]],
display_source: bool,
raw: bool,
- deprecated_options: Dict[Tuple[str, str], Tuple[str, str, str]],
+ deprecated_options: dict[tuple[str, str], tuple[str, str, str]],
include_env: bool,
include_cmds: bool,
include_secret: bool,
@@ -1078,10 +1237,10 @@ def _replace_config_with_display_sources(
def _deprecated_value_is_set_in_config(
deprecated_section: str,
deprecated_key: str,
- configs: Iterable[Tuple[str, ConfigParser]],
+ configs: Iterable[tuple[str, ConfigParser]],
) -> bool:
for config_type, config in configs:
- if config_type == 'default':
+ if config_type == "default":
continue
try:
deprecated_section_array = config.items(section=deprecated_section, raw=True)
@@ -1095,13 +1254,13 @@ def _deprecated_value_is_set_in_config(
@staticmethod
def _deprecated_variable_is_set(deprecated_section: str, deprecated_key: str) -> bool:
return (
- os.environ.get(f'{ENV_VAR_PREFIX}{deprecated_section.upper()}__{deprecated_key.upper()}')
+ os.environ.get(f"{ENV_VAR_PREFIX}{deprecated_section.upper()}__{deprecated_key.upper()}")
is not None
)
@staticmethod
def _deprecated_command_is_set_in_config(
- deprecated_section: str, deprecated_key: str, configs: Iterable[Tuple[str, ConfigParser]]
+ deprecated_section: str, deprecated_key: str, configs: Iterable[tuple[str, ConfigParser]]
) -> bool:
return AirflowConfigParser._deprecated_value_is_set_in_config(
deprecated_section=deprecated_section, deprecated_key=deprecated_key + "_cmd", configs=configs
@@ -1110,13 +1269,13 @@ def _deprecated_command_is_set_in_config(
@staticmethod
def _deprecated_variable_command_is_set(deprecated_section: str, deprecated_key: str) -> bool:
return (
- os.environ.get(f'{ENV_VAR_PREFIX}{deprecated_section.upper()}__{deprecated_key.upper()}_CMD')
+ os.environ.get(f"{ENV_VAR_PREFIX}{deprecated_section.upper()}__{deprecated_key.upper()}_CMD")
is not None
)
@staticmethod
def _deprecated_secret_is_set_in_config(
- deprecated_section: str, deprecated_key: str, configs: Iterable[Tuple[str, ConfigParser]]
+ deprecated_section: str, deprecated_key: str, configs: Iterable[tuple[str, ConfigParser]]
) -> bool:
return AirflowConfigParser._deprecated_value_is_set_in_config(
deprecated_section=deprecated_section, deprecated_key=deprecated_key + "_secret", configs=configs
@@ -1125,10 +1284,17 @@ def _deprecated_secret_is_set_in_config(
@staticmethod
def _deprecated_variable_secret_is_set(deprecated_section: str, deprecated_key: str) -> bool:
return (
- os.environ.get(f'{ENV_VAR_PREFIX}{deprecated_section.upper()}__{deprecated_key.upper()}_SECRET')
+ os.environ.get(f"{ENV_VAR_PREFIX}{deprecated_section.upper()}__{deprecated_key.upper()}_SECRET")
is not None
)
+ @contextmanager
+ def suppress_future_warnings(self):
+ suppress_future_warnings = self._suppress_future_warnings
+ self._suppress_future_warnings = True
+ yield self
+ self._suppress_future_warnings = suppress_future_warnings
+
@staticmethod
def _replace_section_config_with_display_sources(
config: ConfigParser,
@@ -1137,17 +1303,22 @@ def _replace_section_config_with_display_sources(
raw: bool,
section: str,
source_name: str,
- deprecated_options: Dict[Tuple[str, str], Tuple[str, str, str]],
- configs: Iterable[Tuple[str, ConfigParser]],
+ deprecated_options: dict[tuple[str, str], tuple[str, str, str]],
+ configs: Iterable[tuple[str, ConfigParser]],
include_env: bool,
include_cmds: bool,
include_secret: bool,
):
sect = config_sources.setdefault(section, OrderedDict())
- for (k, val) in config.items(section=section, raw=raw):
+ if isinstance(config, AirflowConfigParser):
+ with config.suppress_future_warnings():
+ items = config.items(section=section, raw=raw)
+ else:
+ items = config.items(section=section, raw=raw)
+ for k, val in items:
deprecated_section, deprecated_key, _ = deprecated_options.get((section, k), (None, None, None))
if deprecated_section and deprecated_key:
- if source_name == 'default':
+ if source_name == "default":
# If deprecated entry has some non-default value set for any of the sources requested,
# We should NOT set default for the new entry (because it will override anything
# coming from the deprecated ones)
@@ -1194,62 +1365,64 @@ def load_test_config(self):
# then read test config
- path = _default_config_file_path('default_test.cfg')
+ path = _default_config_file_path("default_test.cfg")
log.info("Reading default test configuration from %s", path)
- self.read_string(_parameterized_config_from_template('default_test.cfg'))
+ self.read_string(_parameterized_config_from_template("default_test.cfg"))
# then read any "custom" test settings
log.info("Reading test configuration from %s", TEST_CONFIG_FILE)
self.read(TEST_CONFIG_FILE)
@staticmethod
- def _warn_deprecate(section: str, key: str, deprecated_section: str, deprecated_name: str):
+ def _warn_deprecate(
+ section: str, key: str, deprecated_section: str, deprecated_name: str, extra_stacklevel: int
+ ):
if section == deprecated_section:
warnings.warn(
- f'The {deprecated_name} option in [{section}] has been renamed to {key} - '
- f'the old setting has been used, but please update your config.',
+ f"The {deprecated_name} option in [{section}] has been renamed to {key} - "
+ f"the old setting has been used, but please update your config.",
DeprecationWarning,
- stacklevel=3,
+ stacklevel=4 + extra_stacklevel,
)
else:
warnings.warn(
- f'The {deprecated_name} option in [{deprecated_section}] has been moved to the {key} option '
- f'in [{section}] - the old setting has been used, but please update your config.',
+ f"The {deprecated_name} option in [{deprecated_section}] has been moved to the {key} option "
+ f"in [{section}] - the old setting has been used, but please update your config.",
DeprecationWarning,
- stacklevel=3,
+ stacklevel=4 + extra_stacklevel,
)
def __getstate__(self):
return {
name: getattr(self, name)
for name in [
- '_sections',
- 'is_validated',
- 'airflow_defaults',
+ "_sections",
+ "is_validated",
+ "airflow_defaults",
]
}
def __setstate__(self, state):
self.__init__()
- config = state.pop('_sections')
+ config = state.pop("_sections")
self.read_dict(config)
self.__dict__.update(state)
def get_airflow_home() -> str:
- """Get path to Airflow Home"""
- return expand_env_var(os.environ.get('AIRFLOW_HOME', '~/airflow'))
+ """Get path to Airflow Home."""
+ return expand_env_var(os.environ.get("AIRFLOW_HOME", "~/airflow"))
def get_airflow_config(airflow_home) -> str:
- """Get Path to airflow.cfg path"""
- airflow_config_var = os.environ.get('AIRFLOW_CONFIG')
+ """Get Path to airflow.cfg path."""
+ airflow_config_var = os.environ.get("AIRFLOW_CONFIG")
if airflow_config_var is None:
- return os.path.join(airflow_home, 'airflow.cfg')
+ return os.path.join(airflow_home, "airflow.cfg")
return expand_env_var(airflow_config_var)
def _parameterized_config_from_template(filename) -> str:
- TEMPLATE_START = '# ----------------------- TEMPLATE BEGINS HERE -----------------------\n'
+ TEMPLATE_START = "# ----------------------- TEMPLATE BEGINS HERE -----------------------\n"
path = _default_config_file_path(filename)
with open(path) as fh:
@@ -1262,8 +1435,7 @@ def _parameterized_config_from_template(filename) -> str:
def parameterized_config(template) -> str:
"""
- Generates a configuration from the provided template + variables defined in
- current scope
+ Generates configuration from provided template & variables defined in current scope.
:param template: a config content templated with {{variables}}
"""
@@ -1272,11 +1444,11 @@ def parameterized_config(template) -> str:
def get_airflow_test_config(airflow_home) -> str:
- """Get path to unittests.cfg"""
- if 'AIRFLOW_TEST_CONFIG' not in os.environ:
- return os.path.join(airflow_home, 'unittests.cfg')
+ """Get path to unittests.cfg."""
+ if "AIRFLOW_TEST_CONFIG" not in os.environ:
+ return os.path.join(airflow_home, "unittests.cfg")
# It will never return None
- return expand_env_var(os.environ['AIRFLOW_TEST_CONFIG']) # type: ignore[return-value]
+ return expand_env_var(os.environ["AIRFLOW_TEST_CONFIG"]) # type: ignore[return-value]
def _generate_fernet_key() -> str:
@@ -1293,22 +1465,22 @@ def initialize_config() -> AirflowConfigParser:
"""
global FERNET_KEY, AIRFLOW_HOME
- default_config = _parameterized_config_from_template('default_airflow.cfg')
+ default_config = _parameterized_config_from_template("default_airflow.cfg")
local_conf = AirflowConfigParser(default_config=default_config)
- if local_conf.getboolean('core', 'unit_test_mode'):
+ if local_conf.getboolean("core", "unit_test_mode"):
# Load test config only
if not os.path.isfile(TEST_CONFIG_FILE):
from cryptography.fernet import Fernet
- log.info('Creating new Airflow config file for unit tests in: %s', TEST_CONFIG_FILE)
+ log.info("Creating new Airflow config file for unit tests in: %s", TEST_CONFIG_FILE)
pathlib.Path(AIRFLOW_HOME).mkdir(parents=True, exist_ok=True)
FERNET_KEY = Fernet.generate_key().decode()
- with open(TEST_CONFIG_FILE, 'w') as file:
- cfg = _parameterized_config_from_template('default_test.cfg')
+ with open(TEST_CONFIG_FILE, "w") as file:
+ cfg = _parameterized_config_from_template("default_test.cfg")
file.write(cfg)
local_conf.load_test_config()
@@ -1317,58 +1489,58 @@ def initialize_config() -> AirflowConfigParser:
if not os.path.isfile(AIRFLOW_CONFIG):
from cryptography.fernet import Fernet
- log.info('Creating new Airflow config file in: %s', AIRFLOW_CONFIG)
+ log.info("Creating new Airflow config file in: %s", AIRFLOW_CONFIG)
pathlib.Path(AIRFLOW_HOME).mkdir(parents=True, exist_ok=True)
FERNET_KEY = Fernet.generate_key().decode()
- with open(AIRFLOW_CONFIG, 'w') as file:
+ with open(AIRFLOW_CONFIG, "w") as file:
file.write(default_config)
log.info("Reading the config from %s", AIRFLOW_CONFIG)
local_conf.read(AIRFLOW_CONFIG)
- if local_conf.has_option('core', 'AIRFLOW_HOME'):
+ if local_conf.has_option("core", "AIRFLOW_HOME"):
msg = (
- 'Specifying both AIRFLOW_HOME environment variable and airflow_home '
- 'in the config file is deprecated. Please use only the AIRFLOW_HOME '
- 'environment variable and remove the config file entry.'
+ "Specifying both AIRFLOW_HOME environment variable and airflow_home "
+ "in the config file is deprecated. Please use only the AIRFLOW_HOME "
+ "environment variable and remove the config file entry."
)
- if 'AIRFLOW_HOME' in os.environ:
+ if "AIRFLOW_HOME" in os.environ:
warnings.warn(msg, category=DeprecationWarning)
- elif local_conf.get('core', 'airflow_home') == AIRFLOW_HOME:
+ elif local_conf.get("core", "airflow_home") == AIRFLOW_HOME:
warnings.warn(
- 'Specifying airflow_home in the config file is deprecated. As you '
- 'have left it at the default value you should remove the setting '
- 'from your airflow.cfg and suffer no change in behaviour.',
+ "Specifying airflow_home in the config file is deprecated. As you "
+ "have left it at the default value you should remove the setting "
+ "from your airflow.cfg and suffer no change in behaviour.",
category=DeprecationWarning,
)
else:
# there
- AIRFLOW_HOME = local_conf.get('core', 'airflow_home') # type: ignore[assignment]
+ AIRFLOW_HOME = local_conf.get("core", "airflow_home") # type: ignore[assignment]
warnings.warn(msg, category=DeprecationWarning)
# They _might_ have set unit_test_mode in the airflow.cfg, we still
# want to respect that and then load the unittests.cfg
- if local_conf.getboolean('core', 'unit_test_mode'):
+ if local_conf.getboolean("core", "unit_test_mode"):
local_conf.load_test_config()
# Make it no longer a proxy variable, just set it to an actual string
global WEBSERVER_CONFIG
- WEBSERVER_CONFIG = AIRFLOW_HOME + '/webserver_config.py'
+ WEBSERVER_CONFIG = AIRFLOW_HOME + "/webserver_config.py"
if not os.path.isfile(WEBSERVER_CONFIG):
import shutil
- log.info('Creating new FAB webserver config file in: %s', WEBSERVER_CONFIG)
- shutil.copy(_default_config_file_path('default_webserver_config.py'), WEBSERVER_CONFIG)
+ log.info("Creating new FAB webserver config file in: %s", WEBSERVER_CONFIG)
+ shutil.copy(_default_config_file_path("default_webserver_config.py"), WEBSERVER_CONFIG)
return local_conf
# Historical convenience functions to access config entries
def load_test_config():
- """Historical load_test_config"""
+ """Historical load_test_config."""
warnings.warn(
"Accessing configuration method 'load_test_config' directly from the configuration module is "
"deprecated. Please access the configuration from the 'configuration.conf' object via "
@@ -1379,8 +1551,8 @@ def load_test_config():
conf.load_test_config()
-def get(*args, **kwargs) -> Optional[ConfigType]:
- """Historical get"""
+def get(*args, **kwargs) -> ConfigType | None:
+ """Historical get."""
warnings.warn(
"Accessing configuration method 'get' directly from the configuration module is "
"deprecated. Please access the configuration from the 'configuration.conf' object via "
@@ -1392,7 +1564,7 @@ def get(*args, **kwargs) -> Optional[ConfigType]:
def getboolean(*args, **kwargs) -> bool:
- """Historical getboolean"""
+ """Historical getboolean."""
warnings.warn(
"Accessing configuration method 'getboolean' directly from the configuration module is "
"deprecated. Please access the configuration from the 'configuration.conf' object via "
@@ -1404,7 +1576,7 @@ def getboolean(*args, **kwargs) -> bool:
def getfloat(*args, **kwargs) -> float:
- """Historical getfloat"""
+ """Historical getfloat."""
warnings.warn(
"Accessing configuration method 'getfloat' directly from the configuration module is "
"deprecated. Please access the configuration from the 'configuration.conf' object via "
@@ -1416,7 +1588,7 @@ def getfloat(*args, **kwargs) -> float:
def getint(*args, **kwargs) -> int:
- """Historical getint"""
+ """Historical getint."""
warnings.warn(
"Accessing configuration method 'getint' directly from the configuration module is "
"deprecated. Please access the configuration from the 'configuration.conf' object via "
@@ -1427,8 +1599,8 @@ def getint(*args, **kwargs) -> int:
return conf.getint(*args, **kwargs)
-def getsection(*args, **kwargs) -> Optional[ConfigOptionsDictType]:
- """Historical getsection"""
+def getsection(*args, **kwargs) -> ConfigOptionsDictType | None:
+ """Historical getsection."""
warnings.warn(
"Accessing configuration method 'getsection' directly from the configuration module is "
"deprecated. Please access the configuration from the 'configuration.conf' object via "
@@ -1440,7 +1612,7 @@ def getsection(*args, **kwargs) -> Optional[ConfigOptionsDictType]:
def has_option(*args, **kwargs) -> bool:
- """Historical has_option"""
+ """Historical has_option."""
warnings.warn(
"Accessing configuration method 'has_option' directly from the configuration module is "
"deprecated. Please access the configuration from the 'configuration.conf' object via "
@@ -1452,7 +1624,7 @@ def has_option(*args, **kwargs) -> bool:
def remove_option(*args, **kwargs) -> bool:
- """Historical remove_option"""
+ """Historical remove_option."""
warnings.warn(
"Accessing configuration method 'remove_option' directly from the configuration module is "
"deprecated. Please access the configuration from the 'configuration.conf' object via "
@@ -1464,7 +1636,7 @@ def remove_option(*args, **kwargs) -> bool:
def as_dict(*args, **kwargs) -> ConfigSourcesType:
- """Historical as_dict"""
+ """Historical as_dict."""
warnings.warn(
"Accessing configuration method 'as_dict' directly from the configuration module is "
"deprecated. Please access the configuration from the 'configuration.conf' object via "
@@ -1476,7 +1648,7 @@ def as_dict(*args, **kwargs) -> ConfigSourcesType:
def set(*args, **kwargs) -> None:
- """Historical set"""
+ """Historical set."""
warnings.warn(
"Accessing configuration method 'set' directly from the configuration module is "
"deprecated. Please access the configuration from the 'configuration.conf' object via "
@@ -1487,7 +1659,7 @@ def set(*args, **kwargs) -> None:
conf.set(*args, **kwargs)
-def ensure_secrets_loaded() -> List[BaseSecretsBackend]:
+def ensure_secrets_loaded() -> list[BaseSecretsBackend]:
"""
Ensure that all secrets backends are loaded.
If the secrets_backend_list contains only 2 default backends, reload it.
@@ -1498,23 +1670,33 @@ def ensure_secrets_loaded() -> List[BaseSecretsBackend]:
return secrets_backend_list
-def get_custom_secret_backend() -> Optional[BaseSecretsBackend]:
- """Get Secret Backend if defined in airflow.cfg"""
- secrets_backend_cls = conf.getimport(section='secrets', key='backend')
-
- if secrets_backend_cls:
- try:
- backends: Any = conf.get(section='secrets', key='backend_kwargs', fallback='{}')
- alternative_secrets_config_dict = json.loads(backends)
- except JSONDecodeError:
- alternative_secrets_config_dict = {}
-
- return secrets_backend_cls(**alternative_secrets_config_dict)
- return None
+def get_custom_secret_backend() -> BaseSecretsBackend | None:
+ """Get Secret Backend if defined in airflow.cfg."""
+ secrets_backend_cls = conf.getimport(section="secrets", key="backend")
+ if not secrets_backend_cls:
+ return None
-def initialize_secrets_backends() -> List[BaseSecretsBackend]:
+ try:
+ backend_kwargs = conf.getjson(section="secrets", key="backend_kwargs")
+ if not backend_kwargs:
+ backend_kwargs = {}
+ elif not isinstance(backend_kwargs, dict):
+ raise ValueError("not a dict")
+ except AirflowConfigException:
+ log.warning("Failed to parse [secrets] backend_kwargs as JSON, defaulting to no kwargs.")
+ backend_kwargs = {}
+ except ValueError:
+ log.warning("Failed to parse [secrets] backend_kwargs into a dict, defaulting to no kwargs.")
+ backend_kwargs = {}
+
+ return secrets_backend_cls(**backend_kwargs)
+
+
+def initialize_secrets_backends() -> list[BaseSecretsBackend]:
"""
+ Initialize secrets backend.
+
* import secrets backend classes
* instantiate them and return them in a list
"""
@@ -1534,23 +1716,23 @@ def initialize_secrets_backends() -> List[BaseSecretsBackend]:
@functools.lru_cache(maxsize=None)
def _DEFAULT_CONFIG() -> str:
- path = _default_config_file_path('default_airflow.cfg')
+ path = _default_config_file_path("default_airflow.cfg")
with open(path) as fh:
return fh.read()
@functools.lru_cache(maxsize=None)
def _TEST_CONFIG() -> str:
- path = _default_config_file_path('default_test.cfg')
+ path = _default_config_file_path("default_test.cfg")
with open(path) as fh:
return fh.read()
_deprecated = {
- 'DEFAULT_CONFIG': _DEFAULT_CONFIG,
- 'TEST_CONFIG': _TEST_CONFIG,
- 'TEST_CONFIG_FILE_PATH': functools.partial(_default_config_file_path, 'default_test.cfg'),
- 'DEFAULT_CONFIG_FILE_PATH': functools.partial(_default_config_file_path, 'default_airflow.cfg'),
+ "DEFAULT_CONFIG": _DEFAULT_CONFIG,
+ "TEST_CONFIG": _TEST_CONFIG,
+ "TEST_CONFIG_FILE_PATH": functools.partial(_default_config_file_path, "default_test.cfg"),
+ "DEFAULT_CONFIG_FILE_PATH": functools.partial(_default_config_file_path, "default_airflow.cfg"),
}
@@ -1574,28 +1756,28 @@ def __getattr__(name):
# Set up dags folder for unit tests
# this directory won't exist if users install via pip
_TEST_DAGS_FOLDER = os.path.join(
- os.path.dirname(os.path.dirname(os.path.realpath(__file__))), 'tests', 'dags'
+ os.path.dirname(os.path.dirname(os.path.realpath(__file__))), "tests", "dags"
)
if os.path.exists(_TEST_DAGS_FOLDER):
TEST_DAGS_FOLDER = _TEST_DAGS_FOLDER
else:
- TEST_DAGS_FOLDER = os.path.join(AIRFLOW_HOME, 'dags')
+ TEST_DAGS_FOLDER = os.path.join(AIRFLOW_HOME, "dags")
# Set up plugins folder for unit tests
_TEST_PLUGINS_FOLDER = os.path.join(
- os.path.dirname(os.path.dirname(os.path.realpath(__file__))), 'tests', 'plugins'
+ os.path.dirname(os.path.dirname(os.path.realpath(__file__))), "tests", "plugins"
)
if os.path.exists(_TEST_PLUGINS_FOLDER):
TEST_PLUGINS_FOLDER = _TEST_PLUGINS_FOLDER
else:
- TEST_PLUGINS_FOLDER = os.path.join(AIRFLOW_HOME, 'plugins')
+ TEST_PLUGINS_FOLDER = os.path.join(AIRFLOW_HOME, "plugins")
TEST_CONFIG_FILE = get_airflow_test_config(AIRFLOW_HOME)
-SECRET_KEY = b64encode(os.urandom(16)).decode('utf-8')
-FERNET_KEY = '' # Set only if needed when generating a new file
-WEBSERVER_CONFIG = '' # Set by initialize_config
+SECRET_KEY = b64encode(os.urandom(16)).decode("utf-8")
+FERNET_KEY = "" # Set only if needed when generating a new file
+WEBSERVER_CONFIG = "" # Set by initialize_config
conf = initialize_config()
secrets_backend_list = initialize_secrets_backends()
diff --git a/airflow/contrib/hooks/__init__.py b/airflow/contrib/hooks/__init__.py
index c861d649e4a49..cffc937097306 100644
--- a/airflow/contrib/hooks/__init__.py
+++ b/airflow/contrib/hooks/__init__.py
@@ -15,13 +15,323 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-#
"""This package is deprecated. Please use `airflow.hooks` or `airflow.providers.*.hooks`."""
+from __future__ import annotations
import warnings
+from airflow.exceptions import RemovedInAirflow3Warning
+from airflow.utils.deprecation_tools import add_deprecated_classes
+
warnings.warn(
"This package is deprecated. Please use `airflow.hooks` or `airflow.providers.*.hooks`.",
- DeprecationWarning,
+ RemovedInAirflow3Warning,
stacklevel=2,
)
+
+__deprecated_classes = {
+ 'aws_athena_hook': {
+ 'AWSAthenaHook': 'airflow.providers.amazon.aws.hooks.athena.AthenaHook',
+ },
+ 'aws_datasync_hook': {
+ 'AWSDataSyncHook': 'airflow.providers.amazon.aws.hooks.datasync.DataSyncHook',
+ },
+ 'aws_dynamodb_hook': {
+ 'AwsDynamoDBHook': 'airflow.providers.amazon.aws.hooks.dynamodb.DynamoDBHook',
+ },
+ 'aws_firehose_hook': {
+ 'FirehoseHook': 'airflow.providers.amazon.aws.hooks.kinesis.FirehoseHook',
+ },
+ 'aws_glue_catalog_hook': {
+ 'AwsGlueCatalogHook': 'airflow.providers.amazon.aws.hooks.glue_catalog.GlueCatalogHook',
+ },
+ 'aws_hook': {
+ 'AwsBaseHook': 'airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook',
+ '_parse_s3_config': 'airflow.providers.amazon.aws.hooks.base_aws._parse_s3_config',
+ 'boto3': 'airflow.providers.amazon.aws.hooks.base_aws.boto3',
+ 'AwsHook': 'airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook',
+ },
+ 'aws_lambda_hook': {
+ 'AwsLambdaHook': 'airflow.providers.amazon.aws.hooks.lambda_function.LambdaHook',
+ },
+ 'aws_logs_hook': {
+ 'AwsLogsHook': 'airflow.providers.amazon.aws.hooks.logs.AwsLogsHook',
+ },
+ 'aws_sns_hook': {
+ 'AwsSnsHook': 'airflow.providers.amazon.aws.hooks.sns.SnsHook',
+ },
+ 'aws_sqs_hook': {
+ 'SqsHook': 'airflow.providers.amazon.aws.hooks.sqs.SqsHook',
+ 'SQSHook': 'airflow.providers.amazon.aws.hooks.sqs.SqsHook',
+ },
+ 'azure_container_instance_hook': {
+ 'AzureContainerInstanceHook': (
+ 'airflow.providers.microsoft.azure.hooks.container_instance.AzureContainerInstanceHook'
+ ),
+ },
+ 'azure_container_registry_hook': {
+ 'AzureContainerRegistryHook': (
+ 'airflow.providers.microsoft.azure.hooks.container_registry.AzureContainerRegistryHook'
+ ),
+ },
+ 'azure_container_volume_hook': {
+ 'AzureContainerVolumeHook': (
+ 'airflow.providers.microsoft.azure.hooks.container_volume.AzureContainerVolumeHook'
+ ),
+ },
+ 'azure_cosmos_hook': {
+ 'AzureCosmosDBHook': 'airflow.providers.microsoft.azure.hooks.cosmos.AzureCosmosDBHook',
+ },
+ 'azure_data_lake_hook': {
+ 'AzureDataLakeHook': 'airflow.providers.microsoft.azure.hooks.data_lake.AzureDataLakeHook',
+ },
+ 'azure_fileshare_hook': {
+ 'AzureFileShareHook': 'airflow.providers.microsoft.azure.hooks.fileshare.AzureFileShareHook',
+ },
+ 'bigquery_hook': {
+ 'BigQueryBaseCursor': 'airflow.providers.google.cloud.hooks.bigquery.BigQueryBaseCursor',
+ 'BigQueryConnection': 'airflow.providers.google.cloud.hooks.bigquery.BigQueryConnection',
+ 'BigQueryCursor': 'airflow.providers.google.cloud.hooks.bigquery.BigQueryCursor',
+ 'BigQueryHook': 'airflow.providers.google.cloud.hooks.bigquery.BigQueryHook',
+ 'GbqConnector': 'airflow.providers.google.cloud.hooks.bigquery.GbqConnector',
+ },
+ 'cassandra_hook': {
+ 'CassandraHook': 'airflow.providers.apache.cassandra.hooks.cassandra.CassandraHook',
+ },
+ 'cloudant_hook': {
+ 'CloudantHook': 'airflow.providers.cloudant.hooks.cloudant.CloudantHook',
+ },
+ 'databricks_hook': {
+ 'CANCEL_RUN_ENDPOINT': 'airflow.providers.databricks.hooks.databricks.CANCEL_RUN_ENDPOINT',
+ 'GET_RUN_ENDPOINT': 'airflow.providers.databricks.hooks.databricks.GET_RUN_ENDPOINT',
+ 'RESTART_CLUSTER_ENDPOINT': 'airflow.providers.databricks.hooks.databricks.RESTART_CLUSTER_ENDPOINT',
+ 'RUN_LIFE_CYCLE_STATES': 'airflow.providers.databricks.hooks.databricks.RUN_LIFE_CYCLE_STATES',
+ 'RUN_NOW_ENDPOINT': 'airflow.providers.databricks.hooks.databricks.RUN_NOW_ENDPOINT',
+ 'START_CLUSTER_ENDPOINT': 'airflow.providers.databricks.hooks.databricks.START_CLUSTER_ENDPOINT',
+ 'SUBMIT_RUN_ENDPOINT': 'airflow.providers.databricks.hooks.databricks.SUBMIT_RUN_ENDPOINT',
+ 'TERMINATE_CLUSTER_ENDPOINT': (
+ 'airflow.providers.databricks.hooks.databricks.TERMINATE_CLUSTER_ENDPOINT'
+ ),
+ 'DatabricksHook': 'airflow.providers.databricks.hooks.databricks.DatabricksHook',
+ 'RunState': 'airflow.providers.databricks.hooks.databricks.RunState',
+ },
+ 'datadog_hook': {
+ 'DatadogHook': 'airflow.providers.datadog.hooks.datadog.DatadogHook',
+ },
+ 'datastore_hook': {
+ 'DatastoreHook': 'airflow.providers.google.cloud.hooks.datastore.DatastoreHook',
+ },
+ 'dingding_hook': {
+ 'DingdingHook': 'airflow.providers.dingding.hooks.dingding.DingdingHook',
+ 'requests': 'airflow.providers.dingding.hooks.dingding.requests',
+ },
+ 'discord_webhook_hook': {
+ 'DiscordWebhookHook': 'airflow.providers.discord.hooks.discord_webhook.DiscordWebhookHook',
+ },
+ 'emr_hook': {
+ 'EmrHook': 'airflow.providers.amazon.aws.hooks.emr.EmrHook',
+ },
+ 'fs_hook': {
+ 'FSHook': 'airflow.hooks.filesystem.FSHook',
+ },
+ 'ftp_hook': {
+ 'FTPHook': 'airflow.providers.ftp.hooks.ftp.FTPHook',
+ 'FTPSHook': 'airflow.providers.ftp.hooks.ftp.FTPSHook',
+ },
+ 'gcp_api_base_hook': {
+ 'GoogleBaseHook': 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook',
+ 'GoogleCloudBaseHook': 'airflow.providers.google.common.hooks.base_google.GoogleBaseHook',
+ },
+ 'gcp_bigtable_hook': {
+ 'BigtableHook': 'airflow.providers.google.cloud.hooks.bigtable.BigtableHook',
+ },
+ 'gcp_cloud_build_hook': {
+ 'CloudBuildHook': 'airflow.providers.google.cloud.hooks.cloud_build.CloudBuildHook',
+ },
+ 'gcp_compute_hook': {
+ 'ComputeEngineHook': 'airflow.providers.google.cloud.hooks.compute.ComputeEngineHook',
+ 'GceHook': 'airflow.providers.google.cloud.hooks.compute.ComputeEngineHook',
+ },
+ 'gcp_container_hook': {
+ 'GKEHook': 'airflow.providers.google.cloud.hooks.kubernetes_engine.GKEHook',
+ 'GKEClusterHook': 'airflow.providers.google.cloud.hooks.kubernetes_engine.GKEHook',
+ },
+ 'gcp_dataflow_hook': {
+ 'DataflowHook': 'airflow.providers.google.cloud.hooks.dataflow.DataflowHook',
+ 'DataFlowHook': 'airflow.providers.google.cloud.hooks.dataflow.DataflowHook',
+ },
+ 'gcp_dataproc_hook': {
+ 'DataprocHook': 'airflow.providers.google.cloud.hooks.dataproc.DataprocHook',
+ 'DataProcHook': 'airflow.providers.google.cloud.hooks.dataproc.DataprocHook',
+ },
+ 'gcp_dlp_hook': {
+ 'CloudDLPHook': 'airflow.providers.google.cloud.hooks.dlp.CloudDLPHook',
+ 'DlpJob': 'airflow.providers.google.cloud.hooks.dlp.DlpJob',
+ },
+ 'gcp_function_hook': {
+ 'CloudFunctionsHook': 'airflow.providers.google.cloud.hooks.functions.CloudFunctionsHook',
+ 'GcfHook': 'airflow.providers.google.cloud.hooks.functions.CloudFunctionsHook',
+ },
+ 'gcp_kms_hook': {
+ 'CloudKMSHook': 'airflow.providers.google.cloud.hooks.kms.CloudKMSHook',
+ 'GoogleCloudKMSHook': 'airflow.providers.google.cloud.hooks.kms.CloudKMSHook',
+ },
+ 'gcp_mlengine_hook': {
+ 'MLEngineHook': 'airflow.providers.google.cloud.hooks.mlengine.MLEngineHook',
+ },
+ 'gcp_natural_language_hook': {
+ 'CloudNaturalLanguageHook': (
+ 'airflow.providers.google.cloud.hooks.natural_language.CloudNaturalLanguageHook'
+ ),
+ },
+ 'gcp_pubsub_hook': {
+ 'PubSubException': 'airflow.providers.google.cloud.hooks.pubsub.PubSubException',
+ 'PubSubHook': 'airflow.providers.google.cloud.hooks.pubsub.PubSubHook',
+ },
+ 'gcp_spanner_hook': {
+ 'SpannerHook': 'airflow.providers.google.cloud.hooks.spanner.SpannerHook',
+ 'CloudSpannerHook': 'airflow.providers.google.cloud.hooks.spanner.SpannerHook',
+ },
+ 'gcp_speech_to_text_hook': {
+ 'CloudSpeechToTextHook': 'airflow.providers.google.cloud.hooks.speech_to_text.CloudSpeechToTextHook',
+ 'GCPSpeechToTextHook': 'airflow.providers.google.cloud.hooks.speech_to_text.CloudSpeechToTextHook',
+ },
+ 'gcp_sql_hook': {
+ 'CloudSQLDatabaseHook': 'airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook',
+ 'CloudSQLHook': 'airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook',
+ 'CloudSqlDatabaseHook': 'airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook',
+ 'CloudSqlHook': 'airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLHook',
+ },
+ 'gcp_tasks_hook': {
+ 'CloudTasksHook': 'airflow.providers.google.cloud.hooks.tasks.CloudTasksHook',
+ },
+ 'gcp_text_to_speech_hook': {
+ 'CloudTextToSpeechHook': 'airflow.providers.google.cloud.hooks.text_to_speech.CloudTextToSpeechHook',
+ 'GCPTextToSpeechHook': 'airflow.providers.google.cloud.hooks.text_to_speech.CloudTextToSpeechHook',
+ },
+ 'gcp_transfer_hook': {
+ 'CloudDataTransferServiceHook': (
+ 'airflow.providers.google.cloud.hooks.cloud_storage_transfer_service.CloudDataTransferServiceHook'
+ ),
+ 'GCPTransferServiceHook': (
+ 'airflow.providers.google.cloud.hooks.cloud_storage_transfer_service.CloudDataTransferServiceHook'
+ ),
+ },
+ 'gcp_translate_hook': {
+ 'CloudTranslateHook': 'airflow.providers.google.cloud.hooks.translate.CloudTranslateHook',
+ },
+ 'gcp_video_intelligence_hook': {
+ 'CloudVideoIntelligenceHook': (
+ 'airflow.providers.google.cloud.hooks.video_intelligence.CloudVideoIntelligenceHook'
+ ),
+ },
+ 'gcp_vision_hook': {
+ 'CloudVisionHook': 'airflow.providers.google.cloud.hooks.vision.CloudVisionHook',
+ },
+ 'gcs_hook': {
+ 'GCSHook': 'airflow.providers.google.cloud.hooks.gcs.GCSHook',
+ 'GoogleCloudStorageHook': 'airflow.providers.google.cloud.hooks.gcs.GCSHook',
+ },
+ 'gdrive_hook': {
+ 'GoogleDriveHook': 'airflow.providers.google.suite.hooks.drive.GoogleDriveHook',
+ },
+ 'grpc_hook': {
+ 'GrpcHook': 'airflow.providers.grpc.hooks.grpc.GrpcHook',
+ },
+ 'imap_hook': {
+ 'ImapHook': 'airflow.providers.imap.hooks.imap.ImapHook',
+ 'Mail': 'airflow.providers.imap.hooks.imap.Mail',
+ 'MailPart': 'airflow.providers.imap.hooks.imap.MailPart',
+ },
+ 'jenkins_hook': {
+ 'JenkinsHook': 'airflow.providers.jenkins.hooks.jenkins.JenkinsHook',
+ },
+ 'jira_hook': {
+ 'JiraHook': 'airflow.providers.atlassian.jira.hooks.jira.JiraHook',
+ },
+ 'mongo_hook': {
+ 'MongoHook': 'airflow.providers.mongo.hooks.mongo.MongoHook',
+ },
+ 'openfaas_hook': {
+ 'OK_STATUS_CODE': 'airflow.providers.openfaas.hooks.openfaas.OK_STATUS_CODE',
+ 'OpenFaasHook': 'airflow.providers.openfaas.hooks.openfaas.OpenFaasHook',
+ 'requests': 'airflow.providers.openfaas.hooks.openfaas.requests',
+ },
+ 'opsgenie_alert_hook': {
+ 'OpsgenieAlertHook': 'airflow.providers.opsgenie.hooks.opsgenie.OpsgenieAlertHook',
+ },
+ 'pagerduty_hook': {
+ 'PagerdutyHook': 'airflow.providers.pagerduty.hooks.pagerduty.PagerdutyHook',
+ },
+ 'pinot_hook': {
+ 'PinotAdminHook': 'airflow.providers.apache.pinot.hooks.pinot.PinotAdminHook',
+ 'PinotDbApiHook': 'airflow.providers.apache.pinot.hooks.pinot.PinotDbApiHook',
+ },
+ 'qubole_check_hook': {
+ 'QuboleCheckHook': 'airflow.providers.qubole.hooks.qubole_check.QuboleCheckHook',
+ },
+ 'qubole_hook': {
+ 'QuboleHook': 'airflow.providers.qubole.hooks.qubole.QuboleHook',
+ },
+ 'redis_hook': {
+ 'RedisHook': 'airflow.providers.redis.hooks.redis.RedisHook',
+ },
+ 'redshift_hook': {
+ 'RedshiftHook': 'airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook',
+ },
+ 'sagemaker_hook': {
+ 'LogState': 'airflow.providers.amazon.aws.hooks.sagemaker.LogState',
+ 'Position': 'airflow.providers.amazon.aws.hooks.sagemaker.Position',
+ 'SageMakerHook': 'airflow.providers.amazon.aws.hooks.sagemaker.SageMakerHook',
+ 'argmin': 'airflow.providers.amazon.aws.hooks.sagemaker.argmin',
+ 'secondary_training_status_changed': (
+ 'airflow.providers.amazon.aws.hooks.sagemaker.secondary_training_status_changed'
+ ),
+ 'secondary_training_status_message': (
+ 'airflow.providers.amazon.aws.hooks.sagemaker.secondary_training_status_message'
+ ),
+ },
+ 'salesforce_hook': {
+ 'SalesforceHook': 'airflow.providers.salesforce.hooks.salesforce.SalesforceHook',
+ 'pd': 'airflow.providers.salesforce.hooks.salesforce.pd',
+ },
+ 'segment_hook': {
+ 'SegmentHook': 'airflow.providers.segment.hooks.segment.SegmentHook',
+ 'analytics': 'airflow.providers.segment.hooks.segment.analytics',
+ },
+ 'sftp_hook': {
+ 'SFTPHook': 'airflow.providers.sftp.hooks.sftp.SFTPHook',
+ },
+ 'slack_webhook_hook': {
+ 'SlackWebhookHook': 'airflow.providers.slack.hooks.slack_webhook.SlackWebhookHook',
+ },
+ 'snowflake_hook': {
+ 'SnowflakeHook': 'airflow.providers.snowflake.hooks.snowflake.SnowflakeHook',
+ },
+ 'spark_jdbc_hook': {
+ 'SparkJDBCHook': 'airflow.providers.apache.spark.hooks.spark_jdbc.SparkJDBCHook',
+ },
+ 'spark_sql_hook': {
+ 'SparkSqlHook': 'airflow.providers.apache.spark.hooks.spark_sql.SparkSqlHook',
+ },
+ 'spark_submit_hook': {
+ 'SparkSubmitHook': 'airflow.providers.apache.spark.hooks.spark_submit.SparkSubmitHook',
+ },
+ 'sqoop_hook': {
+ 'SqoopHook': 'airflow.providers.apache.sqoop.hooks.sqoop.SqoopHook',
+ },
+ 'ssh_hook': {
+ 'SSHHook': 'airflow.providers.ssh.hooks.ssh.SSHHook',
+ },
+ 'vertica_hook': {
+ 'VerticaHook': 'airflow.providers.vertica.hooks.vertica.VerticaHook',
+ },
+ 'wasb_hook': {
+ 'WasbHook': 'airflow.providers.microsoft.azure.hooks.wasb.WasbHook',
+ },
+ 'winrm_hook': {
+ 'WinRMHook': 'airflow.providers.microsoft.winrm.hooks.winrm.WinRMHook',
+ },
+}
+
+add_deprecated_classes(__deprecated_classes, __name__)
diff --git a/airflow/contrib/hooks/aws_athena_hook.py b/airflow/contrib/hooks/aws_athena_hook.py
deleted file mode 100644
index db1ecdfdbf3c5..0000000000000
--- a/airflow/contrib/hooks/aws_athena_hook.py
+++ /dev/null
@@ -1,29 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.hooks.athena`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.hooks.athena import AWSAthenaHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.athena`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/aws_datasync_hook.py b/airflow/contrib/hooks/aws_datasync_hook.py
deleted file mode 100644
index 0d485475b0310..0000000000000
--- a/airflow/contrib/hooks/aws_datasync_hook.py
+++ /dev/null
@@ -1,29 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.hooks.datasync`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.hooks.datasync import AWSDataSyncHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.datasync`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/aws_dynamodb_hook.py b/airflow/contrib/hooks/aws_dynamodb_hook.py
deleted file mode 100644
index dedb80073e3e5..0000000000000
--- a/airflow/contrib/hooks/aws_dynamodb_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.hooks.dynamodb`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.hooks.dynamodb import AwsDynamoDBHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.dynamodb`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/aws_firehose_hook.py b/airflow/contrib/hooks/aws_firehose_hook.py
deleted file mode 100644
index c6d39cd795b79..0000000000000
--- a/airflow/contrib/hooks/aws_firehose_hook.py
+++ /dev/null
@@ -1,29 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.hooks.kinesis`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.hooks.kinesis import AwsFirehoseHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.kinesis`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/aws_glue_catalog_hook.py b/airflow/contrib/hooks/aws_glue_catalog_hook.py
deleted file mode 100644
index 703ba47b81bf3..0000000000000
--- a/airflow/contrib/hooks/aws_glue_catalog_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.hooks.glue_catalog`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.hooks.glue_catalog import AwsGlueCatalogHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.glue_catalog`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/aws_hook.py b/airflow/contrib/hooks/aws_hook.py
deleted file mode 100644
index c40e32c0305f1..0000000000000
--- a/airflow/contrib/hooks/aws_hook.py
+++ /dev/null
@@ -1,43 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.hooks.base_aws`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook, _parse_s3_config, boto3 # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.base_aws`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class AwsHook(AwsBaseHook):
- """
- This class is deprecated.
- Please use :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- "This class is deprecated. Please use `airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`.",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/hooks/aws_lambda_hook.py b/airflow/contrib/hooks/aws_lambda_hook.py
deleted file mode 100644
index 379aaf5486655..0000000000000
--- a/airflow/contrib/hooks/aws_lambda_hook.py
+++ /dev/null
@@ -1,32 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""
-This module is deprecated.
-Please use :mod:`airflow.providers.amazon.aws.hooks.lambda_function`.
-"""
-
-import warnings
-
-from airflow.providers.amazon.aws.hooks.lambda_function import AwsLambdaHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.lambda_function`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/aws_logs_hook.py b/airflow/contrib/hooks/aws_logs_hook.py
deleted file mode 100644
index 9b9c449f8415e..0000000000000
--- a/airflow/contrib/hooks/aws_logs_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.hooks.logs`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.logs`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/aws_sns_hook.py b/airflow/contrib/hooks/aws_sns_hook.py
deleted file mode 100644
index b1318f52add2d..0000000000000
--- a/airflow/contrib/hooks/aws_sns_hook.py
+++ /dev/null
@@ -1,29 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.hooks.sns`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.hooks.sns import AwsSnsHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.sns`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/aws_sqs_hook.py b/airflow/contrib/hooks/aws_sqs_hook.py
deleted file mode 100644
index aafd8379ab5dd..0000000000000
--- a/airflow/contrib/hooks/aws_sqs_hook.py
+++ /dev/null
@@ -1,44 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.hooks.sqs`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.hooks.sqs import SqsHook
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.sqs`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class SQSHook(SqsHook):
- """
- This class is deprecated.
- Please use :class:`airflow.providers.amazon.aws.hooks.sqs.SqsHook`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- "This class is deprecated. Please use `airflow.providers.amazon.aws.hooks.sqs.SqsHook`.",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/hooks/azure_container_instance_hook.py b/airflow/contrib/hooks/azure_container_instance_hook.py
deleted file mode 100644
index 9fefa5c679d38..0000000000000
--- a/airflow/contrib/hooks/azure_container_instance_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.hooks.container_instance`."""
-
-import warnings
-
-from airflow.providers.microsoft.azure.hooks.container_instance import AzureContainerInstanceHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.container_instance`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/azure_container_registry_hook.py b/airflow/contrib/hooks/azure_container_registry_hook.py
deleted file mode 100644
index 14e55ef820737..0000000000000
--- a/airflow/contrib/hooks/azure_container_registry_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.container_registry`."""
-
-import warnings
-
-from airflow.providers.microsoft.azure.hooks.container_registry import AzureContainerRegistryHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.container_registry`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/azure_container_volume_hook.py b/airflow/contrib/hooks/azure_container_volume_hook.py
deleted file mode 100644
index facfdaca4cc8a..0000000000000
--- a/airflow/contrib/hooks/azure_container_volume_hook.py
+++ /dev/null
@@ -1,31 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-This module is deprecated.
-Please use :mod:`airflow.providers.microsoft.azure.hooks.container_volume`.
-"""
-
-import warnings
-
-from airflow.providers.microsoft.azure.hooks.container_volume import AzureContainerVolumeHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.container_volume`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/azure_cosmos_hook.py b/airflow/contrib/hooks/azure_cosmos_hook.py
deleted file mode 100644
index 4152f15e9d7b1..0000000000000
--- a/airflow/contrib/hooks/azure_cosmos_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.hooks.cosmos`."""
-
-import warnings
-
-from airflow.providers.microsoft.azure.hooks.cosmos import AzureCosmosDBHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.cosmos`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/azure_data_lake_hook.py b/airflow/contrib/hooks/azure_data_lake_hook.py
deleted file mode 100644
index 3442d1345078c..0000000000000
--- a/airflow/contrib/hooks/azure_data_lake_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.hooks.data_lake`."""
-
-import warnings
-
-from airflow.providers.microsoft.azure.hooks.data_lake import AzureDataLakeHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.data_lake`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/azure_fileshare_hook.py b/airflow/contrib/hooks/azure_fileshare_hook.py
deleted file mode 100644
index f0a5b2ec4c9bb..0000000000000
--- a/airflow/contrib/hooks/azure_fileshare_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.hooks.fileshare`."""
-
-import warnings
-
-from airflow.providers.microsoft.azure.hooks.fileshare import AzureFileShareHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.fileshare`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/bigquery_hook.py b/airflow/contrib/hooks/bigquery_hook.py
deleted file mode 100644
index e9e10402943a5..0000000000000
--- a/airflow/contrib/hooks/bigquery_hook.py
+++ /dev/null
@@ -1,34 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.hooks.bigquery`."""
-
-import warnings
-
-from airflow.providers.google.cloud.hooks.bigquery import ( # noqa
- BigQueryBaseCursor,
- BigQueryConnection,
- BigQueryCursor,
- BigQueryHook,
- GbqConnector,
-)
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.bigquery`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/cassandra_hook.py b/airflow/contrib/hooks/cassandra_hook.py
deleted file mode 100644
index ea4c748c2bfc5..0000000000000
--- a/airflow/contrib/hooks/cassandra_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.apache.cassandra.hooks.cassandra`."""
-
-import warnings
-
-from airflow.providers.apache.cassandra.hooks.cassandra import CassandraHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.apache.cassandra.hooks.cassandra`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/cloudant_hook.py b/airflow/contrib/hooks/cloudant_hook.py
deleted file mode 100644
index ab7a1fa398fa7..0000000000000
--- a/airflow/contrib/hooks/cloudant_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.cloudant.hooks.cloudant`."""
-
-import warnings
-
-from airflow.providers.cloudant.hooks.cloudant import CloudantHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.cloudant.hooks.cloudant`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/databricks_hook.py b/airflow/contrib/hooks/databricks_hook.py
deleted file mode 100644
index 84746d39da8ae..0000000000000
--- a/airflow/contrib/hooks/databricks_hook.py
+++ /dev/null
@@ -1,39 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.databricks.hooks.databricks`."""
-
-import warnings
-
-from airflow.providers.databricks.hooks.databricks import ( # noqa
- CANCEL_RUN_ENDPOINT,
- GET_RUN_ENDPOINT,
- RESTART_CLUSTER_ENDPOINT,
- RUN_LIFE_CYCLE_STATES,
- RUN_NOW_ENDPOINT,
- START_CLUSTER_ENDPOINT,
- SUBMIT_RUN_ENDPOINT,
- TERMINATE_CLUSTER_ENDPOINT,
- DatabricksHook,
- RunState,
-)
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.databricks.hooks.databricks`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/datadog_hook.py b/airflow/contrib/hooks/datadog_hook.py
deleted file mode 100644
index be275e9adf81d..0000000000000
--- a/airflow/contrib/hooks/datadog_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.datadog.hooks.datadog`."""
-
-import warnings
-
-from airflow.providers.datadog.hooks.datadog import DatadogHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.datadog.hooks.datadog`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/datastore_hook.py b/airflow/contrib/hooks/datastore_hook.py
deleted file mode 100644
index 9898e2fb16814..0000000000000
--- a/airflow/contrib/hooks/datastore_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.hooks.datastore`."""
-
-import warnings
-
-from airflow.providers.google.cloud.hooks.datastore import DatastoreHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.datastore`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/dingding_hook.py b/airflow/contrib/hooks/dingding_hook.py
deleted file mode 100644
index deff0414baf94..0000000000000
--- a/airflow/contrib/hooks/dingding_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.dingding.hooks.dingding`."""
-
-import warnings
-
-from airflow.providers.dingding.hooks.dingding import DingdingHook, requests # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.dingding.hooks.dingding`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/discord_webhook_hook.py b/airflow/contrib/hooks/discord_webhook_hook.py
deleted file mode 100644
index a907d2115a724..0000000000000
--- a/airflow/contrib/hooks/discord_webhook_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.discord.hooks.discord_webhook`."""
-
-import warnings
-
-from airflow.providers.discord.hooks.discord_webhook import DiscordWebhookHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.discord.hooks.discord_webhook`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/emr_hook.py b/airflow/contrib/hooks/emr_hook.py
deleted file mode 100644
index 1a15ee3edcd17..0000000000000
--- a/airflow/contrib/hooks/emr_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.hooks.emr`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.hooks.emr import EmrHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.emr`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/fs_hook.py b/airflow/contrib/hooks/fs_hook.py
deleted file mode 100644
index bc247c1948074..0000000000000
--- a/airflow/contrib/hooks/fs_hook.py
+++ /dev/null
@@ -1,26 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.hooks.filesystem`."""
-
-import warnings
-
-from airflow.hooks.filesystem import FSHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.hooks.filesystem`.", DeprecationWarning, stacklevel=2
-)
diff --git a/airflow/contrib/hooks/ftp_hook.py b/airflow/contrib/hooks/ftp_hook.py
deleted file mode 100644
index 8d2e9cb06c843..0000000000000
--- a/airflow/contrib/hooks/ftp_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.ftp.hooks.ftp`."""
-
-import warnings
-
-from airflow.providers.ftp.hooks.ftp import FTPHook, FTPSHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.ftp.hooks.ftp`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/gcp_api_base_hook.py b/airflow/contrib/hooks/gcp_api_base_hook.py
deleted file mode 100644
index 3226a8683d7a7..0000000000000
--- a/airflow/contrib/hooks/gcp_api_base_hook.py
+++ /dev/null
@@ -1,43 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.common.hooks.base_google`."""
-import warnings
-
-from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.common.hooks.base_google`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class GoogleCloudBaseHook(GoogleBaseHook):
- """
- This class is deprecated. Please use
- `airflow.providers.google.common.hooks.base_google.GoogleBaseHook`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- "This class is deprecated. Please use "
- "`airflow.providers.google.common.hooks.base_google.GoogleBaseHook`.",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/hooks/gcp_bigtable_hook.py b/airflow/contrib/hooks/gcp_bigtable_hook.py
deleted file mode 100644
index 47ccd2414a839..0000000000000
--- a/airflow/contrib/hooks/gcp_bigtable_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.hooks.bigtable`."""
-
-import warnings
-
-from airflow.providers.google.cloud.hooks.bigtable import BigtableHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.bigtable`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/gcp_cloud_build_hook.py b/airflow/contrib/hooks/gcp_cloud_build_hook.py
deleted file mode 100644
index 691ae728a46d1..0000000000000
--- a/airflow/contrib/hooks/gcp_cloud_build_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.hooks.cloud_build`."""
-
-import warnings
-
-from airflow.providers.google.cloud.hooks.cloud_build import CloudBuildHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.cloud_build`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/gcp_compute_hook.py b/airflow/contrib/hooks/gcp_compute_hook.py
deleted file mode 100644
index 5ac9df1378561..0000000000000
--- a/airflow/contrib/hooks/gcp_compute_hook.py
+++ /dev/null
@@ -1,44 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.hooks.compute`."""
-
-import warnings
-
-from airflow.providers.google.cloud.hooks.compute import ComputeEngineHook
-
-warnings.warn(
- "This module is deprecated. Please use airflow.providers.google.cloud.hooks.compute`",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class GceHook(ComputeEngineHook):
- """
- This class is deprecated.
- Please use :class:`airflow.providers.google.cloud.hooks.compute.ComputeEngineHook`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- "This class is deprecated. Please use `airflow.providers.google.cloud.hooks.compute`.",
- DeprecationWarning,
- stacklevel=2,
- )
-
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/hooks/gcp_container_hook.py b/airflow/contrib/hooks/gcp_container_hook.py
deleted file mode 100644
index e825dbe5492c2..0000000000000
--- a/airflow/contrib/hooks/gcp_container_hook.py
+++ /dev/null
@@ -1,43 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-This module is deprecated.
-Please use :mod:`airflow.providers.google.cloud.hooks.kubernetes_engine`.
-"""
-
-import warnings
-
-from airflow.providers.google.cloud.hooks.kubernetes_engine import GKEHook
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.kubernetes_engine`",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class GKEClusterHook(GKEHook):
- """This class is deprecated. Please use `airflow.providers.google.cloud.hooks.container.GKEHook`."""
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- "This class is deprecated. Please use `airflow.providers.google.cloud.hooks.container.GKEHook`.",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/hooks/gcp_dataflow_hook.py b/airflow/contrib/hooks/gcp_dataflow_hook.py
deleted file mode 100644
index f489ffad82006..0000000000000
--- a/airflow/contrib/hooks/gcp_dataflow_hook.py
+++ /dev/null
@@ -1,44 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.hooks.dataflow`."""
-
-import warnings
-
-from airflow.providers.google.cloud.hooks.dataflow import DataflowHook
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.dataflow`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class DataFlowHook(DataflowHook):
- """
- This class is deprecated.
- Please use :class:`airflow.providers.google.cloud.hooks.dataflow.DataflowHook`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- "This class is deprecated. "
- "Please use `airflow.providers.google.cloud.hooks.dataflow.DataflowHook`.",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/hooks/gcp_dataproc_hook.py b/airflow/contrib/hooks/gcp_dataproc_hook.py
deleted file mode 100644
index 1f02c6de42f7e..0000000000000
--- a/airflow/contrib/hooks/gcp_dataproc_hook.py
+++ /dev/null
@@ -1,45 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.hooks.dataproc`."""
-
-import warnings
-
-from airflow.providers.google.cloud.hooks.dataproc import DataprocHook
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.dataproc`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class DataProcHook(DataprocHook):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.hooks.dataproc.DataprocHook`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.hooks.dataproc.DataprocHook`.""",
- DeprecationWarning,
- stacklevel=2,
- )
-
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/hooks/gcp_dlp_hook.py b/airflow/contrib/hooks/gcp_dlp_hook.py
deleted file mode 100644
index 77a9da6b1a8a8..0000000000000
--- a/airflow/contrib/hooks/gcp_dlp_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.hooks.dlp`."""
-
-import warnings
-
-from airflow.providers.google.cloud.hooks.dlp import CloudDLPHook, DlpJob # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.dlp`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/gcp_function_hook.py b/airflow/contrib/hooks/gcp_function_hook.py
deleted file mode 100644
index dc6464fc07ce9..0000000000000
--- a/airflow/contrib/hooks/gcp_function_hook.py
+++ /dev/null
@@ -1,45 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.hooks.functions`."""
-
-import warnings
-
-from airflow.providers.google.cloud.hooks.functions import CloudFunctionsHook
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.functions`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class GcfHook(CloudFunctionsHook):
- """
- This class is deprecated. Please use
- `airflow.providers.google.cloud.hooks.functions.CloudFunctionsHook`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- "This class is deprecated. "
- "Please use `airflow.providers.google.cloud.hooks.functions.CloudFunctionsHook`.",
- DeprecationWarning,
- stacklevel=2,
- )
-
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/hooks/gcp_kms_hook.py b/airflow/contrib/hooks/gcp_kms_hook.py
deleted file mode 100644
index 1b409be99f121..0000000000000
--- a/airflow/contrib/hooks/gcp_kms_hook.py
+++ /dev/null
@@ -1,40 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.hooks.kms`."""
-
-import warnings
-
-from airflow.providers.google.cloud.hooks.kms import CloudKMSHook
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.kms`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class GoogleCloudKMSHook(CloudKMSHook):
- """This class is deprecated. Please use `airflow.providers.google.cloud.hooks.kms.CloudKMSHook`."""
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- "This class is deprecated. Please use `airflow.providers.google.cloud.hooks.kms.CloudKMSHook`.",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/hooks/gcp_mlengine_hook.py b/airflow/contrib/hooks/gcp_mlengine_hook.py
deleted file mode 100644
index 57978e008b375..0000000000000
--- a/airflow/contrib/hooks/gcp_mlengine_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.hooks.mlengine`."""
-
-import warnings
-
-from airflow.providers.google.cloud.hooks.mlengine import MLEngineHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.mlengine`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/gcp_natural_language_hook.py b/airflow/contrib/hooks/gcp_natural_language_hook.py
deleted file mode 100644
index 86ee9f8675d57..0000000000000
--- a/airflow/contrib/hooks/gcp_natural_language_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.hooks.natural_language`."""
-
-import warnings
-
-from airflow.providers.google.cloud.hooks.natural_language import CloudNaturalLanguageHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.natural_language`",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/gcp_pubsub_hook.py b/airflow/contrib/hooks/gcp_pubsub_hook.py
deleted file mode 100644
index 677a0f03fa506..0000000000000
--- a/airflow/contrib/hooks/gcp_pubsub_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.hooks.pubsub`."""
-
-import warnings
-
-from airflow.providers.google.cloud.hooks.pubsub import PubSubException, PubSubHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.pubsub`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/gcp_spanner_hook.py b/airflow/contrib/hooks/gcp_spanner_hook.py
deleted file mode 100644
index d3608ce5d0994..0000000000000
--- a/airflow/contrib/hooks/gcp_spanner_hook.py
+++ /dev/null
@@ -1,36 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.hooks.spanner`."""
-
-import warnings
-
-from airflow.providers.google.cloud.hooks.spanner import SpannerHook
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.spanner`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class CloudSpannerHook(SpannerHook):
- """This class is deprecated. Please use `airflow.providers.google.cloud.hooks.spanner.SpannerHook`."""
-
- def __init__(self, *args, **kwargs):
- warnings.warn(self.__doc__, DeprecationWarning, stacklevel=2)
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/hooks/gcp_speech_to_text_hook.py b/airflow/contrib/hooks/gcp_speech_to_text_hook.py
deleted file mode 100644
index 1c4f4e16a79c0..0000000000000
--- a/airflow/contrib/hooks/gcp_speech_to_text_hook.py
+++ /dev/null
@@ -1,45 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.hooks.speech_to_text`."""
-
-import warnings
-
-from airflow.providers.google.cloud.hooks.speech_to_text import CloudSpeechToTextHook
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.speech_to_text`",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class GCPSpeechToTextHook(CloudSpeechToTextHook):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.hooks.speech_to_text.CloudSpeechToTextHook`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- "This class is deprecated. "
- "Please use `airflow.providers.google.cloud.hooks.speech_to_text.CloudSpeechToTextHook`.",
- DeprecationWarning,
- stacklevel=2,
- )
-
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/hooks/gcp_sql_hook.py b/airflow/contrib/hooks/gcp_sql_hook.py
deleted file mode 100644
index d3e85a7c3a57c..0000000000000
--- a/airflow/contrib/hooks/gcp_sql_hook.py
+++ /dev/null
@@ -1,47 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.hooks.cloud_sql`."""
-
-import warnings
-
-from airflow.providers.google.cloud.hooks.cloud_sql import CloudSQLDatabaseHook, CloudSQLHook
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.cloud_sql`",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class CloudSqlDatabaseHook(CloudSQLDatabaseHook):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(self.__doc__, DeprecationWarning, stacklevel=2)
- super().__init__(*args, **kwargs)
-
-
-class CloudSqlHook(CloudSQLHook):
- """This class is deprecated. Please use `airflow.providers.google.cloud.hooks.sql.CloudSQLHook`."""
-
- def __init__(self, *args, **kwargs):
- warnings.warn(self.__doc__, DeprecationWarning, stacklevel=2)
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/hooks/gcp_tasks_hook.py b/airflow/contrib/hooks/gcp_tasks_hook.py
deleted file mode 100644
index 1753b2a1842d4..0000000000000
--- a/airflow/contrib/hooks/gcp_tasks_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.hooks.tasks`."""
-
-import warnings
-
-from airflow.providers.google.cloud.hooks.tasks import CloudTasksHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.tasks`",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/gcp_text_to_speech_hook.py b/airflow/contrib/hooks/gcp_text_to_speech_hook.py
deleted file mode 100644
index 9e53ec16f5f0e..0000000000000
--- a/airflow/contrib/hooks/gcp_text_to_speech_hook.py
+++ /dev/null
@@ -1,45 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.hooks.text_to_speech`."""
-
-import warnings
-
-from airflow.providers.google.cloud.hooks.text_to_speech import CloudTextToSpeechHook
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.text_to_speech`",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class GCPTextToSpeechHook(CloudTextToSpeechHook):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.hooks.text_to_speech.CloudTextToSpeechHook`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- "This class is deprecated. "
- "Please use `airflow.providers.google.cloud.hooks.text_to_speech.CloudTextToSpeechHook`.",
- DeprecationWarning,
- stacklevel=2,
- )
-
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/hooks/gcp_transfer_hook.py b/airflow/contrib/hooks/gcp_transfer_hook.py
deleted file mode 100644
index 57c098d52b7f9..0000000000000
--- a/airflow/contrib/hooks/gcp_transfer_hook.py
+++ /dev/null
@@ -1,50 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-This module is deprecated. Please use
-`airflow.providers.google.cloud.hooks.cloud_storage_transfer_service`.
-"""
-
-import warnings
-
-from airflow.providers.google.cloud.hooks.cloud_storage_transfer_service import CloudDataTransferServiceHook
-
-warnings.warn(
- "This module is deprecated. "
- "Please use `airflow.providers.google.cloud.hooks.cloud_storage_transfer_service`",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class GCPTransferServiceHook(CloudDataTransferServiceHook):
- """
- This class is deprecated. Please use
- `airflow.providers.google.cloud.hooks.cloud_storage_transfer_service.CloudDataTransferServiceHook`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use
- `airflow.providers.google.cloud.hooks.cloud_storage_transfer_service
- .CloudDataTransferServiceHook`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/hooks/gcp_translate_hook.py b/airflow/contrib/hooks/gcp_translate_hook.py
deleted file mode 100644
index 1b0cec8b5ec9c..0000000000000
--- a/airflow/contrib/hooks/gcp_translate_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.hooks.translate`."""
-
-import warnings
-
-from airflow.providers.google.cloud.hooks.translate import CloudTranslateHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.translate`",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/gcp_video_intelligence_hook.py b/airflow/contrib/hooks/gcp_video_intelligence_hook.py
deleted file mode 100644
index a71ef4649e909..0000000000000
--- a/airflow/contrib/hooks/gcp_video_intelligence_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.hooks.video_intelligence`."""
-
-import warnings
-
-from airflow.providers.google.cloud.hooks.video_intelligence import CloudVideoIntelligenceHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.video_intelligence`",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/gcp_vision_hook.py b/airflow/contrib/hooks/gcp_vision_hook.py
deleted file mode 100644
index 52f47f42bfdf5..0000000000000
--- a/airflow/contrib/hooks/gcp_vision_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.hooks.vision`."""
-
-import warnings
-
-from airflow.providers.google.cloud.hooks.vision import CloudVisionHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.vision`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/gcs_hook.py b/airflow/contrib/hooks/gcs_hook.py
deleted file mode 100644
index 910206acf9286..0000000000000
--- a/airflow/contrib/hooks/gcs_hook.py
+++ /dev/null
@@ -1,39 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.hooks.gcs`."""
-import warnings
-
-from airflow.providers.google.cloud.hooks.gcs import GCSHook
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.hooks.gcs`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class GoogleCloudStorageHook(GCSHook):
- """This class is deprecated. Please use `airflow.providers.google.cloud.hooks.gcs.GCSHook`."""
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- "This class is deprecated. Please use `airflow.providers.google.cloud.hooks.gcs.GCSHook`.",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/hooks/gdrive_hook.py b/airflow/contrib/hooks/gdrive_hook.py
deleted file mode 100644
index dad8459c58394..0000000000000
--- a/airflow/contrib/hooks/gdrive_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.suite.hooks.drive`."""
-
-import warnings
-
-from airflow.providers.google.suite.hooks.drive import GoogleDriveHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.suite.hooks.drive`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/grpc_hook.py b/airflow/contrib/hooks/grpc_hook.py
deleted file mode 100644
index f7aa6e2216fa9..0000000000000
--- a/airflow/contrib/hooks/grpc_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.grpc.hooks.grpc`."""
-
-import warnings
-
-from airflow.providers.grpc.hooks.grpc import GrpcHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.grpc.hooks.grpc`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/imap_hook.py b/airflow/contrib/hooks/imap_hook.py
deleted file mode 100644
index 57703966ab3c9..0000000000000
--- a/airflow/contrib/hooks/imap_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.imap.hooks.imap`."""
-
-import warnings
-
-from airflow.providers.imap.hooks.imap import ImapHook, Mail, MailPart # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.imap.hooks.imap`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/jenkins_hook.py b/airflow/contrib/hooks/jenkins_hook.py
deleted file mode 100644
index 178474ea77991..0000000000000
--- a/airflow/contrib/hooks/jenkins_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.jenkins.hooks.jenkins`."""
-
-import warnings
-
-from airflow.providers.jenkins.hooks.jenkins import JenkinsHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.jenkins.hooks.jenkins`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/jira_hook.py b/airflow/contrib/hooks/jira_hook.py
deleted file mode 100644
index 8f9d4670154a2..0000000000000
--- a/airflow/contrib/hooks/jira_hook.py
+++ /dev/null
@@ -1,29 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""This module is deprecated. Please use :mod:`airflow.providers.jira.hooks.jira`."""
-
-import warnings
-
-from airflow.providers.jira.hooks.jira import JiraHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.jira.hooks.jira`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/mongo_hook.py b/airflow/contrib/hooks/mongo_hook.py
deleted file mode 100644
index 63f6eea36ba51..0000000000000
--- a/airflow/contrib/hooks/mongo_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.mongo.hooks.mongo`."""
-
-import warnings
-
-from airflow.providers.mongo.hooks.mongo import MongoHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.mongo.hooks.mongo`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/openfaas_hook.py b/airflow/contrib/hooks/openfaas_hook.py
deleted file mode 100644
index a0e71ffe5e564..0000000000000
--- a/airflow/contrib/hooks/openfaas_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.openfaas.hooks.openfaas`."""
-
-import warnings
-
-from airflow.providers.openfaas.hooks.openfaas import OK_STATUS_CODE, OpenFaasHook, requests # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.openfaas.hooks.openfaas`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/opsgenie_alert_hook.py b/airflow/contrib/hooks/opsgenie_alert_hook.py
deleted file mode 100644
index abd9c89e09308..0000000000000
--- a/airflow/contrib/hooks/opsgenie_alert_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.opsgenie.hooks.opsgenie`."""
-
-import warnings
-
-from airflow.providers.opsgenie.hooks.opsgenie import OpsgenieAlertHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.opsgenie.hooks.opsgenie`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/pagerduty_hook.py b/airflow/contrib/hooks/pagerduty_hook.py
deleted file mode 100644
index 33797b0e60ea4..0000000000000
--- a/airflow/contrib/hooks/pagerduty_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.pagerduty.hooks.pagerduty`."""
-
-import warnings
-
-from airflow.providers.pagerduty.hooks.pagerduty import PagerdutyHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.pagerduty.hooks.pagerduty`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/pinot_hook.py b/airflow/contrib/hooks/pinot_hook.py
deleted file mode 100644
index 159677fdec52a..0000000000000
--- a/airflow/contrib/hooks/pinot_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.apache.pinot.hooks.pinot`."""
-
-import warnings
-
-from airflow.providers.apache.pinot.hooks.pinot import PinotAdminHook, PinotDbApiHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.apache.pinot.hooks.pinot`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/qubole_check_hook.py b/airflow/contrib/hooks/qubole_check_hook.py
deleted file mode 100644
index 0a674d7c763ad..0000000000000
--- a/airflow/contrib/hooks/qubole_check_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.qubole.hooks.qubole_check`."""
-
-import warnings
-
-from airflow.providers.qubole.hooks.qubole_check import QuboleCheckHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.qubole.hooks.qubole_check`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/qubole_hook.py b/airflow/contrib/hooks/qubole_hook.py
deleted file mode 100644
index 6a695bca76462..0000000000000
--- a/airflow/contrib/hooks/qubole_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.qubole.hooks.qubole`."""
-
-import warnings
-
-from airflow.providers.qubole.hooks.qubole import QuboleHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.qubole.hooks.qubole`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/redis_hook.py b/airflow/contrib/hooks/redis_hook.py
deleted file mode 100644
index 57bdab5aead08..0000000000000
--- a/airflow/contrib/hooks/redis_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.redis.hooks.redis`."""
-
-import warnings
-
-from airflow.providers.redis.hooks.redis import RedisHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.redis.hooks.redis`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/redshift_hook.py b/airflow/contrib/hooks/redshift_hook.py
deleted file mode 100644
index f33515f6cd954..0000000000000
--- a/airflow/contrib/hooks/redshift_hook.py
+++ /dev/null
@@ -1,31 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.hooks.redshift_cluster`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.redshift_cluster`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-__all__ = ["RedshiftHook"]
diff --git a/airflow/contrib/hooks/sagemaker_hook.py b/airflow/contrib/hooks/sagemaker_hook.py
deleted file mode 100644
index 321f25b022fa9..0000000000000
--- a/airflow/contrib/hooks/sagemaker_hook.py
+++ /dev/null
@@ -1,35 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.hooks.sagemaker`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.hooks.sagemaker import ( # noqa
- LogState,
- Position,
- SageMakerHook,
- argmin,
- secondary_training_status_changed,
- secondary_training_status_message,
-)
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.sagemaker`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/salesforce_hook.py b/airflow/contrib/hooks/salesforce_hook.py
deleted file mode 100644
index a707a527b8449..0000000000000
--- a/airflow/contrib/hooks/salesforce_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.salesforce.hooks.salesforce`."""
-
-import warnings
-
-from airflow.providers.salesforce.hooks.salesforce import SalesforceHook, pd # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.salesforce.hooks.salesforce`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/segment_hook.py b/airflow/contrib/hooks/segment_hook.py
deleted file mode 100644
index 6da62578cee8b..0000000000000
--- a/airflow/contrib/hooks/segment_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.segment.hooks.segment`."""
-
-import warnings
-
-from airflow.providers.segment.hooks.segment import SegmentHook, analytics # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.segment.hooks.segment`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/sftp_hook.py b/airflow/contrib/hooks/sftp_hook.py
deleted file mode 100644
index 0153e8e54d062..0000000000000
--- a/airflow/contrib/hooks/sftp_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.sftp.hooks.sftp`."""
-
-import warnings
-
-from airflow.providers.sftp.hooks.sftp import SFTPHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.sftp.hooks.sftp`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/slack_webhook_hook.py b/airflow/contrib/hooks/slack_webhook_hook.py
deleted file mode 100644
index f438d11575b43..0000000000000
--- a/airflow/contrib/hooks/slack_webhook_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.slack.hooks.slack_webhook`."""
-
-import warnings
-
-from airflow.providers.slack.hooks.slack_webhook import SlackWebhookHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.slack.hooks.slack_webhook`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/snowflake_hook.py b/airflow/contrib/hooks/snowflake_hook.py
deleted file mode 100644
index 804baccf66cdf..0000000000000
--- a/airflow/contrib/hooks/snowflake_hook.py
+++ /dev/null
@@ -1,29 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""This module is deprecated. Please use :mod:`airflow.providers.snowflake.hooks.snowflake`."""
-
-import warnings
-
-from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.snowflake.hooks.snowflake`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/spark_jdbc_hook.py b/airflow/contrib/hooks/spark_jdbc_hook.py
deleted file mode 100644
index 1b48d094acbb6..0000000000000
--- a/airflow/contrib/hooks/spark_jdbc_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.apache.spark.hooks.spark_jdbc`."""
-
-import warnings
-
-from airflow.providers.apache.spark.hooks.spark_jdbc import SparkJDBCHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.apache.spark.hooks.spark_jdbc`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/spark_sql_hook.py b/airflow/contrib/hooks/spark_sql_hook.py
deleted file mode 100644
index 6b262ed3efce5..0000000000000
--- a/airflow/contrib/hooks/spark_sql_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.apache.spark.hooks.spark_sql`."""
-
-import warnings
-
-from airflow.providers.apache.spark.hooks.spark_sql import SparkSqlHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.apache.spark.hooks.spark_sql`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/spark_submit_hook.py b/airflow/contrib/hooks/spark_submit_hook.py
deleted file mode 100644
index fbdbf4f9f0d53..0000000000000
--- a/airflow/contrib/hooks/spark_submit_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.apache.spark.hooks.spark_submit`."""
-
-import warnings
-
-from airflow.providers.apache.spark.hooks.spark_submit import SparkSubmitHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.apache.spark.hooks.spark_submit`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/sqoop_hook.py b/airflow/contrib/hooks/sqoop_hook.py
deleted file mode 100644
index f231c0f26a8dc..0000000000000
--- a/airflow/contrib/hooks/sqoop_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.apache.sqoop.hooks.sqoop`."""
-
-import warnings
-
-from airflow.providers.apache.sqoop.hooks.sqoop import SqoopHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.apache.sqoop.hooks.sqoop`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/ssh_hook.py b/airflow/contrib/hooks/ssh_hook.py
deleted file mode 100644
index ef3000d888c50..0000000000000
--- a/airflow/contrib/hooks/ssh_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.ssh.hooks.ssh`."""
-
-import warnings
-
-from airflow.providers.ssh.hooks.ssh import SSHHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.ssh.hooks.ssh`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/vertica_hook.py b/airflow/contrib/hooks/vertica_hook.py
deleted file mode 100644
index fc84b222d0a5a..0000000000000
--- a/airflow/contrib/hooks/vertica_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.vertica.hooks.vertica`."""
-
-import warnings
-
-from airflow.providers.vertica.hooks.vertica import VerticaHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.vertica.hooks.vertica`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/wasb_hook.py b/airflow/contrib/hooks/wasb_hook.py
deleted file mode 100644
index 3b5eb650934af..0000000000000
--- a/airflow/contrib/hooks/wasb_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.hooks.wasb`."""
-
-import warnings
-
-from airflow.providers.microsoft.azure.hooks.wasb import WasbHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.microsoft.azure.hooks.wasb`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/hooks/winrm_hook.py b/airflow/contrib/hooks/winrm_hook.py
deleted file mode 100644
index 35e7db2bc7294..0000000000000
--- a/airflow/contrib/hooks/winrm_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.winrm.hooks.winrm`."""
-
-import warnings
-
-from airflow.providers.microsoft.winrm.hooks.winrm import WinRMHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.microsoft.winrm.hooks.winrm`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/__init__.py b/airflow/contrib/operators/__init__.py
index 2041adb9ea24a..8e02c871b84d2 100644
--- a/airflow/contrib/operators/__init__.py
+++ b/airflow/contrib/operators/__init__.py
@@ -15,5 +15,1128 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-#
"""This package is deprecated. Please use `airflow.operators` or `airflow.providers.*.operators`."""
+from __future__ import annotations
+
+import warnings
+
+from airflow.exceptions import RemovedInAirflow3Warning
+from airflow.utils.deprecation_tools import add_deprecated_classes
+
+warnings.warn(
+ "This package is deprecated. Please use `airflow.operators` or `airflow.providers.*.operators`.",
+ RemovedInAirflow3Warning,
+ stacklevel=2,
+)
+
+__deprecated_classes = {
+ 'adls_list_operator': {
+ 'ADLSListOperator': 'airflow.providers.microsoft.azure.operators.adls.ADLSListOperator',
+ 'AzureDataLakeStorageListOperator': (
+ 'airflow.providers.microsoft.azure.operators.adls.ADLSListOperator'
+ ),
+ },
+ 'adls_to_gcs': {
+ 'ADLSToGCSOperator': 'airflow.providers.google.cloud.transfers.adls_to_gcs.ADLSToGCSOperator',
+ 'AdlsToGoogleCloudStorageOperator': (
+ 'airflow.providers.google.cloud.transfers.adls_to_gcs.ADLSToGCSOperator'
+ ),
+ },
+ 'aws_athena_operator': {
+ 'AWSAthenaOperator': 'airflow.providers.amazon.aws.operators.athena.AthenaOperator',
+ },
+ 'aws_sqs_publish_operator': {
+ 'SqsPublishOperator': 'airflow.providers.amazon.aws.operators.sqs.SqsPublishOperator',
+ 'SQSPublishOperator': 'airflow.providers.amazon.aws.operators.sqs.SqsPublishOperator',
+ },
+ 'awsbatch_operator': {
+ 'BatchProtocol': 'airflow.providers.amazon.aws.hooks.batch_client.BatchProtocol',
+ 'BatchOperator': 'airflow.providers.amazon.aws.operators.batch.BatchOperator',
+ 'AWSBatchOperator': 'airflow.providers.amazon.aws.operators.batch.BatchOperator',
+ },
+ 'azure_container_instances_operator': {
+ 'AzureContainerInstancesOperator': (
+ 'airflow.providers.microsoft.azure.operators.container_instances.AzureContainerInstancesOperator'
+ ),
+ },
+ 'azure_cosmos_operator': {
+ 'AzureCosmosInsertDocumentOperator': (
+ 'airflow.providers.microsoft.azure.operators.cosmos.AzureCosmosInsertDocumentOperator'
+ ),
+ },
+ 'bigquery_check_operator': {
+ 'BigQueryCheckOperator': 'airflow.providers.google.cloud.operators.bigquery.BigQueryCheckOperator',
+ 'BigQueryIntervalCheckOperator': (
+ 'airflow.providers.google.cloud.operators.bigquery.BigQueryIntervalCheckOperator'
+ ),
+ 'BigQueryValueCheckOperator': (
+ 'airflow.providers.google.cloud.operators.bigquery.BigQueryValueCheckOperator'
+ ),
+ },
+ 'bigquery_get_data': {
+ 'BigQueryGetDataOperator': (
+ 'airflow.providers.google.cloud.operators.bigquery.BigQueryGetDataOperator'
+ ),
+ },
+ 'bigquery_operator': {
+ 'BigQueryCreateEmptyDatasetOperator': (
+ 'airflow.providers.google.cloud.operators.bigquery.BigQueryCreateEmptyDatasetOperator'
+ ),
+ 'BigQueryCreateEmptyTableOperator': (
+ 'airflow.providers.google.cloud.operators.bigquery.BigQueryCreateEmptyTableOperator'
+ ),
+ 'BigQueryCreateExternalTableOperator': (
+ 'airflow.providers.google.cloud.operators.bigquery.BigQueryCreateExternalTableOperator'
+ ),
+ 'BigQueryDeleteDatasetOperator': (
+ 'airflow.providers.google.cloud.operators.bigquery.BigQueryDeleteDatasetOperator'
+ ),
+ 'BigQueryExecuteQueryOperator': (
+ 'airflow.providers.google.cloud.operators.bigquery.BigQueryExecuteQueryOperator'
+ ),
+ 'BigQueryGetDatasetOperator': (
+ 'airflow.providers.google.cloud.operators.bigquery.BigQueryGetDatasetOperator'
+ ),
+ 'BigQueryGetDatasetTablesOperator': (
+ 'airflow.providers.google.cloud.operators.bigquery.BigQueryGetDatasetTablesOperator'
+ ),
+ 'BigQueryPatchDatasetOperator': (
+ 'airflow.providers.google.cloud.operators.bigquery.BigQueryPatchDatasetOperator'
+ ),
+ 'BigQueryUpdateDatasetOperator': (
+ 'airflow.providers.google.cloud.operators.bigquery.BigQueryUpdateDatasetOperator'
+ ),
+ 'BigQueryUpsertTableOperator': (
+ 'airflow.providers.google.cloud.operators.bigquery.BigQueryUpsertTableOperator'
+ ),
+ 'BigQueryOperator': 'airflow.providers.google.cloud.operators.bigquery.BigQueryExecuteQueryOperator',
+ },
+ 'bigquery_table_delete_operator': {
+ 'BigQueryDeleteTableOperator': (
+ 'airflow.providers.google.cloud.operators.bigquery.BigQueryDeleteTableOperator'
+ ),
+ 'BigQueryTableDeleteOperator': (
+ 'airflow.providers.google.cloud.operators.bigquery.BigQueryDeleteTableOperator'
+ ),
+ },
+ 'bigquery_to_bigquery': {
+ 'BigQueryToBigQueryOperator': (
+ 'airflow.providers.google.cloud.transfers.bigquery_to_bigquery.BigQueryToBigQueryOperator'
+ ),
+ },
+ 'bigquery_to_gcs': {
+ 'BigQueryToGCSOperator': (
+ 'airflow.providers.google.cloud.transfers.bigquery_to_gcs.BigQueryToGCSOperator'
+ ),
+ 'BigQueryToCloudStorageOperator': (
+ 'airflow.providers.google.cloud.transfers.bigquery_to_gcs.BigQueryToGCSOperator'
+ ),
+ },
+ 'bigquery_to_mysql_operator': {
+ 'BigQueryToMySqlOperator': (
+ 'airflow.providers.google.cloud.transfers.bigquery_to_mysql.BigQueryToMySqlOperator'
+ ),
+ },
+ 'cassandra_to_gcs': {
+ 'CassandraToGCSOperator': (
+ 'airflow.providers.google.cloud.transfers.cassandra_to_gcs.CassandraToGCSOperator'
+ ),
+ 'CassandraToGoogleCloudStorageOperator': (
+ 'airflow.providers.google.cloud.transfers.cassandra_to_gcs.CassandraToGCSOperator'
+ ),
+ },
+ 'databricks_operator': {
+ 'DatabricksRunNowOperator': (
+ 'airflow.providers.databricks.operators.databricks.DatabricksRunNowOperator'
+ ),
+ 'DatabricksSubmitRunOperator': (
+ 'airflow.providers.databricks.operators.databricks.DatabricksSubmitRunOperator'
+ ),
+ },
+ 'dataflow_operator': {
+ 'DataflowCreateJavaJobOperator': (
+ 'airflow.providers.google.cloud.operators.dataflow.DataflowCreateJavaJobOperator'
+ ),
+ 'DataflowCreatePythonJobOperator': (
+ 'airflow.providers.google.cloud.operators.dataflow.DataflowCreatePythonJobOperator'
+ ),
+ 'DataflowTemplatedJobStartOperator': (
+ 'airflow.providers.google.cloud.operators.dataflow.DataflowTemplatedJobStartOperator'
+ ),
+ 'DataFlowJavaOperator': (
+ 'airflow.providers.google.cloud.operators.dataflow.DataflowCreateJavaJobOperator'
+ ),
+ 'DataFlowPythonOperator': (
+ 'airflow.providers.google.cloud.operators.dataflow.DataflowCreatePythonJobOperator'
+ ),
+ 'DataflowTemplateOperator': (
+ 'airflow.providers.google.cloud.operators.dataflow.DataflowTemplatedJobStartOperator'
+ ),
+ },
+ 'dataproc_operator': {
+ 'DataprocCreateClusterOperator': (
+ 'airflow.providers.google.cloud.operators.dataproc.DataprocCreateClusterOperator'
+ ),
+ 'DataprocDeleteClusterOperator': (
+ 'airflow.providers.google.cloud.operators.dataproc.DataprocDeleteClusterOperator'
+ ),
+ 'DataprocInstantiateInlineWorkflowTemplateOperator':
+ 'airflow.providers.google.cloud.operators.dataproc.'
+ 'DataprocInstantiateInlineWorkflowTemplateOperator',
+ 'DataprocInstantiateWorkflowTemplateOperator': (
+ 'airflow.providers.google.cloud.operators.dataproc.DataprocInstantiateWorkflowTemplateOperator'
+ ),
+ 'DataprocJobBaseOperator': (
+ 'airflow.providers.google.cloud.operators.dataproc.DataprocJobBaseOperator'
+ ),
+ 'DataprocScaleClusterOperator': (
+ 'airflow.providers.google.cloud.operators.dataproc.DataprocScaleClusterOperator'
+ ),
+ 'DataprocSubmitHadoopJobOperator': (
+ 'airflow.providers.google.cloud.operators.dataproc.DataprocSubmitHadoopJobOperator'
+ ),
+ 'DataprocSubmitHiveJobOperator': (
+ 'airflow.providers.google.cloud.operators.dataproc.DataprocSubmitHiveJobOperator'
+ ),
+ 'DataprocSubmitPigJobOperator': (
+ 'airflow.providers.google.cloud.operators.dataproc.DataprocSubmitPigJobOperator'
+ ),
+ 'DataprocSubmitPySparkJobOperator': (
+ 'airflow.providers.google.cloud.operators.dataproc.DataprocSubmitPySparkJobOperator'
+ ),
+ 'DataprocSubmitSparkJobOperator': (
+ 'airflow.providers.google.cloud.operators.dataproc.DataprocSubmitSparkJobOperator'
+ ),
+ 'DataprocSubmitSparkSqlJobOperator': (
+ 'airflow.providers.google.cloud.operators.dataproc.DataprocSubmitSparkSqlJobOperator'
+ ),
+ 'DataprocClusterCreateOperator': (
+ 'airflow.providers.google.cloud.operators.dataproc.DataprocCreateClusterOperator'
+ ),
+ 'DataprocClusterDeleteOperator': (
+ 'airflow.providers.google.cloud.operators.dataproc.DataprocDeleteClusterOperator'
+ ),
+ 'DataprocClusterScaleOperator': (
+ 'airflow.providers.google.cloud.operators.dataproc.DataprocScaleClusterOperator'
+ ),
+ 'DataProcHadoopOperator': (
+ 'airflow.providers.google.cloud.operators.dataproc.DataprocSubmitHadoopJobOperator'
+ ),
+ 'DataProcHiveOperator': (
+ 'airflow.providers.google.cloud.operators.dataproc.DataprocSubmitHiveJobOperator'
+ ),
+ 'DataProcJobBaseOperator': (
+ 'airflow.providers.google.cloud.operators.dataproc.DataprocJobBaseOperator'
+ ),
+ 'DataProcPigOperator': (
+ 'airflow.providers.google.cloud.operators.dataproc.DataprocSubmitPigJobOperator'
+ ),
+ 'DataProcPySparkOperator': (
+ 'airflow.providers.google.cloud.operators.dataproc.DataprocSubmitPySparkJobOperator'
+ ),
+ 'DataProcSparkOperator': (
+ 'airflow.providers.google.cloud.operators.dataproc.DataprocSubmitSparkJobOperator'
+ ),
+ 'DataProcSparkSqlOperator': (
+ 'airflow.providers.google.cloud.operators.dataproc.DataprocSubmitSparkSqlJobOperator'
+ ),
+ 'DataprocWorkflowTemplateInstantiateInlineOperator':
+ 'airflow.providers.google.cloud.operators.dataproc.'
+ 'DataprocInstantiateInlineWorkflowTemplateOperator',
+ 'DataprocWorkflowTemplateInstantiateOperator': (
+ 'airflow.providers.google.cloud.operators.dataproc.DataprocInstantiateWorkflowTemplateOperator'
+ ),
+ },
+ 'datastore_export_operator': {
+ 'CloudDatastoreExportEntitiesOperator': (
+ 'airflow.providers.google.cloud.operators.datastore.CloudDatastoreExportEntitiesOperator'
+ ),
+ 'DatastoreExportOperator': (
+ 'airflow.providers.google.cloud.operators.datastore.CloudDatastoreExportEntitiesOperator'
+ ),
+ },
+ 'datastore_import_operator': {
+ 'CloudDatastoreImportEntitiesOperator': (
+ 'airflow.providers.google.cloud.operators.datastore.CloudDatastoreImportEntitiesOperator'
+ ),
+ 'DatastoreImportOperator': (
+ 'airflow.providers.google.cloud.operators.datastore.CloudDatastoreImportEntitiesOperator'
+ ),
+ },
+ 'dingding_operator': {
+ 'DingdingOperator': 'airflow.providers.dingding.operators.dingding.DingdingOperator',
+ },
+ 'discord_webhook_operator': {
+ 'DiscordWebhookOperator': (
+ 'airflow.providers.discord.operators.discord_webhook.DiscordWebhookOperator'
+ ),
+ },
+ 'docker_swarm_operator': {
+ 'DockerSwarmOperator': 'airflow.providers.docker.operators.docker_swarm.DockerSwarmOperator',
+ },
+ 'druid_operator': {
+ 'DruidOperator': 'airflow.providers.apache.druid.operators.druid.DruidOperator',
+ },
+ 'dynamodb_to_s3': {
+ 'DynamoDBToS3Operator': 'airflow.providers.amazon.aws.transfers.dynamodb_to_s3.DynamoDBToS3Operator',
+ },
+ 'ecs_operator': {
+ 'EcsProtocol': 'airflow.providers.amazon.aws.hooks.ecs.EcsProtocol',
+ 'EcsRunTaskOperator': 'airflow.providers.amazon.aws.operators.ecs.EcsRunTaskOperator',
+ 'EcsOperator': 'airflow.providers.amazon.aws.operators.ecs.EcsRunTaskOperator',
+ },
+ 'file_to_gcs': {
+ 'LocalFilesystemToGCSOperator': (
+ 'airflow.providers.google.cloud.transfers.local_to_gcs.LocalFilesystemToGCSOperator'
+ ),
+ 'FileToGoogleCloudStorageOperator': (
+ 'airflow.providers.google.cloud.transfers.local_to_gcs.LocalFilesystemToGCSOperator'
+ ),
+ },
+ 'file_to_wasb': {
+ 'LocalFilesystemToWasbOperator': (
+ 'airflow.providers.microsoft.azure.transfers.local_to_wasb.LocalFilesystemToWasbOperator'
+ ),
+ 'FileToWasbOperator': (
+ 'airflow.providers.microsoft.azure.transfers.local_to_wasb.LocalFilesystemToWasbOperator'
+ ),
+ },
+ 'gcp_bigtable_operator': {
+ 'BigtableCreateInstanceOperator': (
+ 'airflow.providers.google.cloud.operators.bigtable.BigtableCreateInstanceOperator'
+ ),
+ 'BigtableCreateTableOperator': (
+ 'airflow.providers.google.cloud.operators.bigtable.BigtableCreateTableOperator'
+ ),
+ 'BigtableDeleteInstanceOperator': (
+ 'airflow.providers.google.cloud.operators.bigtable.BigtableDeleteInstanceOperator'
+ ),
+ 'BigtableDeleteTableOperator': (
+ 'airflow.providers.google.cloud.operators.bigtable.BigtableDeleteTableOperator'
+ ),
+ 'BigtableUpdateClusterOperator': (
+ 'airflow.providers.google.cloud.operators.bigtable.BigtableUpdateClusterOperator'
+ ),
+ 'BigtableTableReplicationCompletedSensor': (
+ 'airflow.providers.google.cloud.sensors.bigtable.BigtableTableReplicationCompletedSensor'
+ ),
+ 'BigtableClusterUpdateOperator': (
+ 'airflow.providers.google.cloud.operators.bigtable.BigtableUpdateClusterOperator'
+ ),
+ 'BigtableInstanceCreateOperator': (
+ 'airflow.providers.google.cloud.operators.bigtable.BigtableCreateInstanceOperator'
+ ),
+ 'BigtableInstanceDeleteOperator': (
+ 'airflow.providers.google.cloud.operators.bigtable.BigtableDeleteInstanceOperator'
+ ),
+ 'BigtableTableCreateOperator': (
+ 'airflow.providers.google.cloud.operators.bigtable.BigtableCreateTableOperator'
+ ),
+ 'BigtableTableDeleteOperator': (
+ 'airflow.providers.google.cloud.operators.bigtable.BigtableDeleteTableOperator'
+ ),
+ 'BigtableTableWaitForReplicationSensor': (
+ 'airflow.providers.google.cloud.sensors.bigtable.BigtableTableReplicationCompletedSensor'
+ ),
+ },
+ 'gcp_cloud_build_operator': {
+ 'CloudBuildCreateBuildOperator': (
+ 'airflow.providers.google.cloud.operators.cloud_build.CloudBuildCreateBuildOperator'
+ ),
+ },
+ 'gcp_compute_operator': {
+ 'ComputeEngineBaseOperator': (
+ 'airflow.providers.google.cloud.operators.compute.ComputeEngineBaseOperator'
+ ),
+ 'ComputeEngineCopyInstanceTemplateOperator': (
+ 'airflow.providers.google.cloud.operators.compute.ComputeEngineCopyInstanceTemplateOperator'
+ ),
+ 'ComputeEngineInstanceGroupUpdateManagerTemplateOperator':
+ 'airflow.providers.google.cloud.operators.compute.'
+ 'ComputeEngineInstanceGroupUpdateManagerTemplateOperator',
+ 'ComputeEngineSetMachineTypeOperator': (
+ 'airflow.providers.google.cloud.operators.compute.ComputeEngineSetMachineTypeOperator'
+ ),
+ 'ComputeEngineStartInstanceOperator': (
+ 'airflow.providers.google.cloud.operators.compute.ComputeEngineStartInstanceOperator'
+ ),
+ 'ComputeEngineStopInstanceOperator': (
+ 'airflow.providers.google.cloud.operators.compute.ComputeEngineStopInstanceOperator'
+ ),
+ 'GceBaseOperator': 'airflow.providers.google.cloud.operators.compute.ComputeEngineBaseOperator',
+ 'GceInstanceGroupManagerUpdateTemplateOperator':
+ 'airflow.providers.google.cloud.operators.compute.'
+ 'ComputeEngineInstanceGroupUpdateManagerTemplateOperator',
+ 'GceInstanceStartOperator': (
+ 'airflow.providers.google.cloud.operators.compute.ComputeEngineStartInstanceOperator'
+ ),
+ 'GceInstanceStopOperator': (
+ 'airflow.providers.google.cloud.operators.compute.ComputeEngineStopInstanceOperator'
+ ),
+ 'GceInstanceTemplateCopyOperator': (
+ 'airflow.providers.google.cloud.operators.compute.ComputeEngineCopyInstanceTemplateOperator'
+ ),
+ 'GceSetMachineTypeOperator': (
+ 'airflow.providers.google.cloud.operators.compute.ComputeEngineSetMachineTypeOperator'
+ ),
+ },
+ 'gcp_container_operator': {
+ 'GKECreateClusterOperator': (
+ 'airflow.providers.google.cloud.operators.kubernetes_engine.GKECreateClusterOperator'
+ ),
+ 'GKEDeleteClusterOperator': (
+ 'airflow.providers.google.cloud.operators.kubernetes_engine.GKEDeleteClusterOperator'
+ ),
+ 'GKEStartPodOperator': (
+ 'airflow.providers.google.cloud.operators.kubernetes_engine.GKEStartPodOperator'
+ ),
+ 'GKEClusterCreateOperator': (
+ 'airflow.providers.google.cloud.operators.kubernetes_engine.GKECreateClusterOperator'
+ ),
+ 'GKEClusterDeleteOperator': (
+ 'airflow.providers.google.cloud.operators.kubernetes_engine.GKEDeleteClusterOperator'
+ ),
+ 'GKEPodOperator': 'airflow.providers.google.cloud.operators.kubernetes_engine.GKEStartPodOperator',
+ },
+ 'gcp_dlp_operator': {
+ 'CloudDLPCancelDLPJobOperator': (
+ 'airflow.providers.google.cloud.operators.dlp.CloudDLPCancelDLPJobOperator'
+ ),
+ 'CloudDLPCreateDeidentifyTemplateOperator': (
+ 'airflow.providers.google.cloud.operators.dlp.CloudDLPCreateDeidentifyTemplateOperator'
+ ),
+ 'CloudDLPCreateDLPJobOperator': (
+ 'airflow.providers.google.cloud.operators.dlp.CloudDLPCreateDLPJobOperator'
+ ),
+ 'CloudDLPCreateInspectTemplateOperator': (
+ 'airflow.providers.google.cloud.operators.dlp.CloudDLPCreateInspectTemplateOperator'
+ ),
+ 'CloudDLPCreateJobTriggerOperator': (
+ 'airflow.providers.google.cloud.operators.dlp.CloudDLPCreateJobTriggerOperator'
+ ),
+ 'CloudDLPCreateStoredInfoTypeOperator': (
+ 'airflow.providers.google.cloud.operators.dlp.CloudDLPCreateStoredInfoTypeOperator'
+ ),
+ 'CloudDLPDeidentifyContentOperator': (
+ 'airflow.providers.google.cloud.operators.dlp.CloudDLPDeidentifyContentOperator'
+ ),
+ 'CloudDLPDeleteDeidentifyTemplateOperator': (
+ 'airflow.providers.google.cloud.operators.dlp.CloudDLPDeleteDeidentifyTemplateOperator'
+ ),
+ 'CloudDLPDeleteDLPJobOperator': (
+ 'airflow.providers.google.cloud.operators.dlp.CloudDLPDeleteDLPJobOperator'
+ ),
+ 'CloudDLPDeleteInspectTemplateOperator': (
+ 'airflow.providers.google.cloud.operators.dlp.CloudDLPDeleteInspectTemplateOperator'
+ ),
+ 'CloudDLPDeleteJobTriggerOperator': (
+ 'airflow.providers.google.cloud.operators.dlp.CloudDLPDeleteJobTriggerOperator'
+ ),
+ 'CloudDLPDeleteStoredInfoTypeOperator': (
+ 'airflow.providers.google.cloud.operators.dlp.CloudDLPDeleteStoredInfoTypeOperator'
+ ),
+ 'CloudDLPGetDeidentifyTemplateOperator': (
+ 'airflow.providers.google.cloud.operators.dlp.CloudDLPGetDeidentifyTemplateOperator'
+ ),
+ 'CloudDLPGetDLPJobOperator': 'airflow.providers.google.cloud.operators.dlp.CloudDLPGetDLPJobOperator',
+ 'CloudDLPGetDLPJobTriggerOperator': (
+ 'airflow.providers.google.cloud.operators.dlp.CloudDLPGetDLPJobTriggerOperator'
+ ),
+ 'CloudDLPGetInspectTemplateOperator': (
+ 'airflow.providers.google.cloud.operators.dlp.CloudDLPGetInspectTemplateOperator'
+ ),
+ 'CloudDLPGetStoredInfoTypeOperator': (
+ 'airflow.providers.google.cloud.operators.dlp.CloudDLPGetStoredInfoTypeOperator'
+ ),
+ 'CloudDLPInspectContentOperator': (
+ 'airflow.providers.google.cloud.operators.dlp.CloudDLPInspectContentOperator'
+ ),
+ 'CloudDLPListDeidentifyTemplatesOperator': (
+ 'airflow.providers.google.cloud.operators.dlp.CloudDLPListDeidentifyTemplatesOperator'
+ ),
+ 'CloudDLPListDLPJobsOperator': (
+ 'airflow.providers.google.cloud.operators.dlp.CloudDLPListDLPJobsOperator'
+ ),
+ 'CloudDLPListInfoTypesOperator': (
+ 'airflow.providers.google.cloud.operators.dlp.CloudDLPListInfoTypesOperator'
+ ),
+ 'CloudDLPListInspectTemplatesOperator': (
+ 'airflow.providers.google.cloud.operators.dlp.CloudDLPListInspectTemplatesOperator'
+ ),
+ 'CloudDLPListJobTriggersOperator': (
+ 'airflow.providers.google.cloud.operators.dlp.CloudDLPListJobTriggersOperator'
+ ),
+ 'CloudDLPListStoredInfoTypesOperator': (
+ 'airflow.providers.google.cloud.operators.dlp.CloudDLPListStoredInfoTypesOperator'
+ ),
+ 'CloudDLPRedactImageOperator': (
+ 'airflow.providers.google.cloud.operators.dlp.CloudDLPRedactImageOperator'
+ ),
+ 'CloudDLPReidentifyContentOperator': (
+ 'airflow.providers.google.cloud.operators.dlp.CloudDLPReidentifyContentOperator'
+ ),
+ 'CloudDLPUpdateDeidentifyTemplateOperator': (
+ 'airflow.providers.google.cloud.operators.dlp.CloudDLPUpdateDeidentifyTemplateOperator'
+ ),
+ 'CloudDLPUpdateInspectTemplateOperator': (
+ 'airflow.providers.google.cloud.operators.dlp.CloudDLPUpdateInspectTemplateOperator'
+ ),
+ 'CloudDLPUpdateJobTriggerOperator': (
+ 'airflow.providers.google.cloud.operators.dlp.CloudDLPUpdateJobTriggerOperator'
+ ),
+ 'CloudDLPUpdateStoredInfoTypeOperator': (
+ 'airflow.providers.google.cloud.operators.dlp.CloudDLPUpdateStoredInfoTypeOperator'
+ ),
+ 'CloudDLPDeleteDlpJobOperator': (
+ 'airflow.providers.google.cloud.operators.dlp.CloudDLPDeleteDLPJobOperator'
+ ),
+ 'CloudDLPGetDlpJobOperator': 'airflow.providers.google.cloud.operators.dlp.CloudDLPGetDLPJobOperator',
+ 'CloudDLPGetJobTripperOperator': (
+ 'airflow.providers.google.cloud.operators.dlp.CloudDLPGetDLPJobTriggerOperator'
+ ),
+ 'CloudDLPListDlpJobsOperator': (
+ 'airflow.providers.google.cloud.operators.dlp.CloudDLPListDLPJobsOperator'
+ ),
+ },
+ 'gcp_function_operator': {
+ 'CloudFunctionDeleteFunctionOperator': (
+ 'airflow.providers.google.cloud.operators.functions.CloudFunctionDeleteFunctionOperator'
+ ),
+ 'CloudFunctionDeployFunctionOperator': (
+ 'airflow.providers.google.cloud.operators.functions.CloudFunctionDeployFunctionOperator'
+ ),
+ 'GcfFunctionDeleteOperator': (
+ 'airflow.providers.google.cloud.operators.functions.CloudFunctionDeleteFunctionOperator'
+ ),
+ 'GcfFunctionDeployOperator': (
+ 'airflow.providers.google.cloud.operators.functions.CloudFunctionDeployFunctionOperator'
+ ),
+ },
+ 'gcp_natural_language_operator': {
+ 'CloudNaturalLanguageAnalyzeEntitiesOperator':
+ 'airflow.providers.google.cloud.operators.natural_language.'
+ 'CloudNaturalLanguageAnalyzeEntitiesOperator',
+ 'CloudNaturalLanguageAnalyzeEntitySentimentOperator':
+ 'airflow.providers.google.cloud.operators.natural_language.'
+ 'CloudNaturalLanguageAnalyzeEntitySentimentOperator',
+ 'CloudNaturalLanguageAnalyzeSentimentOperator':
+ 'airflow.providers.google.cloud.operators.natural_language.'
+ 'CloudNaturalLanguageAnalyzeSentimentOperator',
+ 'CloudNaturalLanguageClassifyTextOperator':
+ 'airflow.providers.google.cloud.operators.natural_language.'
+ 'CloudNaturalLanguageClassifyTextOperator',
+ 'CloudLanguageAnalyzeEntitiesOperator':
+ 'airflow.providers.google.cloud.operators.natural_language.'
+ 'CloudNaturalLanguageAnalyzeEntitiesOperator',
+ 'CloudLanguageAnalyzeEntitySentimentOperator':
+ 'airflow.providers.google.cloud.operators.natural_language.'
+ 'CloudNaturalLanguageAnalyzeEntitySentimentOperator',
+ 'CloudLanguageAnalyzeSentimentOperator':
+ 'airflow.providers.google.cloud.operators.natural_language.'
+ 'CloudNaturalLanguageAnalyzeSentimentOperator',
+ 'CloudLanguageClassifyTextOperator':
+ 'airflow.providers.google.cloud.operators.natural_language.'
+ 'CloudNaturalLanguageClassifyTextOperator',
+ },
+ 'gcp_spanner_operator': {
+ 'SpannerDeleteDatabaseInstanceOperator': (
+ 'airflow.providers.google.cloud.operators.spanner.SpannerDeleteDatabaseInstanceOperator'
+ ),
+ 'SpannerDeleteInstanceOperator': (
+ 'airflow.providers.google.cloud.operators.spanner.SpannerDeleteInstanceOperator'
+ ),
+ 'SpannerDeployDatabaseInstanceOperator': (
+ 'airflow.providers.google.cloud.operators.spanner.SpannerDeployDatabaseInstanceOperator'
+ ),
+ 'SpannerDeployInstanceOperator': (
+ 'airflow.providers.google.cloud.operators.spanner.SpannerDeployInstanceOperator'
+ ),
+ 'SpannerQueryDatabaseInstanceOperator': (
+ 'airflow.providers.google.cloud.operators.spanner.SpannerQueryDatabaseInstanceOperator'
+ ),
+ 'SpannerUpdateDatabaseInstanceOperator': (
+ 'airflow.providers.google.cloud.operators.spanner.SpannerUpdateDatabaseInstanceOperator'
+ ),
+ 'CloudSpannerInstanceDatabaseDeleteOperator': (
+ 'airflow.providers.google.cloud.operators.spanner.SpannerDeleteDatabaseInstanceOperator'
+ ),
+ 'CloudSpannerInstanceDatabaseDeployOperator': (
+ 'airflow.providers.google.cloud.operators.spanner.SpannerDeployDatabaseInstanceOperator'
+ ),
+ 'CloudSpannerInstanceDatabaseQueryOperator': (
+ 'airflow.providers.google.cloud.operators.spanner.SpannerQueryDatabaseInstanceOperator'
+ ),
+ 'CloudSpannerInstanceDatabaseUpdateOperator': (
+ 'airflow.providers.google.cloud.operators.spanner.SpannerUpdateDatabaseInstanceOperator'
+ ),
+ 'CloudSpannerInstanceDeleteOperator': (
+ 'airflow.providers.google.cloud.operators.spanner.SpannerDeleteInstanceOperator'
+ ),
+ 'CloudSpannerInstanceDeployOperator': (
+ 'airflow.providers.google.cloud.operators.spanner.SpannerDeployInstanceOperator'
+ ),
+ },
+ 'gcp_speech_to_text_operator': {
+ 'CloudSpeechToTextRecognizeSpeechOperator': (
+ 'airflow.providers.google.cloud.operators.speech_to_text.CloudSpeechToTextRecognizeSpeechOperator'
+ ),
+ 'GcpSpeechToTextRecognizeSpeechOperator': (
+ 'airflow.providers.google.cloud.operators.speech_to_text.CloudSpeechToTextRecognizeSpeechOperator'
+ ),
+ },
+ 'gcp_sql_operator': {
+ 'CloudSQLBaseOperator': 'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLBaseOperator',
+ 'CloudSQLCreateInstanceDatabaseOperator': (
+ 'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLCreateInstanceDatabaseOperator'
+ ),
+ 'CloudSQLCreateInstanceOperator': (
+ 'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLCreateInstanceOperator'
+ ),
+ 'CloudSQLDeleteInstanceDatabaseOperator': (
+ 'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLDeleteInstanceDatabaseOperator'
+ ),
+ 'CloudSQLDeleteInstanceOperator': (
+ 'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLDeleteInstanceOperator'
+ ),
+ 'CloudSQLExecuteQueryOperator': (
+ 'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLExecuteQueryOperator'
+ ),
+ 'CloudSQLExportInstanceOperator': (
+ 'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLExportInstanceOperator'
+ ),
+ 'CloudSQLImportInstanceOperator': (
+ 'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLImportInstanceOperator'
+ ),
+ 'CloudSQLInstancePatchOperator': (
+ 'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLInstancePatchOperator'
+ ),
+ 'CloudSQLPatchInstanceDatabaseOperator': (
+ 'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLPatchInstanceDatabaseOperator'
+ ),
+ 'CloudSqlBaseOperator': 'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLBaseOperator',
+ 'CloudSqlInstanceCreateOperator': (
+ 'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLCreateInstanceOperator'
+ ),
+ 'CloudSqlInstanceDatabaseCreateOperator': (
+ 'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLCreateInstanceDatabaseOperator'
+ ),
+ 'CloudSqlInstanceDatabaseDeleteOperator': (
+ 'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLDeleteInstanceDatabaseOperator'
+ ),
+ 'CloudSqlInstanceDatabasePatchOperator': (
+ 'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLPatchInstanceDatabaseOperator'
+ ),
+ 'CloudSqlInstanceDeleteOperator': (
+ 'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLDeleteInstanceOperator'
+ ),
+ 'CloudSqlInstanceExportOperator': (
+ 'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLExportInstanceOperator'
+ ),
+ 'CloudSqlInstanceImportOperator': (
+ 'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLImportInstanceOperator'
+ ),
+ 'CloudSqlInstancePatchOperator': (
+ 'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLInstancePatchOperator'
+ ),
+ 'CloudSqlQueryOperator': (
+ 'airflow.providers.google.cloud.operators.cloud_sql.CloudSQLExecuteQueryOperator'
+ ),
+ },
+ 'gcp_tasks_operator': {
+ 'CloudTasksQueueCreateOperator': (
+ 'airflow.providers.google.cloud.operators.tasks.CloudTasksQueueCreateOperator'
+ ),
+ 'CloudTasksQueueDeleteOperator': (
+ 'airflow.providers.google.cloud.operators.tasks.CloudTasksQueueDeleteOperator'
+ ),
+ 'CloudTasksQueueGetOperator': (
+ 'airflow.providers.google.cloud.operators.tasks.CloudTasksQueueGetOperator'
+ ),
+ 'CloudTasksQueuePauseOperator': (
+ 'airflow.providers.google.cloud.operators.tasks.CloudTasksQueuePauseOperator'
+ ),
+ 'CloudTasksQueuePurgeOperator': (
+ 'airflow.providers.google.cloud.operators.tasks.CloudTasksQueuePurgeOperator'
+ ),
+ 'CloudTasksQueueResumeOperator': (
+ 'airflow.providers.google.cloud.operators.tasks.CloudTasksQueueResumeOperator'
+ ),
+ 'CloudTasksQueuesListOperator': (
+ 'airflow.providers.google.cloud.operators.tasks.CloudTasksQueuesListOperator'
+ ),
+ 'CloudTasksQueueUpdateOperator': (
+ 'airflow.providers.google.cloud.operators.tasks.CloudTasksQueueUpdateOperator'
+ ),
+ 'CloudTasksTaskCreateOperator': (
+ 'airflow.providers.google.cloud.operators.tasks.CloudTasksTaskCreateOperator'
+ ),
+ 'CloudTasksTaskDeleteOperator': (
+ 'airflow.providers.google.cloud.operators.tasks.CloudTasksTaskDeleteOperator'
+ ),
+ 'CloudTasksTaskGetOperator': (
+ 'airflow.providers.google.cloud.operators.tasks.CloudTasksTaskGetOperator'
+ ),
+ 'CloudTasksTaskRunOperator': (
+ 'airflow.providers.google.cloud.operators.tasks.CloudTasksTaskRunOperator'
+ ),
+ 'CloudTasksTasksListOperator': (
+ 'airflow.providers.google.cloud.operators.tasks.CloudTasksTasksListOperator'
+ ),
+ },
+ 'gcp_text_to_speech_operator': {
+ 'CloudTextToSpeechSynthesizeOperator': (
+ 'airflow.providers.google.cloud.operators.text_to_speech.CloudTextToSpeechSynthesizeOperator'
+ ),
+ 'GcpTextToSpeechSynthesizeOperator': (
+ 'airflow.providers.google.cloud.operators.text_to_speech.CloudTextToSpeechSynthesizeOperator'
+ ),
+ },
+ 'gcp_transfer_operator': {
+ 'CloudDataTransferServiceCancelOperationOperator':
+ 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.'
+ 'CloudDataTransferServiceCancelOperationOperator',
+ 'CloudDataTransferServiceCreateJobOperator':
+ 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.'
+ 'CloudDataTransferServiceCreateJobOperator',
+ 'CloudDataTransferServiceDeleteJobOperator':
+ 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.'
+ 'CloudDataTransferServiceDeleteJobOperator',
+ 'CloudDataTransferServiceGCSToGCSOperator':
+ 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.'
+ 'CloudDataTransferServiceGCSToGCSOperator',
+ 'CloudDataTransferServiceGetOperationOperator':
+ 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.'
+ 'CloudDataTransferServiceGetOperationOperator',
+ 'CloudDataTransferServiceListOperationsOperator':
+ 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.'
+ 'CloudDataTransferServiceListOperationsOperator',
+ 'CloudDataTransferServicePauseOperationOperator':
+ 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.'
+ 'CloudDataTransferServicePauseOperationOperator',
+ 'CloudDataTransferServiceResumeOperationOperator':
+ 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.'
+ 'CloudDataTransferServiceResumeOperationOperator',
+ 'CloudDataTransferServiceS3ToGCSOperator':
+ 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.'
+ 'CloudDataTransferServiceS3ToGCSOperator',
+ 'CloudDataTransferServiceUpdateJobOperator':
+ 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.'
+ 'CloudDataTransferServiceUpdateJobOperator',
+ 'GcpTransferServiceJobCreateOperator':
+ 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.'
+ 'CloudDataTransferServiceCreateJobOperator',
+ 'GcpTransferServiceJobDeleteOperator':
+ 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.'
+ 'CloudDataTransferServiceDeleteJobOperator',
+ 'GcpTransferServiceJobUpdateOperator':
+ 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.'
+ 'CloudDataTransferServiceUpdateJobOperator',
+ 'GcpTransferServiceOperationCancelOperator':
+ 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.'
+ 'CloudDataTransferServiceCancelOperationOperator',
+ 'GcpTransferServiceOperationGetOperator':
+ 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.'
+ 'CloudDataTransferServiceGetOperationOperator',
+ 'GcpTransferServiceOperationPauseOperator':
+ 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.'
+ 'CloudDataTransferServicePauseOperationOperator',
+ 'GcpTransferServiceOperationResumeOperator':
+ 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.'
+ 'CloudDataTransferServiceResumeOperationOperator',
+ 'GcpTransferServiceOperationsListOperator':
+ 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.'
+ 'CloudDataTransferServiceListOperationsOperator',
+ 'GoogleCloudStorageToGoogleCloudStorageTransferOperator':
+ 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.'
+ 'CloudDataTransferServiceGCSToGCSOperator',
+ 'S3ToGoogleCloudStorageTransferOperator':
+ 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.'
+ 'CloudDataTransferServiceS3ToGCSOperator',
+ },
+ 'gcp_translate_operator': {
+ 'CloudTranslateTextOperator': (
+ 'airflow.providers.google.cloud.operators.translate.CloudTranslateTextOperator'
+ ),
+ },
+ 'gcp_translate_speech_operator': {
+ 'CloudTranslateSpeechOperator': (
+ 'airflow.providers.google.cloud.operators.translate_speech.CloudTranslateSpeechOperator'
+ ),
+ 'GcpTranslateSpeechOperator': (
+ 'airflow.providers.google.cloud.operators.translate_speech.CloudTranslateSpeechOperator'
+ ),
+ },
+ 'gcp_video_intelligence_operator': {
+ 'CloudVideoIntelligenceDetectVideoExplicitContentOperator':
+ 'airflow.providers.google.cloud.operators.video_intelligence.'
+ 'CloudVideoIntelligenceDetectVideoExplicitContentOperator',
+ 'CloudVideoIntelligenceDetectVideoLabelsOperator':
+ 'airflow.providers.google.cloud.operators.video_intelligence.'
+ 'CloudVideoIntelligenceDetectVideoLabelsOperator',
+ 'CloudVideoIntelligenceDetectVideoShotsOperator':
+ 'airflow.providers.google.cloud.operators.video_intelligence.'
+ 'CloudVideoIntelligenceDetectVideoShotsOperator',
+ },
+ 'gcp_vision_operator': {
+ 'CloudVisionAddProductToProductSetOperator': (
+ 'airflow.providers.google.cloud.operators.vision.CloudVisionAddProductToProductSetOperator'
+ ),
+ 'CloudVisionCreateProductOperator': (
+ 'airflow.providers.google.cloud.operators.vision.CloudVisionCreateProductOperator'
+ ),
+ 'CloudVisionCreateProductSetOperator': (
+ 'airflow.providers.google.cloud.operators.vision.CloudVisionCreateProductSetOperator'
+ ),
+ 'CloudVisionCreateReferenceImageOperator': (
+ 'airflow.providers.google.cloud.operators.vision.CloudVisionCreateReferenceImageOperator'
+ ),
+ 'CloudVisionDeleteProductOperator': (
+ 'airflow.providers.google.cloud.operators.vision.CloudVisionDeleteProductOperator'
+ ),
+ 'CloudVisionDeleteProductSetOperator': (
+ 'airflow.providers.google.cloud.operators.vision.CloudVisionDeleteProductSetOperator'
+ ),
+ 'CloudVisionDetectImageLabelsOperator': (
+ 'airflow.providers.google.cloud.operators.vision.CloudVisionDetectImageLabelsOperator'
+ ),
+ 'CloudVisionDetectImageSafeSearchOperator': (
+ 'airflow.providers.google.cloud.operators.vision.CloudVisionDetectImageSafeSearchOperator'
+ ),
+ 'CloudVisionDetectTextOperator': (
+ 'airflow.providers.google.cloud.operators.vision.CloudVisionDetectTextOperator'
+ ),
+ 'CloudVisionGetProductOperator': (
+ 'airflow.providers.google.cloud.operators.vision.CloudVisionGetProductOperator'
+ ),
+ 'CloudVisionGetProductSetOperator': (
+ 'airflow.providers.google.cloud.operators.vision.CloudVisionGetProductSetOperator'
+ ),
+ 'CloudVisionImageAnnotateOperator': (
+ 'airflow.providers.google.cloud.operators.vision.CloudVisionImageAnnotateOperator'
+ ),
+ 'CloudVisionRemoveProductFromProductSetOperator': (
+ 'airflow.providers.google.cloud.operators.vision.CloudVisionRemoveProductFromProductSetOperator'
+ ),
+ 'CloudVisionTextDetectOperator': (
+ 'airflow.providers.google.cloud.operators.vision.CloudVisionTextDetectOperator'
+ ),
+ 'CloudVisionUpdateProductOperator': (
+ 'airflow.providers.google.cloud.operators.vision.CloudVisionUpdateProductOperator'
+ ),
+ 'CloudVisionUpdateProductSetOperator': (
+ 'airflow.providers.google.cloud.operators.vision.CloudVisionUpdateProductSetOperator'
+ ),
+ 'CloudVisionAnnotateImageOperator': (
+ 'airflow.providers.google.cloud.operators.vision.CloudVisionImageAnnotateOperator'
+ ),
+ 'CloudVisionDetectDocumentTextOperator': (
+ 'airflow.providers.google.cloud.operators.vision.CloudVisionTextDetectOperator'
+ ),
+ 'CloudVisionProductCreateOperator': (
+ 'airflow.providers.google.cloud.operators.vision.CloudVisionCreateProductOperator'
+ ),
+ 'CloudVisionProductDeleteOperator': (
+ 'airflow.providers.google.cloud.operators.vision.CloudVisionDeleteProductOperator'
+ ),
+ 'CloudVisionProductGetOperator': (
+ 'airflow.providers.google.cloud.operators.vision.CloudVisionGetProductOperator'
+ ),
+ 'CloudVisionProductSetCreateOperator': (
+ 'airflow.providers.google.cloud.operators.vision.CloudVisionCreateProductSetOperator'
+ ),
+ 'CloudVisionProductSetDeleteOperator': (
+ 'airflow.providers.google.cloud.operators.vision.CloudVisionDeleteProductSetOperator'
+ ),
+ 'CloudVisionProductSetGetOperator': (
+ 'airflow.providers.google.cloud.operators.vision.CloudVisionGetProductSetOperator'
+ ),
+ 'CloudVisionProductSetUpdateOperator': (
+ 'airflow.providers.google.cloud.operators.vision.CloudVisionUpdateProductSetOperator'
+ ),
+ 'CloudVisionProductUpdateOperator': (
+ 'airflow.providers.google.cloud.operators.vision.CloudVisionUpdateProductOperator'
+ ),
+ 'CloudVisionReferenceImageCreateOperator': (
+ 'airflow.providers.google.cloud.operators.vision.CloudVisionCreateReferenceImageOperator'
+ ),
+ },
+ 'gcs_acl_operator': {
+ 'GCSBucketCreateAclEntryOperator': (
+ 'airflow.providers.google.cloud.operators.gcs.GCSBucketCreateAclEntryOperator'
+ ),
+ 'GCSObjectCreateAclEntryOperator': (
+ 'airflow.providers.google.cloud.operators.gcs.GCSObjectCreateAclEntryOperator'
+ ),
+ 'GoogleCloudStorageBucketCreateAclEntryOperator': (
+ 'airflow.providers.google.cloud.operators.gcs.GCSBucketCreateAclEntryOperator'
+ ),
+ 'GoogleCloudStorageObjectCreateAclEntryOperator': (
+ 'airflow.providers.google.cloud.operators.gcs.GCSObjectCreateAclEntryOperator'
+ ),
+ },
+ 'gcs_delete_operator': {
+ 'GCSDeleteObjectsOperator': 'airflow.providers.google.cloud.operators.gcs.GCSDeleteObjectsOperator',
+ 'GoogleCloudStorageDeleteOperator': (
+ 'airflow.providers.google.cloud.operators.gcs.GCSDeleteObjectsOperator'
+ ),
+ },
+ 'gcs_download_operator': {
+ 'GCSToLocalFilesystemOperator': (
+ 'airflow.providers.google.cloud.transfers.gcs_to_local.GCSToLocalFilesystemOperator'
+ ),
+ 'GoogleCloudStorageDownloadOperator': (
+ 'airflow.providers.google.cloud.transfers.gcs_to_local.GCSToLocalFilesystemOperator'
+ ),
+ },
+ 'gcs_list_operator': {
+ 'GCSListObjectsOperator': 'airflow.providers.google.cloud.operators.gcs.GCSListObjectsOperator',
+ 'GoogleCloudStorageListOperator': (
+ 'airflow.providers.google.cloud.operators.gcs.GCSListObjectsOperator'
+ ),
+ },
+ 'gcs_operator': {
+ 'GCSCreateBucketOperator': 'airflow.providers.google.cloud.operators.gcs.GCSCreateBucketOperator',
+ 'GoogleCloudStorageCreateBucketOperator': (
+ 'airflow.providers.google.cloud.operators.gcs.GCSCreateBucketOperator'
+ ),
+ },
+ 'gcs_to_bq': {
+ 'GCSToBigQueryOperator': (
+ 'airflow.providers.google.cloud.transfers.gcs_to_bigquery.GCSToBigQueryOperator'
+ ),
+ 'GoogleCloudStorageToBigQueryOperator': (
+ 'airflow.providers.google.cloud.transfers.gcs_to_bigquery.GCSToBigQueryOperator'
+ ),
+ },
+ 'gcs_to_gcs': {
+ 'GCSToGCSOperator': 'airflow.providers.google.cloud.transfers.gcs_to_gcs.GCSToGCSOperator',
+ 'GoogleCloudStorageToGoogleCloudStorageOperator': (
+ 'airflow.providers.google.cloud.transfers.gcs_to_gcs.GCSToGCSOperator'
+ ),
+ },
+ 'gcs_to_gdrive_operator': {
+ 'GCSToGoogleDriveOperator': (
+ 'airflow.providers.google.suite.transfers.gcs_to_gdrive.GCSToGoogleDriveOperator'
+ ),
+ },
+ 'gcs_to_s3': {
+ 'GCSToS3Operator': 'airflow.providers.amazon.aws.transfers.gcs_to_s3.GCSToS3Operator',
+ 'GoogleCloudStorageToS3Operator': 'airflow.providers.amazon.aws.transfers.gcs_to_s3.GCSToS3Operator',
+ },
+ 'grpc_operator': {
+ 'GrpcOperator': 'airflow.providers.grpc.operators.grpc.GrpcOperator',
+ },
+ 'hive_to_dynamodb': {
+ 'HiveToDynamoDBOperator': (
+ 'airflow.providers.amazon.aws.transfers.hive_to_dynamodb.HiveToDynamoDBOperator'
+ ),
+ },
+ 'imap_attachment_to_s3_operator': {
+ 'ImapAttachmentToS3Operator': (
+ 'airflow.providers.amazon.aws.transfers.imap_attachment_to_s3.ImapAttachmentToS3Operator'
+ ),
+ },
+ 'jenkins_job_trigger_operator': {
+ 'JenkinsJobTriggerOperator': (
+ 'airflow.providers.jenkins.operators.jenkins_job_trigger.JenkinsJobTriggerOperator'
+ ),
+ },
+ 'jira_operator': {
+ 'JiraOperator': 'airflow.providers.atlassian.jira.operators.jira.JiraOperator',
+ },
+ 'kubernetes_pod_operator': {
+ 'KubernetesPodOperator': (
+ 'airflow.providers.cncf.kubernetes.operators.kubernetes_pod.KubernetesPodOperator'
+ ),
+ },
+ 'mlengine_operator': {
+ 'MLEngineManageModelOperator': (
+ 'airflow.providers.google.cloud.operators.mlengine.MLEngineManageModelOperator'
+ ),
+ 'MLEngineManageVersionOperator': (
+ 'airflow.providers.google.cloud.operators.mlengine.MLEngineManageVersionOperator'
+ ),
+ 'MLEngineStartBatchPredictionJobOperator': (
+ 'airflow.providers.google.cloud.operators.mlengine.MLEngineStartBatchPredictionJobOperator'
+ ),
+ 'MLEngineStartTrainingJobOperator': (
+ 'airflow.providers.google.cloud.operators.mlengine.MLEngineStartTrainingJobOperator'
+ ),
+ 'MLEngineBatchPredictionOperator': (
+ 'airflow.providers.google.cloud.operators.mlengine.MLEngineStartBatchPredictionJobOperator'
+ ),
+ 'MLEngineModelOperator': (
+ 'airflow.providers.google.cloud.operators.mlengine.MLEngineManageModelOperator'
+ ),
+ 'MLEngineTrainingOperator': (
+ 'airflow.providers.google.cloud.operators.mlengine.MLEngineStartTrainingJobOperator'
+ ),
+ 'MLEngineVersionOperator': (
+ 'airflow.providers.google.cloud.operators.mlengine.MLEngineManageVersionOperator'
+ ),
+ },
+ 'mongo_to_s3': {
+ 'MongoToS3Operator': 'airflow.providers.amazon.aws.transfers.mongo_to_s3.MongoToS3Operator',
+ },
+ 'mssql_to_gcs': {
+ 'MSSQLToGCSOperator': 'airflow.providers.google.cloud.transfers.mssql_to_gcs.MSSQLToGCSOperator',
+ 'MsSqlToGoogleCloudStorageOperator': (
+ 'airflow.providers.google.cloud.transfers.mssql_to_gcs.MSSQLToGCSOperator'
+ ),
+ },
+ 'mysql_to_gcs': {
+ 'MySQLToGCSOperator': 'airflow.providers.google.cloud.transfers.mysql_to_gcs.MySQLToGCSOperator',
+ 'MySqlToGoogleCloudStorageOperator': (
+ 'airflow.providers.google.cloud.transfers.mysql_to_gcs.MySQLToGCSOperator'
+ ),
+ },
+ 'opsgenie_alert_operator': {
+ 'OpsgenieCreateAlertOperator': (
+ 'airflow.providers.opsgenie.operators.opsgenie.OpsgenieCreateAlertOperator'
+ ),
+ 'OpsgenieAlertOperator': 'airflow.providers.opsgenie.operators.opsgenie.OpsgenieCreateAlertOperator',
+ },
+ 'oracle_to_azure_data_lake_transfer': {
+ 'OracleToAzureDataLakeOperator':
+ 'airflow.providers.microsoft.azure.transfers.'
+ 'oracle_to_azure_data_lake.OracleToAzureDataLakeOperator',
+ },
+ 'oracle_to_oracle_transfer': {
+ 'OracleToOracleOperator': (
+ 'airflow.providers.oracle.transfers.oracle_to_oracle.OracleToOracleOperator'
+ ),
+ 'OracleToOracleTransfer': (
+ 'airflow.providers.oracle.transfers.oracle_to_oracle.OracleToOracleOperator'
+ ),
+ },
+ 'postgres_to_gcs_operator': {
+ 'PostgresToGCSOperator': (
+ 'airflow.providers.google.cloud.transfers.postgres_to_gcs.PostgresToGCSOperator'
+ ),
+ 'PostgresToGoogleCloudStorageOperator': (
+ 'airflow.providers.google.cloud.transfers.postgres_to_gcs.PostgresToGCSOperator'
+ ),
+ },
+ 'pubsub_operator': {
+ 'PubSubCreateSubscriptionOperator': (
+ 'airflow.providers.google.cloud.operators.pubsub.PubSubCreateSubscriptionOperator'
+ ),
+ 'PubSubCreateTopicOperator': (
+ 'airflow.providers.google.cloud.operators.pubsub.PubSubCreateTopicOperator'
+ ),
+ 'PubSubDeleteSubscriptionOperator': (
+ 'airflow.providers.google.cloud.operators.pubsub.PubSubDeleteSubscriptionOperator'
+ ),
+ 'PubSubDeleteTopicOperator': (
+ 'airflow.providers.google.cloud.operators.pubsub.PubSubDeleteTopicOperator'
+ ),
+ 'PubSubPublishMessageOperator': (
+ 'airflow.providers.google.cloud.operators.pubsub.PubSubPublishMessageOperator'
+ ),
+ 'PubSubPublishOperator': (
+ 'airflow.providers.google.cloud.operators.pubsub.PubSubPublishMessageOperator'
+ ),
+ 'PubSubSubscriptionCreateOperator': (
+ 'airflow.providers.google.cloud.operators.pubsub.PubSubCreateSubscriptionOperator'
+ ),
+ 'PubSubSubscriptionDeleteOperator': (
+ 'airflow.providers.google.cloud.operators.pubsub.PubSubDeleteSubscriptionOperator'
+ ),
+ 'PubSubTopicCreateOperator': (
+ 'airflow.providers.google.cloud.operators.pubsub.PubSubCreateTopicOperator'
+ ),
+ 'PubSubTopicDeleteOperator': (
+ 'airflow.providers.google.cloud.operators.pubsub.PubSubDeleteTopicOperator'
+ ),
+ },
+ 'qubole_check_operator': {
+ 'QuboleCheckOperator': 'airflow.providers.qubole.operators.qubole_check.QuboleCheckOperator',
+ 'QuboleValueCheckOperator': (
+ 'airflow.providers.qubole.operators.qubole_check.QuboleValueCheckOperator'
+ ),
+ },
+ 'qubole_operator': {
+ 'QuboleOperator': 'airflow.providers.qubole.operators.qubole.QuboleOperator',
+ },
+ 'redis_publish_operator': {
+ 'RedisPublishOperator': 'airflow.providers.redis.operators.redis_publish.RedisPublishOperator',
+ },
+ 's3_to_gcs_operator': {
+ 'S3ToGCSOperator': 'airflow.providers.google.cloud.transfers.s3_to_gcs.S3ToGCSOperator',
+ },
+ 's3_to_gcs_transfer_operator': {
+ 'CloudDataTransferServiceS3ToGCSOperator':
+ 'airflow.providers.google.cloud.operators.cloud_storage_transfer_service.'
+ 'CloudDataTransferServiceS3ToGCSOperator',
+ },
+ 's3_to_sftp_operator': {
+ 'S3ToSFTPOperator': 'airflow.providers.amazon.aws.transfers.s3_to_sftp.S3ToSFTPOperator',
+ },
+ 'segment_track_event_operator': {
+ 'SegmentTrackEventOperator': (
+ 'airflow.providers.segment.operators.segment_track_event.SegmentTrackEventOperator'
+ ),
+ },
+ 'sftp_operator': {
+ 'SFTPOperator': 'airflow.providers.sftp.operators.sftp.SFTPOperator',
+ },
+ 'sftp_to_s3_operator': {
+ 'SFTPToS3Operator': 'airflow.providers.amazon.aws.transfers.sftp_to_s3.SFTPToS3Operator',
+ },
+ 'slack_webhook_operator': {
+ 'SlackWebhookOperator': 'airflow.providers.slack.operators.slack_webhook.SlackWebhookOperator',
+ },
+ 'snowflake_operator': {
+ 'SnowflakeOperator': 'airflow.providers.snowflake.operators.snowflake.SnowflakeOperator',
+ },
+ 'sns_publish_operator': {
+ 'SnsPublishOperator': 'airflow.providers.amazon.aws.operators.sns.SnsPublishOperator',
+ },
+ 'spark_jdbc_operator': {
+ 'SparkJDBCOperator': 'airflow.providers.apache.spark.operators.spark_jdbc.SparkJDBCOperator',
+ 'SparkSubmitOperator': 'airflow.providers.apache.spark.operators.spark_jdbc.SparkSubmitOperator',
+ },
+ 'spark_sql_operator': {
+ 'SparkSqlOperator': 'airflow.providers.apache.spark.operators.spark_sql.SparkSqlOperator',
+ },
+ 'spark_submit_operator': {
+ 'SparkSubmitOperator': 'airflow.providers.apache.spark.operators.spark_submit.SparkSubmitOperator',
+ },
+ 'sql_to_gcs': {
+ 'BaseSQLToGCSOperator': 'airflow.providers.google.cloud.transfers.sql_to_gcs.BaseSQLToGCSOperator',
+ 'BaseSQLToGoogleCloudStorageOperator': (
+ 'airflow.providers.google.cloud.transfers.sql_to_gcs.BaseSQLToGCSOperator'
+ ),
+ },
+ 'sqoop_operator': {
+ 'SqoopOperator': 'airflow.providers.apache.sqoop.operators.sqoop.SqoopOperator',
+ },
+ 'ssh_operator': {
+ 'SSHOperator': 'airflow.providers.ssh.operators.ssh.SSHOperator',
+ },
+ 'vertica_operator': {
+ 'VerticaOperator': 'airflow.providers.vertica.operators.vertica.VerticaOperator',
+ },
+ 'vertica_to_hive': {
+ 'VerticaToHiveOperator': (
+ 'airflow.providers.apache.hive.transfers.vertica_to_hive.VerticaToHiveOperator'
+ ),
+ 'VerticaToHiveTransfer': (
+ 'airflow.providers.apache.hive.transfers.vertica_to_hive.VerticaToHiveOperator'
+ ),
+ },
+ 'vertica_to_mysql': {
+ 'VerticaToMySqlOperator': 'airflow.providers.mysql.transfers.vertica_to_mysql.VerticaToMySqlOperator',
+ 'VerticaToMySqlTransfer': 'airflow.providers.mysql.transfers.vertica_to_mysql.VerticaToMySqlOperator',
+ },
+ 'wasb_delete_blob_operator': {
+ 'WasbDeleteBlobOperator': (
+ 'airflow.providers.microsoft.azure.operators.wasb_delete_blob.WasbDeleteBlobOperator'
+ ),
+ },
+ 'winrm_operator': {
+ 'WinRMOperator': 'airflow.providers.microsoft.winrm.operators.winrm.WinRMOperator',
+ },
+}
+
+add_deprecated_classes(__deprecated_classes, __name__)
diff --git a/airflow/contrib/operators/adls_list_operator.py b/airflow/contrib/operators/adls_list_operator.py
deleted file mode 100644
index d4d1394f89fe0..0000000000000
--- a/airflow/contrib/operators/adls_list_operator.py
+++ /dev/null
@@ -1,44 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.operators.adls`."""
-
-import warnings
-
-from airflow.providers.microsoft.azure.operators.adls import ADLSListOperator
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.microsoft.azure.operators.adls`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class AzureDataLakeStorageListOperator(ADLSListOperator):
- """
- This class is deprecated.
- Please use Please use :mod:`airflow.providers.microsoft.azure.operators.adls.ADLSListOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use Please use :mod:`airflow.providers.microsoft.azure.operators.adls.ADLSListOperator`""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/adls_to_gcs.py b/airflow/contrib/operators/adls_to_gcs.py
deleted file mode 100644
index 0497d4c259023..0000000000000
--- a/airflow/contrib/operators/adls_to_gcs.py
+++ /dev/null
@@ -1,44 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.transfers.adls_to_gcs`."""
-
-import warnings
-
-from airflow.providers.google.cloud.transfers.adls_to_gcs import ADLSToGCSOperator
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.transfers.adls_to_gcs`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class AdlsToGoogleCloudStorageOperator(ADLSToGCSOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.transfers.adls_to_gcs.ADLSToGCSOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.transfers.adls_to_gcs.ADLSToGCSOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/aws_athena_operator.py b/airflow/contrib/operators/aws_athena_operator.py
deleted file mode 100644
index e799c74635ed2..0000000000000
--- a/airflow/contrib/operators/aws_athena_operator.py
+++ /dev/null
@@ -1,29 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.athena`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.operators.athena import AWSAthenaOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.athena`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/aws_sqs_publish_operator.py b/airflow/contrib/operators/aws_sqs_publish_operator.py
deleted file mode 100644
index 0ecc8a64f359d..0000000000000
--- a/airflow/contrib/operators/aws_sqs_publish_operator.py
+++ /dev/null
@@ -1,44 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.sqs`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.operators.sqs import SqsPublishOperator
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sqs`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class SQSPublishOperator(SqsPublishOperator):
- """
- This class is deprecated.
- Please use :class:`airflow.providers.amazon.aws.operators.sqs.SqsPublishOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- "This class is deprecated. "
- "Please use `airflow.providers.amazon.aws.operators.sqs.SqsPublishOperator`.",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/awsbatch_operator.py b/airflow/contrib/operators/awsbatch_operator.py
deleted file mode 100644
index a6be224cb93eb..0000000000000
--- a/airflow/contrib/operators/awsbatch_operator.py
+++ /dev/null
@@ -1,75 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-
-"""
-This module is deprecated. Please use:
-
-- :mod:`airflow.providers.amazon.aws.operators.batch`
-- :mod:`airflow.providers.amazon.aws.hooks.batch_client`
-- :mod:`airflow.providers.amazon.aws.hooks.batch_waiters``
-"""
-
-import warnings
-
-from airflow.providers.amazon.aws.hooks.batch_client import AwsBatchProtocol
-from airflow.providers.amazon.aws.operators.batch import AwsBatchOperator
-from airflow.typing_compat import Protocol, runtime_checkable
-
-warnings.warn(
- "This module is deprecated. "
- "Please use `airflow.providers.amazon.aws.operators.batch`, "
- "`airflow.providers.amazon.aws.hooks.batch_client`, and "
- "`airflow.providers.amazon.aws.hooks.batch_waiters`",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class AWSBatchOperator(AwsBatchOperator):
- """
- This class is deprecated. Please use
- `airflow.providers.amazon.aws.operators.batch.AwsBatchOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.amazon.aws.operators.batch.AwsBatchOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-@runtime_checkable
-class BatchProtocol(AwsBatchProtocol, Protocol):
- """
- This class is deprecated. Please use
- `airflow.providers.amazon.aws.hooks.batch_client.AwsBatchProtocol`.
- """
-
- # A Protocol cannot be instantiated
-
- def __new__(cls, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.amazon.aws.hooks.batch_client.AwsBatchProtocol`.""",
- DeprecationWarning,
- stacklevel=2,
- )
diff --git a/airflow/contrib/operators/azure_container_instances_operator.py b/airflow/contrib/operators/azure_container_instances_operator.py
deleted file mode 100644
index d084748ca06df..0000000000000
--- a/airflow/contrib/operators/azure_container_instances_operator.py
+++ /dev/null
@@ -1,33 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-This module is deprecated. Please use
-`airflow.providers.microsoft.azure.operators.container_instances`.
-"""
-import warnings
-
-from airflow.providers.microsoft.azure.operators.container_instances import ( # noqa
- AzureContainerInstancesOperator,
-)
-
-warnings.warn(
- "This module is deprecated. "
- "Please use `airflow.providers.microsoft.azure.operators.container_instances`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/azure_cosmos_operator.py b/airflow/contrib/operators/azure_cosmos_operator.py
deleted file mode 100644
index 269c8357c02d3..0000000000000
--- a/airflow/contrib/operators/azure_cosmos_operator.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.operators.cosmos`."""
-
-import warnings
-
-from airflow.providers.microsoft.azure.operators.cosmos import AzureCosmosInsertDocumentOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.microsoft.azure.operators.cosmos`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/bigquery_check_operator.py b/airflow/contrib/operators/bigquery_check_operator.py
deleted file mode 100644
index 39b658cacab8d..0000000000000
--- a/airflow/contrib/operators/bigquery_check_operator.py
+++ /dev/null
@@ -1,32 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.operators.bigquery`."""
-
-import warnings
-
-from airflow.providers.google.cloud.operators.bigquery import ( # noqa
- BigQueryCheckOperator,
- BigQueryIntervalCheckOperator,
- BigQueryValueCheckOperator,
-)
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.operators.bigquery`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/bigquery_get_data.py b/airflow/contrib/operators/bigquery_get_data.py
deleted file mode 100644
index 00c8575df3d70..0000000000000
--- a/airflow/contrib/operators/bigquery_get_data.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.operators.bigquery`."""
-
-import warnings
-
-from airflow.providers.google.cloud.operators.bigquery import BigQueryGetDataOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.operators.bigquery`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/bigquery_operator.py b/airflow/contrib/operators/bigquery_operator.py
deleted file mode 100644
index 6fe8f0816263f..0000000000000
--- a/airflow/contrib/operators/bigquery_operator.py
+++ /dev/null
@@ -1,55 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.operators.bigquery`."""
-
-import warnings
-
-from airflow.providers.google.cloud.operators.bigquery import ( # noqa; noqa; noqa; noqa; noqa
- BigQueryCreateEmptyDatasetOperator,
- BigQueryCreateEmptyTableOperator,
- BigQueryCreateExternalTableOperator,
- BigQueryDeleteDatasetOperator,
- BigQueryExecuteQueryOperator,
- BigQueryGetDatasetOperator,
- BigQueryGetDatasetTablesOperator,
- BigQueryPatchDatasetOperator,
- BigQueryUpdateDatasetOperator,
- BigQueryUpsertTableOperator,
-)
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.operators.bigquery`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class BigQueryOperator(BigQueryExecuteQueryOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.bigquery.BigQueryExecuteQueryOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.bigquery.BigQueryExecuteQueryOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/bigquery_table_delete_operator.py b/airflow/contrib/operators/bigquery_table_delete_operator.py
deleted file mode 100644
index 13822a1844409..0000000000000
--- a/airflow/contrib/operators/bigquery_table_delete_operator.py
+++ /dev/null
@@ -1,44 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.operators.bigquery`."""
-
-import warnings
-
-from airflow.providers.google.cloud.operators.bigquery import BigQueryDeleteTableOperator
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.operators.bigquery`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class BigQueryTableDeleteOperator(BigQueryDeleteTableOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.bigquery.BigQueryDeleteTableOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.bigquery.BigQueryDeleteTableOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/bigquery_to_bigquery.py b/airflow/contrib/operators/bigquery_to_bigquery.py
deleted file mode 100644
index 84c26fb718c91..0000000000000
--- a/airflow/contrib/operators/bigquery_to_bigquery.py
+++ /dev/null
@@ -1,31 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-This module is deprecated.
-Please use :mod:`airflow.providers.google.cloud.transfers.bigquery_to_bigquery`.
-"""
-
-import warnings
-
-from airflow.providers.google.cloud.transfers.bigquery_to_bigquery import BigQueryToBigQueryOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.transfers.bigquery_to_bigquery`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/bigquery_to_gcs.py b/airflow/contrib/operators/bigquery_to_gcs.py
deleted file mode 100644
index 702171187e885..0000000000000
--- a/airflow/contrib/operators/bigquery_to_gcs.py
+++ /dev/null
@@ -1,44 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.transfers.bigquery_to_gcs`."""
-
-import warnings
-
-from airflow.providers.google.cloud.transfers.bigquery_to_gcs import BigQueryToGCSOperator
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.transfers.bigquery_to_gcs`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class BigQueryToCloudStorageOperator(BigQueryToGCSOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.transfers.bigquery_to_gcs.BigQueryToGCSOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.transfers.bigquery_to_gcs.BigQueryToGCSOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/bigquery_to_mysql_operator.py b/airflow/contrib/operators/bigquery_to_mysql_operator.py
deleted file mode 100644
index 401921cff502b..0000000000000
--- a/airflow/contrib/operators/bigquery_to_mysql_operator.py
+++ /dev/null
@@ -1,30 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated.
-Please use :mod:`airflow.providers.google.cloud.transfers.bigquery_to_mysql`.
-"""
-
-import warnings
-
-from airflow.providers.google.cloud.transfers.bigquery_to_mysql import BigQueryToMySqlOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.transfers.bigquery_to_mysql`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/cassandra_to_gcs.py b/airflow/contrib/operators/cassandra_to_gcs.py
deleted file mode 100644
index bb4b244f0d4cf..0000000000000
--- a/airflow/contrib/operators/cassandra_to_gcs.py
+++ /dev/null
@@ -1,47 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-This module is deprecated.
-Please use :mod:`airflow.providers.google.cloud.transfers.cassandra_to_gcs`.
-"""
-
-import warnings
-
-from airflow.providers.google.cloud.transfers.cassandra_to_gcs import CassandraToGCSOperator
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.transfers.cassandra_to_gcs`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class CassandraToGoogleCloudStorageOperator(CassandraToGCSOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.transfers.cassandra_to_gcs.CassandraToGCSOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.transfers.cassandra_to_gcs.CassandraToGCSOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/databricks_operator.py b/airflow/contrib/operators/databricks_operator.py
deleted file mode 100644
index b591dd63b8274..0000000000000
--- a/airflow/contrib/operators/databricks_operator.py
+++ /dev/null
@@ -1,31 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.databricks.operators.databricks`."""
-
-import warnings
-
-from airflow.providers.databricks.operators.databricks import ( # noqa
- DatabricksRunNowOperator,
- DatabricksSubmitRunOperator,
-)
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.databricks.operators.databricks`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/dataflow_operator.py b/airflow/contrib/operators/dataflow_operator.py
deleted file mode 100644
index d2e445add02f1..0000000000000
--- a/airflow/contrib/operators/dataflow_operator.py
+++ /dev/null
@@ -1,82 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.operators.dataflow`."""
-
-import warnings
-
-from airflow.providers.google.cloud.operators.dataflow import (
- DataflowCreateJavaJobOperator,
- DataflowCreatePythonJobOperator,
- DataflowTemplatedJobStartOperator,
-)
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.operators.dataflow`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class DataFlowJavaOperator(DataflowCreateJavaJobOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.dataflow.DataflowCreateJavaJobOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.dataflow.DataflowCreateJavaJobOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class DataFlowPythonOperator(DataflowCreatePythonJobOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.dataflow.DataflowCreatePythonJobOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use
- `airflow.providers.google.cloud.operators.dataflow.DataflowCreatePythonJobOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class DataflowTemplateOperator(DataflowTemplatedJobStartOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.dataflow.DataflowTemplatedJobStartOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use
- `airflow.providers.google.cloud.operators.dataflow.DataflowTemplatedJobStartOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/dataproc_operator.py b/airflow/contrib/operators/dataproc_operator.py
deleted file mode 100644
index b655ce630142a..0000000000000
--- a/airflow/contrib/operators/dataproc_operator.py
+++ /dev/null
@@ -1,244 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.operators.dataproc`."""
-
-import warnings
-
-from airflow.providers.google.cloud.operators.dataproc import (
- DataprocCreateClusterOperator,
- DataprocDeleteClusterOperator,
- DataprocInstantiateInlineWorkflowTemplateOperator,
- DataprocInstantiateWorkflowTemplateOperator,
- DataprocJobBaseOperator,
- DataprocScaleClusterOperator,
- DataprocSubmitHadoopJobOperator,
- DataprocSubmitHiveJobOperator,
- DataprocSubmitPigJobOperator,
- DataprocSubmitPySparkJobOperator,
- DataprocSubmitSparkJobOperator,
- DataprocSubmitSparkSqlJobOperator,
-)
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.operators.dataproc`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class DataprocClusterCreateOperator(DataprocCreateClusterOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.dataproc.DataprocCreateClusterOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.dataproc.DataprocCreateClusterOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class DataprocClusterDeleteOperator(DataprocDeleteClusterOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.dataproc.DataprocDeleteClusterOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.dataproc.DataprocDeleteClusterOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class DataprocClusterScaleOperator(DataprocScaleClusterOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.dataproc.DataprocScaleClusterOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.dataproc.DataprocScaleClusterOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class DataProcHadoopOperator(DataprocSubmitHadoopJobOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.dataproc.DataprocSubmitHadoopJobOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use
- `airflow.providers.google.cloud.operators.dataproc.DataprocSubmitHadoopJobOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class DataProcHiveOperator(DataprocSubmitHiveJobOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.dataproc.DataprocSubmitHiveJobOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use
- `airflow.providers.google.cloud.operators.dataproc.DataprocSubmitHiveJobOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class DataProcJobBaseOperator(DataprocJobBaseOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.dataproc.DataprocJobBaseOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.dataproc.DataprocJobBaseOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class DataProcPigOperator(DataprocSubmitPigJobOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.dataproc.DataprocSubmitPigJobOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.dataproc.DataprocSubmitPigJobOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class DataProcPySparkOperator(DataprocSubmitPySparkJobOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.dataproc.DataprocSubmitPySparkJobOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use
- `airflow.providers.google.cloud.operators.dataproc.DataprocSubmitPySparkJobOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class DataProcSparkOperator(DataprocSubmitSparkJobOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.dataproc.DataprocSubmitSparkJobOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use
- `airflow.providers.google.cloud.operators.dataproc.DataprocSubmitSparkJobOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class DataProcSparkSqlOperator(DataprocSubmitSparkSqlJobOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.dataproc.DataprocSubmitSparkSqlJobOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use
- `airflow.providers.google.cloud.operators.dataproc.DataprocSubmitSparkSqlJobOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class DataprocWorkflowTemplateInstantiateInlineOperator(DataprocInstantiateInlineWorkflowTemplateOperator):
- """
- This class is deprecated.
- Please use
- `airflow.providers.google.cloud.operators.dataproc.DataprocInstantiateInlineWorkflowTemplateOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use
- `airflow.providers.google.cloud.operators.dataproc
- .DataprocInstantiateInlineWorkflowTemplateOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class DataprocWorkflowTemplateInstantiateOperator(DataprocInstantiateWorkflowTemplateOperator):
- """
- This class is deprecated.
- Please use
- `airflow.providers.google.cloud.operators.dataproc.DataprocInstantiateWorkflowTemplateOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use
- `airflow.providers.google.cloud.operators.dataproc
- .DataprocInstantiateWorkflowTemplateOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/datastore_export_operator.py b/airflow/contrib/operators/datastore_export_operator.py
deleted file mode 100644
index 085ee132607da..0000000000000
--- a/airflow/contrib/operators/datastore_export_operator.py
+++ /dev/null
@@ -1,45 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.operators.datastore`."""
-
-import warnings
-
-from airflow.providers.google.cloud.operators.datastore import CloudDatastoreExportEntitiesOperator
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.operators.datastore`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class DatastoreExportOperator(CloudDatastoreExportEntitiesOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.datastore.CloudDatastoreExportEntitiesOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.l
- Please use
- `airflow.providers.google.cloud.operators.datastore.CloudDatastoreExportEntitiesOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/datastore_import_operator.py b/airflow/contrib/operators/datastore_import_operator.py
deleted file mode 100644
index 5b15cd23c9516..0000000000000
--- a/airflow/contrib/operators/datastore_import_operator.py
+++ /dev/null
@@ -1,45 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.operators.datastore`."""
-
-import warnings
-
-from airflow.providers.google.cloud.operators.datastore import CloudDatastoreImportEntitiesOperator
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.operators.datastore`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class DatastoreImportOperator(CloudDatastoreImportEntitiesOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.datastore.CloudDatastoreImportEntitiesOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use
- `airflow.providers.google.cloud.operators.datastore.CloudDatastoreImportEntitiesOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/dingding_operator.py b/airflow/contrib/operators/dingding_operator.py
deleted file mode 100644
index bfe91e8a72491..0000000000000
--- a/airflow/contrib/operators/dingding_operator.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.dingding.operators.dingding`."""
-
-import warnings
-
-from airflow.providers.dingding.operators.dingding import DingdingOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.dingding.operators.dingding`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/discord_webhook_operator.py b/airflow/contrib/operators/discord_webhook_operator.py
deleted file mode 100644
index be5809afbfcbd..0000000000000
--- a/airflow/contrib/operators/discord_webhook_operator.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.discord.operators.discord_webhook`."""
-
-import warnings
-
-from airflow.providers.discord.operators.discord_webhook import DiscordWebhookOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.discord.operators.discord_webhook`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/docker_swarm_operator.py b/airflow/contrib/operators/docker_swarm_operator.py
deleted file mode 100644
index b023da796a4c8..0000000000000
--- a/airflow/contrib/operators/docker_swarm_operator.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.docker.operators.docker_swarm`."""
-
-import warnings
-
-from airflow.providers.docker.operators.docker_swarm import DockerSwarmOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.docker.operators.docker_swarm`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/druid_operator.py b/airflow/contrib/operators/druid_operator.py
deleted file mode 100644
index 20dff77192313..0000000000000
--- a/airflow/contrib/operators/druid_operator.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.apache.druid.operators.druid`."""
-
-import warnings
-
-from airflow.providers.apache.druid.operators.druid import DruidOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.apache.druid.operators.druid`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/dynamodb_to_s3.py b/airflow/contrib/operators/dynamodb_to_s3.py
deleted file mode 100644
index a2054007c0678..0000000000000
--- a/airflow/contrib/operators/dynamodb_to_s3.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.transfers.dynamodb_to_s3`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.transfers.dynamodb_to_s3 import DynamoDBToS3Operator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.transfers.dynamodb_to_s3`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/ecs_operator.py b/airflow/contrib/operators/ecs_operator.py
deleted file mode 100644
index 569df0c284774..0000000000000
--- a/airflow/contrib/operators/ecs_operator.py
+++ /dev/null
@@ -1,30 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.ecs`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.operators.ecs import ECSOperator, ECSProtocol
-
-__all__ = ["ECSOperator", "ECSProtocol"]
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.ecs`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/emr_add_steps_operator.py b/airflow/contrib/operators/emr_add_steps_operator.py
deleted file mode 100644
index e53f284e447c6..0000000000000
--- a/airflow/contrib/operators/emr_add_steps_operator.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.emr_add_steps`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.operators.emr_add_steps import EmrAddStepsOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.emr_add_steps`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/emr_create_job_flow_operator.py b/airflow/contrib/operators/emr_create_job_flow_operator.py
deleted file mode 100644
index 16f1ce7f61f21..0000000000000
--- a/airflow/contrib/operators/emr_create_job_flow_operator.py
+++ /dev/null
@@ -1,29 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.emr_create_job_flow`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.operators.emr_create_job_flow import EmrCreateJobFlowOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.emr_create_job_flow`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/emr_terminate_job_flow_operator.py b/airflow/contrib/operators/emr_terminate_job_flow_operator.py
deleted file mode 100644
index 7c73bc32dc8db..0000000000000
--- a/airflow/contrib/operators/emr_terminate_job_flow_operator.py
+++ /dev/null
@@ -1,31 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-This module is deprecated.
-Please use :mod:`airflow.providers.amazon.aws.operators.emr_terminate_job_flow`.
-"""
-
-import warnings
-
-from airflow.providers.amazon.aws.operators.emr_terminate_job_flow import EmrTerminateJobFlowOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.emr_terminate_job_flow`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/file_to_gcs.py b/airflow/contrib/operators/file_to_gcs.py
deleted file mode 100644
index 227cd74979c2f..0000000000000
--- a/airflow/contrib/operators/file_to_gcs.py
+++ /dev/null
@@ -1,45 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.transfers.local_to_gcs`."""
-
-import warnings
-
-from airflow.providers.google.cloud.transfers.local_to_gcs import LocalFilesystemToGCSOperator
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.transfers.local_to_gcs`,",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class FileToGoogleCloudStorageOperator(LocalFilesystemToGCSOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.transfers.local_to_gcs.LocalFilesystemToGCSOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use
- `airflow.providers.google.cloud.transfers.local_to_gcs.LocalFilesystemToGCSOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/file_to_wasb.py b/airflow/contrib/operators/file_to_wasb.py
deleted file mode 100644
index eb6275d5bbd1f..0000000000000
--- a/airflow/contrib/operators/file_to_wasb.py
+++ /dev/null
@@ -1,45 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.transfers.local_to_wasb`."""
-
-import warnings
-
-from airflow.providers.microsoft.azure.transfers.local_to_wasb import LocalFilesystemToWasbOperator
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.microsoft.azure.transfers.local_to_wasb`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class FileToWasbOperator(LocalFilesystemToWasbOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.microsoft.azure.transfers.local_to_wasb.LocalFilesystemToWasbOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use
- `airflow.providers.microsoft.azure.transfers.local_to_wasb.LocalFilesystemToWasbOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/gcp_bigtable_operator.py b/airflow/contrib/operators/gcp_bigtable_operator.py
deleted file mode 100644
index f45fde7d5a388..0000000000000
--- a/airflow/contrib/operators/gcp_bigtable_operator.py
+++ /dev/null
@@ -1,136 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-This module is deprecated. Please use `airflow.providers.google.cloud.operators.bigtable`
-or `airflow.providers.google.cloud.sensors.bigtable`.
-"""
-
-import warnings
-
-from airflow.providers.google.cloud.operators.bigtable import (
- BigtableCreateInstanceOperator,
- BigtableCreateTableOperator,
- BigtableDeleteInstanceOperator,
- BigtableDeleteTableOperator,
- BigtableUpdateClusterOperator,
-)
-from airflow.providers.google.cloud.sensors.bigtable import BigtableTableReplicationCompletedSensor
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.operators.bigtable`"
- " or `airflow.providers.google.cloud.sensors.bigtable`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class BigtableClusterUpdateOperator(BigtableUpdateClusterOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.bigtable.BigtableUpdateClusterOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.bigtable.BigtableUpdateClusterOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class BigtableInstanceCreateOperator(BigtableCreateInstanceOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.bigtable.BigtableCreateInstanceOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.bigtable.BigtableCreateInstanceOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class BigtableInstanceDeleteOperator(BigtableDeleteInstanceOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.bigtable.BigtableDeleteInstanceOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.bigtable.BigtableDeleteInstanceOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class BigtableTableCreateOperator(BigtableCreateTableOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.bigtable.BigtableCreateTableOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.bigtable.BigtableCreateTableOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class BigtableTableDeleteOperator(BigtableDeleteTableOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.bigtable.BigtableDeleteTableOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.bigtable.BigtableDeleteTableOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class BigtableTableWaitForReplicationSensor(BigtableTableReplicationCompletedSensor):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.sensors.bigtable.BigtableTableReplicationCompletedSensor`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use
- `airflow.providers.google.cloud.sensors.bigtable.BigtableTableReplicationCompletedSensor`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/gcp_cloud_build_operator.py b/airflow/contrib/operators/gcp_cloud_build_operator.py
deleted file mode 100644
index 443fdbfdab73f..0000000000000
--- a/airflow/contrib/operators/gcp_cloud_build_operator.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.operators.cloud_build`."""
-
-import warnings
-
-from airflow.providers.google.cloud.operators.cloud_build import CloudBuildCreateBuildOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.operators.cloud_build`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/gcp_compute_operator.py b/airflow/contrib/operators/gcp_compute_operator.py
deleted file mode 100644
index d943fdf5587dd..0000000000000
--- a/airflow/contrib/operators/gcp_compute_operator.py
+++ /dev/null
@@ -1,138 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.operators.compute`."""
-
-import warnings
-
-from airflow.providers.google.cloud.operators.compute import (
- ComputeEngineBaseOperator,
- ComputeEngineCopyInstanceTemplateOperator,
- ComputeEngineInstanceGroupUpdateManagerTemplateOperator,
- ComputeEngineSetMachineTypeOperator,
- ComputeEngineStartInstanceOperator,
- ComputeEngineStopInstanceOperator,
-)
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.operators.compute`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class GceBaseOperator(ComputeEngineBaseOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.compute.ComputeEngineBaseOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.compute.ComputeEngineBaseOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class GceInstanceGroupManagerUpdateTemplateOperator(ComputeEngineInstanceGroupUpdateManagerTemplateOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.compute
- .ComputeEngineInstanceGroupUpdateManagerTemplateOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated. Please use
- `airflow.providers.google.cloud.operators.compute
- .ComputeEngineInstanceGroupUpdateManagerTemplateOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class GceInstanceStartOperator(ComputeEngineStartInstanceOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators
- .compute.ComputeEngineStartInstanceOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.compute
- .ComputeEngineStartInstanceOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class GceInstanceStopOperator(ComputeEngineStopInstanceOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.compute.ComputeEngineStopInstanceOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.compute
- .ComputeEngineStopInstanceOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class GceInstanceTemplateCopyOperator(ComputeEngineCopyInstanceTemplateOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.compute.ComputeEngineCopyInstanceTemplateOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """"This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.compute
- .ComputeEngineCopyInstanceTemplateOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class GceSetMachineTypeOperator(ComputeEngineSetMachineTypeOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.compute.ComputeEngineSetMachineTypeOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.compute
- .ComputeEngineSetMachineTypeOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/gcp_container_operator.py b/airflow/contrib/operators/gcp_container_operator.py
deleted file mode 100644
index 8a26c40a8fe5d..0000000000000
--- a/airflow/contrib/operators/gcp_container_operator.py
+++ /dev/null
@@ -1,83 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-This module is deprecated.
-Please use :mod:`airflow.providers.google.cloud.operators.kubernetes_engine`.
-"""
-
-import warnings
-
-from airflow.providers.google.cloud.operators.kubernetes_engine import (
- GKECreateClusterOperator,
- GKEDeleteClusterOperator,
- GKEStartPodOperator,
-)
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.operators.kubernetes_engine`",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class GKEClusterCreateOperator(GKECreateClusterOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.container.GKECreateClusterOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.container.GKECreateClusterOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class GKEClusterDeleteOperator(GKEDeleteClusterOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.container.GKEDeleteClusterOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.container.GKEDeleteClusterOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class GKEPodOperator(GKEStartPodOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.container.GKEStartPodOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.container.GKEStartPodOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/gcp_dlp_operator.py b/airflow/contrib/operators/gcp_dlp_operator.py
deleted file mode 100644
index f5b4c07d4cc8d..0000000000000
--- a/airflow/contrib/operators/gcp_dlp_operator.py
+++ /dev/null
@@ -1,123 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.operators.dlp`."""
-
-import warnings
-
-from airflow.providers.google.cloud.operators.dlp import ( # noqa
- CloudDLPCancelDLPJobOperator,
- CloudDLPCreateDeidentifyTemplateOperator,
- CloudDLPCreateDLPJobOperator,
- CloudDLPCreateInspectTemplateOperator,
- CloudDLPCreateJobTriggerOperator,
- CloudDLPCreateStoredInfoTypeOperator,
- CloudDLPDeidentifyContentOperator,
- CloudDLPDeleteDeidentifyTemplateOperator,
- CloudDLPDeleteDLPJobOperator,
- CloudDLPDeleteInspectTemplateOperator,
- CloudDLPDeleteJobTriggerOperator,
- CloudDLPDeleteStoredInfoTypeOperator,
- CloudDLPGetDeidentifyTemplateOperator,
- CloudDLPGetDLPJobOperator,
- CloudDLPGetDLPJobTriggerOperator,
- CloudDLPGetInspectTemplateOperator,
- CloudDLPGetStoredInfoTypeOperator,
- CloudDLPInspectContentOperator,
- CloudDLPListDeidentifyTemplatesOperator,
- CloudDLPListDLPJobsOperator,
- CloudDLPListInfoTypesOperator,
- CloudDLPListInspectTemplatesOperator,
- CloudDLPListJobTriggersOperator,
- CloudDLPListStoredInfoTypesOperator,
- CloudDLPRedactImageOperator,
- CloudDLPReidentifyContentOperator,
- CloudDLPUpdateDeidentifyTemplateOperator,
- CloudDLPUpdateInspectTemplateOperator,
- CloudDLPUpdateJobTriggerOperator,
- CloudDLPUpdateStoredInfoTypeOperator,
-)
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.operators.dlp`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class CloudDLPDeleteDlpJobOperator(CloudDLPDeleteDLPJobOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.dlp.CloudDLPDeleteDLPJobOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.dlp.CloudDLPDeleteDLPJobOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class CloudDLPGetDlpJobOperator(CloudDLPGetDLPJobOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.dlp.CloudDLPGetDLPJobOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.dlp.CloudDLPGetDLPJobOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class CloudDLPGetJobTripperOperator(CloudDLPGetDLPJobTriggerOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.dlp.CloudDLPGetDLPJobTriggerOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.dlp.CloudDLPGetDLPJobTriggerOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class CloudDLPListDlpJobsOperator(CloudDLPListDLPJobsOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.dlp.CloudDLPListDLPJobsOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.dlp.CloudDLPListDLPJobsOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/gcp_function_operator.py b/airflow/contrib/operators/gcp_function_operator.py
deleted file mode 100644
index 4ca96c5c8a648..0000000000000
--- a/airflow/contrib/operators/gcp_function_operator.py
+++ /dev/null
@@ -1,65 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.operators.functions`."""
-
-import warnings
-
-from airflow.providers.google.cloud.operators.functions import (
- CloudFunctionDeleteFunctionOperator,
- CloudFunctionDeployFunctionOperator,
-)
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.operators.functions`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class GcfFunctionDeleteOperator(CloudFunctionDeleteFunctionOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.function.CloudFunctionDeleteFunctionOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use
- `airflow.providers.google.cloud.operators.function.CloudFunctionDeleteFunctionOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class GcfFunctionDeployOperator(CloudFunctionDeployFunctionOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.function.CloudFunctionDeployFunctionOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use
- `airflow.providers.google.cloud.operators.function.CloudFunctionDeployFunctionOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/gcp_natural_language_operator.py b/airflow/contrib/operators/gcp_natural_language_operator.py
deleted file mode 100644
index 8a98ccf6903f0..0000000000000
--- a/airflow/contrib/operators/gcp_natural_language_operator.py
+++ /dev/null
@@ -1,114 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.operators.natural_language`."""
-
-import warnings
-
-from airflow.providers.google.cloud.operators.natural_language import (
- CloudNaturalLanguageAnalyzeEntitiesOperator,
- CloudNaturalLanguageAnalyzeEntitySentimentOperator,
- CloudNaturalLanguageAnalyzeSentimentOperator,
- CloudNaturalLanguageClassifyTextOperator,
-)
-
-warnings.warn(
- """This module is deprecated.
- Please use `airflow.providers.google.cloud.operators.natural_language`
- """,
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class CloudLanguageAnalyzeEntitiesOperator(CloudNaturalLanguageAnalyzeEntitiesOperator):
- """
- This class is deprecated.
- Please use
- `airflow.providers.google.cloud.operators.natural_language.CloudNaturalLanguageAnalyzeEntitiesOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use
- `airflow.providers.google.cloud.operators.natural_language
- .CloudNaturalLanguageAnalyzeEntitiesOperator`.
- """,
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class CloudLanguageAnalyzeEntitySentimentOperator(CloudNaturalLanguageAnalyzeEntitySentimentOperator):
- """
- This class is deprecated.
- Please use
- `airflow.providers.google.cloud.operators.natural_language
- .CloudNaturalLanguageAnalyzeEntitySentimentOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use
- `airflow.providers.google.cloud.operators.natural_language
- .CloudNaturalLanguageAnalyzeEntitySentimentOperator`.
- """,
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class CloudLanguageAnalyzeSentimentOperator(CloudNaturalLanguageAnalyzeSentimentOperator):
- """
- This class is deprecated.
- Please use
- `airflow.providers.google.cloud.operators.natural_language.CloudNaturalLanguageAnalyzeSentimentOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.natural_language
- .CloudNaturalLanguageAnalyzeSentimentOperator`.
- """,
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class CloudLanguageClassifyTextOperator(CloudNaturalLanguageClassifyTextOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.natural_language
- .CloudNaturalLanguageClassifyTextOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.natural_language
- .CloudNaturalLanguageClassifyTextOperator`.
- """,
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/gcp_spanner_operator.py b/airflow/contrib/operators/gcp_spanner_operator.py
deleted file mode 100644
index b2e50c3a80f6a..0000000000000
--- a/airflow/contrib/operators/gcp_spanner_operator.py
+++ /dev/null
@@ -1,125 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.operators.spanner`."""
-
-import warnings
-
-from airflow.providers.google.cloud.operators.spanner import (
- SpannerDeleteDatabaseInstanceOperator,
- SpannerDeleteInstanceOperator,
- SpannerDeployDatabaseInstanceOperator,
- SpannerDeployInstanceOperator,
- SpannerQueryDatabaseInstanceOperator,
- SpannerUpdateDatabaseInstanceOperator,
-)
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.operators.spanner`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class CloudSpannerInstanceDatabaseDeleteOperator(SpannerDeleteDatabaseInstanceOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.spanner.SpannerDeleteDatabaseInstanceOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- self.__doc__,
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class CloudSpannerInstanceDatabaseDeployOperator(SpannerDeployDatabaseInstanceOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.spanner.SpannerDeployDatabaseInstanceOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- self.__doc__,
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class CloudSpannerInstanceDatabaseQueryOperator(SpannerQueryDatabaseInstanceOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.spanner.SpannerQueryDatabaseInstanceOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- self.__doc__,
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class CloudSpannerInstanceDatabaseUpdateOperator(SpannerUpdateDatabaseInstanceOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.spanner.SpannerUpdateDatabaseInstanceOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- self.__doc__,
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class CloudSpannerInstanceDeleteOperator(SpannerDeleteInstanceOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.spanner.SpannerDeleteInstanceOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- self.__doc__,
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class CloudSpannerInstanceDeployOperator(SpannerDeployInstanceOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.spanner.SpannerDeployInstanceOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- self.__doc__,
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/gcp_speech_to_text_operator.py b/airflow/contrib/operators/gcp_speech_to_text_operator.py
deleted file mode 100644
index 499c8ab7f6e01..0000000000000
--- a/airflow/contrib/operators/gcp_speech_to_text_operator.py
+++ /dev/null
@@ -1,47 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.operators.speech_to_text`."""
-
-import warnings
-
-from airflow.providers.google.cloud.operators.speech_to_text import CloudSpeechToTextRecognizeSpeechOperator
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.operators.speech_to_text`",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class GcpSpeechToTextRecognizeSpeechOperator(CloudSpeechToTextRecognizeSpeechOperator):
- """
- This class is deprecated.
- Please use
- `airflow.providers.google.cloud.operators.speech_to_text.CloudSpeechToTextRecognizeSpeechOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use
- `airflow.providers.google.cloud.operators.speech_to_text
- .CloudSpeechToTextRecognizeSpeechOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/gcp_sql_operator.py b/airflow/contrib/operators/gcp_sql_operator.py
deleted file mode 100644
index 5cbbf8a80a294..0000000000000
--- a/airflow/contrib/operators/gcp_sql_operator.py
+++ /dev/null
@@ -1,149 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.operators.cloud_sql`."""
-
-import warnings
-
-from airflow.providers.google.cloud.operators.cloud_sql import (
- CloudSQLBaseOperator,
- CloudSQLCreateInstanceDatabaseOperator,
- CloudSQLCreateInstanceOperator,
- CloudSQLDeleteInstanceDatabaseOperator,
- CloudSQLDeleteInstanceOperator,
- CloudSQLExecuteQueryOperator,
- CloudSQLExportInstanceOperator,
- CloudSQLImportInstanceOperator,
- CloudSQLInstancePatchOperator,
- CloudSQLPatchInstanceDatabaseOperator,
-)
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.operators.cloud_sql`",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class CloudSqlBaseOperator(CloudSQLBaseOperator):
- """
- This class is deprecated. Please use
- `airflow.providers.google.cloud.operators.sql.CloudSQLBaseOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(self.__doc__, DeprecationWarning, stacklevel=2)
- super().__init__(*args, **kwargs)
-
-
-class CloudSqlInstanceCreateOperator(CloudSQLCreateInstanceOperator):
- """
- This class is deprecated. Please use `airflow.providers.google.cloud.operators.sql
- .CloudSQLCreateInstanceOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(self.__doc__, DeprecationWarning, stacklevel=2)
- super().__init__(*args, **kwargs)
-
-
-class CloudSqlInstanceDatabaseCreateOperator(CloudSQLCreateInstanceDatabaseOperator):
- """
- This class is deprecated. Please use `airflow.providers.google.cloud.operators.sql
- .CloudSQLCreateInstanceDatabaseOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(self.__doc__, DeprecationWarning, stacklevel=2)
- super().__init__(*args, **kwargs)
-
-
-class CloudSqlInstanceDatabaseDeleteOperator(CloudSQLDeleteInstanceDatabaseOperator):
- """
- This class is deprecated. Please use `airflow.providers.google.cloud.operators.sql
- .CloudSQLDeleteInstanceDatabaseOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(self.__doc__, DeprecationWarning, stacklevel=2)
- super().__init__(*args, **kwargs)
-
-
-class CloudSqlInstanceDatabasePatchOperator(CloudSQLPatchInstanceDatabaseOperator):
- """
- This class is deprecated. Please use `airflow.providers.google.cloud.operators.sql
- .CloudSQLPatchInstanceDatabaseOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(self.__doc__, DeprecationWarning, stacklevel=2)
- super().__init__(*args, **kwargs)
-
-
-class CloudSqlInstanceDeleteOperator(CloudSQLDeleteInstanceOperator):
- """
- This class is deprecated. Please use `airflow.providers.google.cloud.operators.sql
- .CloudSQLDeleteInstanceOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(self.__doc__, DeprecationWarning, stacklevel=2)
- super().__init__(*args, **kwargs)
-
-
-class CloudSqlInstanceExportOperator(CloudSQLExportInstanceOperator):
- """
- This class is deprecated. Please use `airflow.providers.google.cloud.operators.sql
- .CloudSQLExportInstanceOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(self.__doc__, DeprecationWarning, stacklevel=2)
- super().__init__(*args, **kwargs)
-
-
-class CloudSqlInstanceImportOperator(CloudSQLImportInstanceOperator):
- """
- This class is deprecated. Please use `airflow.providers.google.cloud.operators.sql
- .CloudSQLImportInstanceOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(self.__doc__, DeprecationWarning, stacklevel=2)
- super().__init__(*args, **kwargs)
-
-
-class CloudSqlInstancePatchOperator(CloudSQLInstancePatchOperator):
- """
- This class is deprecated. Please use `airflow.providers.google.cloud.operators
- .sql.CloudSQLInstancePatchOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(self.__doc__, DeprecationWarning, stacklevel=2)
- super().__init__(*args, **kwargs)
-
-
-class CloudSqlQueryOperator(CloudSQLExecuteQueryOperator):
- """
- This class is deprecated. Please use `airflow.providers.google.cloud.operators
- .sql.CloudSQLExecuteQueryOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(self.__doc__, DeprecationWarning, stacklevel=2)
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/gcp_tasks_operator.py b/airflow/contrib/operators/gcp_tasks_operator.py
deleted file mode 100644
index 319ddb4fe2ad5..0000000000000
--- a/airflow/contrib/operators/gcp_tasks_operator.py
+++ /dev/null
@@ -1,42 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.operators.tasks`."""
-
-import warnings
-
-from airflow.providers.google.cloud.operators.tasks import ( # noqa
- CloudTasksQueueCreateOperator,
- CloudTasksQueueDeleteOperator,
- CloudTasksQueueGetOperator,
- CloudTasksQueuePauseOperator,
- CloudTasksQueuePurgeOperator,
- CloudTasksQueueResumeOperator,
- CloudTasksQueuesListOperator,
- CloudTasksQueueUpdateOperator,
- CloudTasksTaskCreateOperator,
- CloudTasksTaskDeleteOperator,
- CloudTasksTaskGetOperator,
- CloudTasksTaskRunOperator,
- CloudTasksTasksListOperator,
-)
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.operators.tasks`",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/gcp_text_to_speech_operator.py b/airflow/contrib/operators/gcp_text_to_speech_operator.py
deleted file mode 100644
index af5ff2f39d122..0000000000000
--- a/airflow/contrib/operators/gcp_text_to_speech_operator.py
+++ /dev/null
@@ -1,45 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.operators.text_to_speech`."""
-
-import warnings
-
-from airflow.providers.google.cloud.operators.text_to_speech import CloudTextToSpeechSynthesizeOperator
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.operators.text_to_speech`",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class GcpTextToSpeechSynthesizeOperator(CloudTextToSpeechSynthesizeOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.text_to_speech.CloudTextToSpeechSynthesizeOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use
- `airflow.providers.google.cloud.operators.text_to_speech.CloudTextToSpeechSynthesizeOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/gcp_transfer_operator.py b/airflow/contrib/operators/gcp_transfer_operator.py
deleted file mode 100644
index bd7c7bc2cb752..0000000000000
--- a/airflow/contrib/operators/gcp_transfer_operator.py
+++ /dev/null
@@ -1,229 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-This module is deprecated.
-Please use `airflow.providers.google.cloud.operators.cloud_storage_transfer_service`.
-"""
-
-import warnings
-
-from airflow.providers.google.cloud.operators.cloud_storage_transfer_service import (
- CloudDataTransferServiceCancelOperationOperator,
- CloudDataTransferServiceCreateJobOperator,
- CloudDataTransferServiceDeleteJobOperator,
- CloudDataTransferServiceGCSToGCSOperator,
- CloudDataTransferServiceGetOperationOperator,
- CloudDataTransferServiceListOperationsOperator,
- CloudDataTransferServicePauseOperationOperator,
- CloudDataTransferServiceResumeOperationOperator,
- CloudDataTransferServiceS3ToGCSOperator,
- CloudDataTransferServiceUpdateJobOperator,
-)
-
-warnings.warn(
- "This module is deprecated. "
- "Please use `airflow.providers.google.cloud.operators.cloud_storage_transfer_service`",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class GcpTransferServiceJobCreateOperator(CloudDataTransferServiceCreateJobOperator):
- """
- This class is deprecated.
- Please use
- `airflow.providers.google.cloud.operators.data_transfer.CloudDataTransferServiceCreateJobOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.data_transfer
- .CloudDataTransferServiceCreateJobOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class GcpTransferServiceJobDeleteOperator(CloudDataTransferServiceDeleteJobOperator):
- """
- This class is deprecated.
- Please use
- `airflow.providers.google.cloud.operators.data_transfer.CloudDataTransferServiceDeleteJobOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.data_transfer
- .CloudDataTransferServiceDeleteJobOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class GcpTransferServiceJobUpdateOperator(CloudDataTransferServiceUpdateJobOperator):
- """
- This class is deprecated.
- Please use
- `airflow.providers.google.cloud.operators.data_transfer.CloudDataTransferServiceUpdateJobOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.data_transfer
- .CloudDataTransferServiceUpdateJobOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class GcpTransferServiceOperationCancelOperator(CloudDataTransferServiceCancelOperationOperator):
- """
- This class is deprecated.
- Please use
- `airflow.providers.google.cloud.operators.data_transfer.CloudDataTransferServiceCancelOperationOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.data_transfer
- .CloudDataTransferServiceCancelOperationOperator`.
- """,
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class GcpTransferServiceOperationGetOperator(CloudDataTransferServiceGetOperationOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.data_transfer
- .CloudDataTransferServiceGetOperationOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.data_transfer
- .CloudDataTransferServiceGetOperationOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class GcpTransferServiceOperationPauseOperator(CloudDataTransferServicePauseOperationOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.data_transfer
- .CloudDataTransferServicePauseOperationOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.data_transfer
- .CloudDataTransferServicePauseOperationOperator`.
- """,
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class GcpTransferServiceOperationResumeOperator(CloudDataTransferServiceResumeOperationOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.data_transfer
- .CloudDataTransferServiceResumeOperationOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.data_transfer
- .CloudDataTransferServiceResumeOperationOperator`.
- """,
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class GcpTransferServiceOperationsListOperator(CloudDataTransferServiceListOperationsOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.data_transfer
- .CloudDataTransferServiceListOperationsOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.data_transfer
- .CloudDataTransferServiceListOperationsOperator`.
- """,
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class GoogleCloudStorageToGoogleCloudStorageTransferOperator(CloudDataTransferServiceGCSToGCSOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.data_transfer
- .CloudDataTransferServiceGCSToGCSOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.data_transfer
- .CloudDataTransferServiceGCSToGCSOperator`.
- """,
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class S3ToGoogleCloudStorageTransferOperator(CloudDataTransferServiceS3ToGCSOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.data_transfer
- .CloudDataTransferServiceS3ToGCSOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """"This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.data_transfer
- .CloudDataTransferServiceS3ToGCSOperator`.
- """,
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/gcp_translate_operator.py b/airflow/contrib/operators/gcp_translate_operator.py
deleted file mode 100644
index c61cc8492c250..0000000000000
--- a/airflow/contrib/operators/gcp_translate_operator.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.operators.translate`."""
-
-import warnings
-
-from airflow.providers.google.cloud.operators.translate import CloudTranslateTextOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.operators.translate`",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/gcp_translate_speech_operator.py b/airflow/contrib/operators/gcp_translate_speech_operator.py
deleted file mode 100644
index 2e0bb70787b61..0000000000000
--- a/airflow/contrib/operators/gcp_translate_speech_operator.py
+++ /dev/null
@@ -1,49 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-This module is deprecated.
-Please use :mod:`airflow.providers.google.cloud.operators.translate_speech`.
-"""
-
-import warnings
-
-from airflow.providers.google.cloud.operators.translate_speech import CloudTranslateSpeechOperator
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.operators.translate_speech`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class GcpTranslateSpeechOperator(CloudTranslateSpeechOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.translate_speech.CloudTranslateSpeechOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use
- `airflow.providers.google.cloud.operators.translate_speech.CloudTranslateSpeechOperator`.
- """,
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/gcp_video_intelligence_operator.py b/airflow/contrib/operators/gcp_video_intelligence_operator.py
deleted file mode 100644
index a82fc9aeb8411..0000000000000
--- a/airflow/contrib/operators/gcp_video_intelligence_operator.py
+++ /dev/null
@@ -1,35 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-This module is deprecated.
-Please use :mod:`airflow.providers.google.cloud.operators.video_intelligence`.
-"""
-
-import warnings
-
-from airflow.providers.google.cloud.operators.video_intelligence import ( # noqa
- CloudVideoIntelligenceDetectVideoExplicitContentOperator,
- CloudVideoIntelligenceDetectVideoLabelsOperator,
- CloudVideoIntelligenceDetectVideoShotsOperator,
-)
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.operators.video_intelligence`",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/gcp_vision_operator.py b/airflow/contrib/operators/gcp_vision_operator.py
deleted file mode 100644
index 09a5b1e817e28..0000000000000
--- a/airflow/contrib/operators/gcp_vision_operator.py
+++ /dev/null
@@ -1,227 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.operators.vision`."""
-
-import warnings
-
-from airflow.providers.google.cloud.operators.vision import ( # noqa
- CloudVisionAddProductToProductSetOperator,
- CloudVisionCreateProductOperator,
- CloudVisionCreateProductSetOperator,
- CloudVisionCreateReferenceImageOperator,
- CloudVisionDeleteProductOperator,
- CloudVisionDeleteProductSetOperator,
- CloudVisionDetectImageLabelsOperator,
- CloudVisionDetectImageSafeSearchOperator,
- CloudVisionDetectTextOperator,
- CloudVisionGetProductOperator,
- CloudVisionGetProductSetOperator,
- CloudVisionImageAnnotateOperator,
- CloudVisionRemoveProductFromProductSetOperator,
- CloudVisionTextDetectOperator,
- CloudVisionUpdateProductOperator,
- CloudVisionUpdateProductSetOperator,
-)
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.operators.vision`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class CloudVisionAnnotateImageOperator(CloudVisionImageAnnotateOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.vision.CloudVisionImageAnnotateOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.vision.CloudVisionImageAnnotateOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class CloudVisionDetectDocumentTextOperator(CloudVisionTextDetectOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.vision.CloudVisionTextDetectOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.vision.CloudVisionTextDetectOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class CloudVisionProductCreateOperator(CloudVisionCreateProductOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.vision.CloudVisionCreateProductOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.vision.CloudVisionCreateProductOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class CloudVisionProductDeleteOperator(CloudVisionDeleteProductOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.vision.CloudVisionDeleteProductOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.vision.CloudVisionDeleteProductOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class CloudVisionProductGetOperator(CloudVisionGetProductOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.vision.CloudVisionGetProductOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.vision.CloudVisionGetProductOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class CloudVisionProductSetCreateOperator(CloudVisionCreateProductSetOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.vision.CloudVisionCreateProductSetOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use
- `airflow.providers.google.cloud.operators.vision.CloudVisionCreateProductSetOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class CloudVisionProductSetDeleteOperator(CloudVisionDeleteProductSetOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.vision.CloudVisionDeleteProductSetOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use
- `airflow.providers.google.cloud.operators.vision.CloudVisionDeleteProductSetOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class CloudVisionProductSetGetOperator(CloudVisionGetProductSetOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.vision.CloudVisionGetProductSetOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use
- `airflow.providers.google.cloud.operators.vision.CloudVisionGetProductSetOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class CloudVisionProductSetUpdateOperator(CloudVisionUpdateProductSetOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.vision.CloudVisionUpdateProductSetOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use
- `airflow.providers.google.cloud.operators.vision.CloudVisionUpdateProductSetOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class CloudVisionProductUpdateOperator(CloudVisionUpdateProductOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.vision.CloudVisionUpdateProductOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use
- `airflow.providers.google.cloud.operators.vision.CloudVisionUpdateProductOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class CloudVisionReferenceImageCreateOperator(CloudVisionCreateReferenceImageOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.vision.CloudVisionCreateReferenceImageOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use
- `airflow.providers.google.cloud.operators.vision.CloudVisionCreateReferenceImageOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/gcs_acl_operator.py b/airflow/contrib/operators/gcs_acl_operator.py
deleted file mode 100644
index f3e6afa303602..0000000000000
--- a/airflow/contrib/operators/gcs_acl_operator.py
+++ /dev/null
@@ -1,63 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.operators.gcs`."""
-
-import warnings
-
-from airflow.providers.google.cloud.operators.gcs import (
- GCSBucketCreateAclEntryOperator,
- GCSObjectCreateAclEntryOperator,
-)
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.operators.gcs`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class GoogleCloudStorageBucketCreateAclEntryOperator(GCSBucketCreateAclEntryOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.gcs.GCSBucketCreateAclEntryOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.gcs.GCSBucketCreateAclEntryOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class GoogleCloudStorageObjectCreateAclEntryOperator(GCSObjectCreateAclEntryOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.gcs.GCSObjectCreateAclEntryOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.gcs.GCSObjectCreateAclEntryOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/gcs_delete_operator.py b/airflow/contrib/operators/gcs_delete_operator.py
deleted file mode 100644
index 42fa2a89057d1..0000000000000
--- a/airflow/contrib/operators/gcs_delete_operator.py
+++ /dev/null
@@ -1,44 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.operators.gcs`."""
-
-import warnings
-
-from airflow.providers.google.cloud.operators.gcs import GCSDeleteObjectsOperator
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.operators.gcs`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class GoogleCloudStorageDeleteOperator(GCSDeleteObjectsOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.gcs.GCSDeleteObjectsOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.gcs.GCSDeleteObjectsOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/gcs_download_operator.py b/airflow/contrib/operators/gcs_download_operator.py
deleted file mode 100644
index a6904320bd2c2..0000000000000
--- a/airflow/contrib/operators/gcs_download_operator.py
+++ /dev/null
@@ -1,44 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.operators.gcs`."""
-
-import warnings
-
-from airflow.providers.google.cloud.transfers.gcs_to_local import GCSToLocalFilesystemOperator
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.operators.gcs`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class GoogleCloudStorageDownloadOperator(GCSToLocalFilesystemOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.gcs.GCSToLocalFilesystemOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.gcs.GCSToLocalFilesystemOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/gcs_list_operator.py b/airflow/contrib/operators/gcs_list_operator.py
deleted file mode 100644
index a18ec272a5907..0000000000000
--- a/airflow/contrib/operators/gcs_list_operator.py
+++ /dev/null
@@ -1,44 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.operators.gcs`."""
-
-import warnings
-
-from airflow.providers.google.cloud.operators.gcs import GCSListObjectsOperator
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.operators.gcs`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class GoogleCloudStorageListOperator(GCSListObjectsOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.gcs.GCSListObjectsOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.gcs.GCSListObjectsOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/gcs_operator.py b/airflow/contrib/operators/gcs_operator.py
deleted file mode 100644
index 4f0aecbf9c167..0000000000000
--- a/airflow/contrib/operators/gcs_operator.py
+++ /dev/null
@@ -1,44 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.operators.gcs`."""
-
-import warnings
-
-from airflow.providers.google.cloud.operators.gcs import GCSCreateBucketOperator
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.operators.gcs`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class GoogleCloudStorageCreateBucketOperator(GCSCreateBucketOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.gcs.GCSCreateBucketOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.gcs.GCSCreateBucketOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/gcs_to_bq.py b/airflow/contrib/operators/gcs_to_bq.py
deleted file mode 100644
index ab71bb8da7989..0000000000000
--- a/airflow/contrib/operators/gcs_to_bq.py
+++ /dev/null
@@ -1,44 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.transfers.gcs_to_bigquery`."""
-
-import warnings
-
-from airflow.providers.google.cloud.transfers.gcs_to_bigquery import GCSToBigQueryOperator
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.transfers.gcs_to_bigquery`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class GoogleCloudStorageToBigQueryOperator(GCSToBigQueryOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.transfers.gcs_to_bigquery.GCSToBigQueryOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.transfers.gcs_to_bigquery.GCSToBigQueryOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/gcs_to_gcs.py b/airflow/contrib/operators/gcs_to_gcs.py
deleted file mode 100644
index ab6f2d91c1ba9..0000000000000
--- a/airflow/contrib/operators/gcs_to_gcs.py
+++ /dev/null
@@ -1,44 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.transfers.gcs_to_gcs`."""
-
-import warnings
-
-from airflow.providers.google.cloud.transfers.gcs_to_gcs import GCSToGCSOperator
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.transfers.gcs_to_gcs`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class GoogleCloudStorageToGoogleCloudStorageOperator(GCSToGCSOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.transfers.gcs_to_gcs.GCSToGCSOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.transfers.gcs_to_gcs.GCSToGCSOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/gcs_to_gcs_transfer_operator.py b/airflow/contrib/operators/gcs_to_gcs_transfer_operator.py
deleted file mode 100644
index 75a672f2c18b2..0000000000000
--- a/airflow/contrib/operators/gcs_to_gcs_transfer_operator.py
+++ /dev/null
@@ -1,30 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-This module is deprecated.
-Please use `airflow.providers.google.cloud.operators.cloud_storage_transfer_service`.
-"""
-
-import warnings
-
-warnings.warn(
- "This module is deprecated. "
- "Please use `airflow.providers.google.cloud.operators.cloud_storage_transfer_service`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/gcs_to_gdrive_operator.py b/airflow/contrib/operators/gcs_to_gdrive_operator.py
deleted file mode 100644
index 1fb55d1fc3dfe..0000000000000
--- a/airflow/contrib/operators/gcs_to_gdrive_operator.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.suite.transfers.gcs_to_gdrive`."""
-
-import warnings
-
-from airflow.providers.google.suite.transfers.gcs_to_gdrive import GCSToGoogleDriveOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.suite.transfers.gcs_to_gdrive.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/gcs_to_s3.py b/airflow/contrib/operators/gcs_to_s3.py
deleted file mode 100644
index 13aa005ab20a3..0000000000000
--- a/airflow/contrib/operators/gcs_to_s3.py
+++ /dev/null
@@ -1,44 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.transfers.gcs_to_s3`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.transfers.gcs_to_s3 import GCSToS3Operator
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.transfers.gcs_to_s3`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class GoogleCloudStorageToS3Operator(GCSToS3Operator):
- """
- This class is deprecated. Please use
- `airflow.providers.amazon.aws.transfers.gcs_to_s3.GCSToS3Operator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- "This class is deprecated. "
- "Please use `airflow.providers.amazon.aws.transfers.gcs_to_s3.GCSToS3Operator`.",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/grpc_operator.py b/airflow/contrib/operators/grpc_operator.py
deleted file mode 100644
index bd8cfbd6003ee..0000000000000
--- a/airflow/contrib/operators/grpc_operator.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.grpc.operators.grpc`."""
-
-import warnings
-
-from airflow.providers.grpc.operators.grpc import GrpcOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.grpc.operators.grpc`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/hive_to_dynamodb.py b/airflow/contrib/operators/hive_to_dynamodb.py
deleted file mode 100644
index ba4f8b967cc98..0000000000000
--- a/airflow/contrib/operators/hive_to_dynamodb.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.transfers.hive_to_dynamodb`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.transfers.hive_to_dynamodb import HiveToDynamoDBOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.transfers.hive_to_dynamodb`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/imap_attachment_to_s3_operator.py b/airflow/contrib/operators/imap_attachment_to_s3_operator.py
deleted file mode 100644
index e82a8bc05aeae..0000000000000
--- a/airflow/contrib/operators/imap_attachment_to_s3_operator.py
+++ /dev/null
@@ -1,31 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-This module is deprecated.
-Please use :mod:`airflow.providers.amazon.aws.transfers.imap_attachment_to_s3`.
-"""
-
-import warnings
-
-from airflow.providers.amazon.aws.transfers.imap_attachment_to_s3 import ImapAttachmentToS3Operator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.transfers.imap_attachment_to_s3`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/jenkins_job_trigger_operator.py b/airflow/contrib/operators/jenkins_job_trigger_operator.py
deleted file mode 100644
index 0b401d2b430a7..0000000000000
--- a/airflow/contrib/operators/jenkins_job_trigger_operator.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.jenkins.operators.jenkins_job_trigger`."""
-
-import warnings
-
-from airflow.providers.jenkins.operators.jenkins_job_trigger import JenkinsJobTriggerOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.jenkins.operators.jenkins_job_trigger`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/jira_operator.py b/airflow/contrib/operators/jira_operator.py
deleted file mode 100644
index b6e3b3e124a85..0000000000000
--- a/airflow/contrib/operators/jira_operator.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.jira.operators.jira`."""
-
-import warnings
-
-from airflow.providers.jira.operators.jira import JiraOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.jira.operators.jira`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/kubernetes_pod_operator.py b/airflow/contrib/operators/kubernetes_pod_operator.py
deleted file mode 100644
index 962fa22859013..0000000000000
--- a/airflow/contrib/operators/kubernetes_pod_operator.py
+++ /dev/null
@@ -1,31 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-This module is deprecated.
-Please use :mod:`airflow.providers.cncf.kubernetes.operators.kubernetes_pod`.
-"""
-
-import warnings
-
-from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import KubernetesPodOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.cncf.kubernetes.operators.kubernetes_pod`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/mlengine_operator.py b/airflow/contrib/operators/mlengine_operator.py
deleted file mode 100644
index d5d6b2fb1d24a..0000000000000
--- a/airflow/contrib/operators/mlengine_operator.py
+++ /dev/null
@@ -1,100 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.operators.mlengine`."""
-
-import warnings
-
-from airflow.providers.google.cloud.operators.mlengine import (
- MLEngineManageModelOperator,
- MLEngineManageVersionOperator,
- MLEngineStartBatchPredictionJobOperator,
- MLEngineStartTrainingJobOperator,
-)
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.operators.mlengine`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class MLEngineBatchPredictionOperator(MLEngineStartBatchPredictionJobOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.mlengine.MLEngineStartBatchPredictionJobOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use
- `airflow.providers.google.cloud.operators.mlengine.MLEngineStartBatchPredictionJobOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class MLEngineModelOperator(MLEngineManageModelOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.mlengine.MLEngineManageModelOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.mlengine.MLEngineManageModelOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class MLEngineTrainingOperator(MLEngineStartTrainingJobOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.mlengine.MLEngineStartTrainingJobOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use
- `airflow.providers.google.cloud.operators.mlengine.MLEngineStartTrainingJobOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class MLEngineVersionOperator(MLEngineManageVersionOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.mlengine.MLEngineManageVersionOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use
- `airflow.providers.google.cloud.operators.mlengine.MLEngineManageVersionOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/mongo_to_s3.py b/airflow/contrib/operators/mongo_to_s3.py
deleted file mode 100644
index 17b0676952e91..0000000000000
--- a/airflow/contrib/operators/mongo_to_s3.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.transfers.mongo_to_s3`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.transfers.mongo_to_s3 import MongoToS3Operator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.transfers.mongo_to_s3`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/mssql_to_gcs.py b/airflow/contrib/operators/mssql_to_gcs.py
deleted file mode 100644
index 140457bb709f4..0000000000000
--- a/airflow/contrib/operators/mssql_to_gcs.py
+++ /dev/null
@@ -1,44 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.transfers.mssql_to_gcs`."""
-
-import warnings
-
-from airflow.providers.google.cloud.transfers.mssql_to_gcs import MSSQLToGCSOperator
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.transfers.mssql_to_gcs`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class MsSqlToGoogleCloudStorageOperator(MSSQLToGCSOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.transfers.mssql_to_gcs.MSSQLToGCSOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.transfers.mssql_to_gcs.MSSQLToGCSOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/mysql_to_gcs.py b/airflow/contrib/operators/mysql_to_gcs.py
deleted file mode 100644
index 50363b3a39bc6..0000000000000
--- a/airflow/contrib/operators/mysql_to_gcs.py
+++ /dev/null
@@ -1,44 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.transfers.mysql_to_gcs`."""
-
-import warnings
-
-from airflow.providers.google.cloud.transfers.mysql_to_gcs import MySQLToGCSOperator
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.transfers.mysql_to_gcs`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class MySqlToGoogleCloudStorageOperator(MySQLToGCSOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.transfers.mysql_to_gcs.MySQLToGCSOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.transfers.mysql_to_gcs.MySQLToGCSOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/opsgenie_alert_operator.py b/airflow/contrib/operators/opsgenie_alert_operator.py
deleted file mode 100644
index 008214b97e020..0000000000000
--- a/airflow/contrib/operators/opsgenie_alert_operator.py
+++ /dev/null
@@ -1,44 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.opsgenie.operators.opsgenie`."""
-
-import warnings
-
-from airflow.providers.opsgenie.operators.opsgenie import OpsgenieCreateAlertOperator
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.opsgenie.operators.opsgenie`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class OpsgenieAlertOperator(OpsgenieCreateAlertOperator):
- """
- This class is deprecated.
- Please use :class:`airflow.providers.opsgenie.operators.opsgenie.OpsgenieCreateAlertOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- "This class is deprecated. "
- "Please use :class:`airflow.providers.opsgenie.operators.opsgenie.OpsgenieCreateAlertOperator`.",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/oracle_to_azure_data_lake_transfer.py b/airflow/contrib/operators/oracle_to_azure_data_lake_transfer.py
deleted file mode 100644
index 3907b6f8e8b6b..0000000000000
--- a/airflow/contrib/operators/oracle_to_azure_data_lake_transfer.py
+++ /dev/null
@@ -1,34 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-This module is deprecated.
-Please use `airflow.providers.microsoft.azure.transfers.oracle_to_azure_data_lake`.
-"""
-
-import warnings
-
-from airflow.providers.microsoft.azure.transfers.oracle_to_azure_data_lake import ( # noqa
- OracleToAzureDataLakeOperator,
-)
-
-warnings.warn(
- "This module is deprecated. "
- "Please use `airflow.providers.microsoft.azure.transfers.oracle_to_azure_data_lake`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/oracle_to_oracle_transfer.py b/airflow/contrib/operators/oracle_to_oracle_transfer.py
deleted file mode 100644
index 2efbf4d0a89b4..0000000000000
--- a/airflow/contrib/operators/oracle_to_oracle_transfer.py
+++ /dev/null
@@ -1,49 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-This module is deprecated.
-Please use :mod:`airflow.providers.oracle.transfers.oracle_to_oracle`.
-"""
-
-import warnings
-
-from airflow.providers.oracle.transfers.oracle_to_oracle import OracleToOracleOperator
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.oracle.transfers.oracle_to_oracle`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class OracleToOracleTransfer(OracleToOracleOperator):
- """This class is deprecated.
-
- Please use:
- `airflow.providers.oracle.transfers.oracle_to_oracle.OracleToOracleOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use
- `airflow.providers.oracle.transfers.oracle_to_oracle.OracleToOracleOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/postgres_to_gcs_operator.py b/airflow/contrib/operators/postgres_to_gcs_operator.py
deleted file mode 100644
index 1ce5f42304ce5..0000000000000
--- a/airflow/contrib/operators/postgres_to_gcs_operator.py
+++ /dev/null
@@ -1,44 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.transfers.postgres_to_gcs`."""
-
-import warnings
-
-from airflow.providers.google.cloud.transfers.postgres_to_gcs import PostgresToGCSOperator
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.transfers.postgres_to_gcs`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class PostgresToGoogleCloudStorageOperator(PostgresToGCSOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.transfers.postgres_to_gcs.PostgresToGCSOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.transfers.postgres_to_gcs.PostgresToGCSOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/pubsub_operator.py b/airflow/contrib/operators/pubsub_operator.py
deleted file mode 100644
index 9d85d32fefed1..0000000000000
--- a/airflow/contrib/operators/pubsub_operator.py
+++ /dev/null
@@ -1,118 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-This module is deprecated.
-Please use :mod:`airflow.providers.google.cloud.operators.pubsub`.
-"""
-
-import warnings
-
-from airflow.providers.google.cloud.operators.pubsub import (
- PubSubCreateSubscriptionOperator,
- PubSubCreateTopicOperator,
- PubSubDeleteSubscriptionOperator,
- PubSubDeleteTopicOperator,
- PubSubPublishMessageOperator,
-)
-
-warnings.warn(
- """This module is deprecated.
- "Please use `airflow.providers.google.cloud.operators.pubsub`.""",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class PubSubPublishOperator(PubSubPublishMessageOperator):
- """This class is deprecated.
-
- Please use `airflow.providers.google.cloud.operators.pubsub.PubSubPublishMessageOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.pubsub.PubSubPublishMessageOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class PubSubSubscriptionCreateOperator(PubSubCreateSubscriptionOperator):
- """This class is deprecated.
-
- Please use `airflow.providers.google.cloud.operators.pubsub.PubSubCreateSubscriptionOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.pubsub.PubSubCreateSubscriptionOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class PubSubSubscriptionDeleteOperator(PubSubDeleteSubscriptionOperator):
- """This class is deprecated.
-
- Please use `airflow.providers.google.cloud.operators.pubsub.PubSubDeleteSubscriptionOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.pubsub.PubSubDeleteSubscriptionOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class PubSubTopicCreateOperator(PubSubCreateTopicOperator):
- """This class is deprecated.
-
- Please use `airflow.providers.google.cloud.operators.pubsub.PubSubCreateTopicOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.pubsub.PubSubCreateTopicOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class PubSubTopicDeleteOperator(PubSubDeleteTopicOperator):
- """This class is deprecated.
-
- Please use `airflow.providers.google.cloud.operators.pubsub.PubSubDeleteTopicOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.operators.pubsub.PubSubDeleteTopicOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/qubole_check_operator.py b/airflow/contrib/operators/qubole_check_operator.py
deleted file mode 100644
index e42a9e0c8f13a..0000000000000
--- a/airflow/contrib/operators/qubole_check_operator.py
+++ /dev/null
@@ -1,31 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.qubole.operators.qubole_check`."""
-
-import warnings
-
-from airflow.providers.qubole.operators.qubole_check import ( # noqa
- QuboleCheckOperator,
- QuboleValueCheckOperator,
-)
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.qubole.operators.qubole_check`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/qubole_operator.py b/airflow/contrib/operators/qubole_operator.py
deleted file mode 100644
index e4a30748a03c6..0000000000000
--- a/airflow/contrib/operators/qubole_operator.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.qubole.operators.qubole`."""
-
-import warnings
-
-from airflow.providers.qubole.operators.qubole import QuboleOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.qubole.operators.qubole`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/redis_publish_operator.py b/airflow/contrib/operators/redis_publish_operator.py
deleted file mode 100644
index 994d9323ad170..0000000000000
--- a/airflow/contrib/operators/redis_publish_operator.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.redis.operators.redis_publish`."""
-
-import warnings
-
-from airflow.providers.redis.operators.redis_publish import RedisPublishOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.redis.operators.redis_publish`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/s3_copy_object_operator.py b/airflow/contrib/operators/s3_copy_object_operator.py
deleted file mode 100644
index cbe9c63440629..0000000000000
--- a/airflow/contrib/operators/s3_copy_object_operator.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.s3_copy_object`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.operators.s3_copy_object import S3CopyObjectOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.s3_copy_object`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/s3_delete_objects_operator.py b/airflow/contrib/operators/s3_delete_objects_operator.py
deleted file mode 100644
index a0ab210324b55..0000000000000
--- a/airflow/contrib/operators/s3_delete_objects_operator.py
+++ /dev/null
@@ -1,31 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-This module is deprecated.
-Please use :mod:`airflow.providers.amazon.aws.operators.s3_delete_objects`.
-"""
-
-import warnings
-
-from airflow.providers.amazon.aws.operators.s3_delete_objects import S3DeleteObjectsOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.s3_delete_objects`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/s3_list_operator.py b/airflow/contrib/operators/s3_list_operator.py
deleted file mode 100644
index 172b94cb116c2..0000000000000
--- a/airflow/contrib/operators/s3_list_operator.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.s3_list`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.operators.s3_list import S3ListOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.s3_list`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/s3_to_gcs_operator.py b/airflow/contrib/operators/s3_to_gcs_operator.py
deleted file mode 100644
index d0ea8e09bdbb1..0000000000000
--- a/airflow/contrib/operators/s3_to_gcs_operator.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.transfers.s3_to_gcs`."""
-
-import warnings
-
-from airflow.providers.google.cloud.transfers.s3_to_gcs import S3ToGCSOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.transfers.s3_to_gcs`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/s3_to_gcs_transfer_operator.py b/airflow/contrib/operators/s3_to_gcs_transfer_operator.py
deleted file mode 100644
index 71df06229334b..0000000000000
--- a/airflow/contrib/operators/s3_to_gcs_transfer_operator.py
+++ /dev/null
@@ -1,33 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-This module is deprecated.
-Please use `airflow.providers.google.cloud.operators.cloud_storage_transfer_service`.
-"""
-import warnings
-
-from airflow.providers.google.cloud.operators.cloud_storage_transfer_service import ( # noqa isort:skip
- CloudDataTransferServiceS3ToGCSOperator,
-)
-
-warnings.warn(
- "This module is deprecated. "
- "Please use `airflow.providers.google.cloud.operators.cloud_storage_transfer_service`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/s3_to_sftp_operator.py b/airflow/contrib/operators/s3_to_sftp_operator.py
deleted file mode 100644
index e129af1471782..0000000000000
--- a/airflow/contrib/operators/s3_to_sftp_operator.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.transfers.s3_to_sftp`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.transfers.s3_to_sftp import S3ToSFTPOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.transfers.s3_to_sftp`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/sagemaker_base_operator.py b/airflow/contrib/operators/sagemaker_base_operator.py
deleted file mode 100644
index 4c2c8f6baf131..0000000000000
--- a/airflow/contrib/operators/sagemaker_base_operator.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.sagemaker_base`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.operators.sagemaker_base import SageMakerBaseOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker_base`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/sagemaker_endpoint_config_operator.py b/airflow/contrib/operators/sagemaker_endpoint_config_operator.py
deleted file mode 100644
index 43945b222734d..0000000000000
--- a/airflow/contrib/operators/sagemaker_endpoint_config_operator.py
+++ /dev/null
@@ -1,34 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-This module is deprecated.
-Please use `airflow.providers.amazon.aws.operators.sagemaker_endpoint_config`.
-"""
-
-import warnings
-
-from airflow.providers.amazon.aws.operators.sagemaker_endpoint_config import ( # noqa
- SageMakerEndpointConfigOperator,
-)
-
-warnings.warn(
- "This module is deprecated. "
- "Please use `airflow.providers.amazon.aws.operators.sagemaker_endpoint_config`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/sagemaker_endpoint_operator.py b/airflow/contrib/operators/sagemaker_endpoint_operator.py
deleted file mode 100644
index fe175a67b29cf..0000000000000
--- a/airflow/contrib/operators/sagemaker_endpoint_operator.py
+++ /dev/null
@@ -1,31 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-This module is deprecated.
-Please use :mod:`airflow.providers.amazon.aws.operators.sagemaker_endpoint`.
-"""
-
-import warnings
-
-from airflow.providers.amazon.aws.operators.sagemaker_endpoint import SageMakerEndpointOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker_endpoint`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/sagemaker_model_operator.py b/airflow/contrib/operators/sagemaker_model_operator.py
deleted file mode 100644
index 9a003485606ab..0000000000000
--- a/airflow/contrib/operators/sagemaker_model_operator.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.sagemaker_model`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.operators.sagemaker_model import SageMakerModelOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker_model`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/sagemaker_training_operator.py b/airflow/contrib/operators/sagemaker_training_operator.py
deleted file mode 100644
index d3749c68573d2..0000000000000
--- a/airflow/contrib/operators/sagemaker_training_operator.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.sagemaker_training`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.operators.sagemaker_training import SageMakerTrainingOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker_training`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/sagemaker_transform_operator.py b/airflow/contrib/operators/sagemaker_transform_operator.py
deleted file mode 100644
index 93cf7070db220..0000000000000
--- a/airflow/contrib/operators/sagemaker_transform_operator.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.sagemaker_transform`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.operators.sagemaker_transform import SageMakerTransformOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker_transform`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/sagemaker_tuning_operator.py b/airflow/contrib/operators/sagemaker_tuning_operator.py
deleted file mode 100644
index 05760a74569cf..0000000000000
--- a/airflow/contrib/operators/sagemaker_tuning_operator.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.sagemaker_tuning`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.operators.sagemaker_tuning import SageMakerTuningOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker_tuning`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/segment_track_event_operator.py b/airflow/contrib/operators/segment_track_event_operator.py
deleted file mode 100644
index 92419a1977405..0000000000000
--- a/airflow/contrib/operators/segment_track_event_operator.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.segment.operators.segment_track_event`."""
-
-import warnings
-
-from airflow.providers.segment.operators.segment_track_event import SegmentTrackEventOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.segment.operators.segment_track_event`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/sftp_operator.py b/airflow/contrib/operators/sftp_operator.py
deleted file mode 100644
index e73a84743c83b..0000000000000
--- a/airflow/contrib/operators/sftp_operator.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.sftp.operators.sftp`."""
-
-import warnings
-
-from airflow.providers.sftp.operators.sftp import SFTPOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.sftp.operators.sftp`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/sftp_to_s3_operator.py b/airflow/contrib/operators/sftp_to_s3_operator.py
deleted file mode 100644
index 7c13b1817d63e..0000000000000
--- a/airflow/contrib/operators/sftp_to_s3_operator.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.transfers.sftp_to_s3`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.transfers.sftp_to_s3 import SFTPToS3Operator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.transfers.sftp_to_s3`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/slack_webhook_operator.py b/airflow/contrib/operators/slack_webhook_operator.py
deleted file mode 100644
index f271102e14550..0000000000000
--- a/airflow/contrib/operators/slack_webhook_operator.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.slack.operators.slack_webhook`."""
-
-import warnings
-
-from airflow.providers.slack.operators.slack_webhook import SlackWebhookOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.slack.operators.slack_webhook`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/snowflake_operator.py b/airflow/contrib/operators/snowflake_operator.py
deleted file mode 100644
index f01cc72d1a531..0000000000000
--- a/airflow/contrib/operators/snowflake_operator.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.snowflake.operators.snowflake`."""
-
-import warnings
-
-from airflow.providers.snowflake.operators.snowflake import SnowflakeOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.snowflake.operators.snowflake`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/sns_publish_operator.py b/airflow/contrib/operators/sns_publish_operator.py
deleted file mode 100644
index 104e240836f12..0000000000000
--- a/airflow/contrib/operators/sns_publish_operator.py
+++ /dev/null
@@ -1,29 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.sns`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.operators.sns import SnsPublishOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sns`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/spark_jdbc_operator.py b/airflow/contrib/operators/spark_jdbc_operator.py
deleted file mode 100644
index fc3cdc02704c6..0000000000000
--- a/airflow/contrib/operators/spark_jdbc_operator.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.apache.spark.operators.spark_jdbc`."""
-
-import warnings
-
-from airflow.providers.apache.spark.operators.spark_jdbc import SparkJDBCOperator, SparkSubmitOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.apache.spark.operators.spark_jdbc`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/spark_sql_operator.py b/airflow/contrib/operators/spark_sql_operator.py
deleted file mode 100644
index 19e20d215dbf5..0000000000000
--- a/airflow/contrib/operators/spark_sql_operator.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.apache.spark.operators.spark_sql`."""
-
-import warnings
-
-from airflow.providers.apache.spark.operators.spark_sql import SparkSqlOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.apache.spark.operators.spark_sql`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/spark_submit_operator.py b/airflow/contrib/operators/spark_submit_operator.py
deleted file mode 100644
index 103187e445fa4..0000000000000
--- a/airflow/contrib/operators/spark_submit_operator.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.apache.spark.operators.spark_submit`."""
-
-import warnings
-
-from airflow.providers.apache.spark.operators.spark_submit import SparkSubmitOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.apache.spark.operators.spark_submit`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/sql_to_gcs.py b/airflow/contrib/operators/sql_to_gcs.py
deleted file mode 100644
index 13aa869c5d98d..0000000000000
--- a/airflow/contrib/operators/sql_to_gcs.py
+++ /dev/null
@@ -1,44 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.transfers.sql_to_gcs`."""
-
-import warnings
-
-from airflow.providers.google.cloud.transfers.sql_to_gcs import BaseSQLToGCSOperator
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.transfers.sql_to_gcs`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class BaseSQLToGoogleCloudStorageOperator(BaseSQLToGCSOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.transfers.sql_to_gcs.BaseSQLToGCSOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.transfers.sql_to_gcs.BaseSQLToGCSOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/sqoop_operator.py b/airflow/contrib/operators/sqoop_operator.py
deleted file mode 100644
index 2757847abd65e..0000000000000
--- a/airflow/contrib/operators/sqoop_operator.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.apache.sqoop.operators.sqoop`."""
-
-import warnings
-
-from airflow.providers.apache.sqoop.operators.sqoop import SqoopOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.apache.sqoop.operators.sqoop`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/ssh_operator.py b/airflow/contrib/operators/ssh_operator.py
deleted file mode 100644
index 56f94b9b26b63..0000000000000
--- a/airflow/contrib/operators/ssh_operator.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.ssh.operators.ssh`."""
-
-import warnings
-
-from airflow.providers.ssh.operators.ssh import SSHOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.ssh.operators.ssh`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/vertica_operator.py b/airflow/contrib/operators/vertica_operator.py
deleted file mode 100644
index e652512ad4056..0000000000000
--- a/airflow/contrib/operators/vertica_operator.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.vertica.operators.vertica`."""
-
-import warnings
-
-from airflow.providers.vertica.operators.vertica import VerticaOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.vertica.operators.vertica`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/vertica_to_hive.py b/airflow/contrib/operators/vertica_to_hive.py
deleted file mode 100644
index 49ddea172f374..0000000000000
--- a/airflow/contrib/operators/vertica_to_hive.py
+++ /dev/null
@@ -1,49 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-This module is deprecated.
-Please use :mod:`airflow.providers.apache.hive.transfers.vertica_to_hive`.
-"""
-
-import warnings
-
-from airflow.providers.apache.hive.transfers.vertica_to_hive import VerticaToHiveOperator
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.apache.hive.transfers.vertica_to_hive`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class VerticaToHiveTransfer(VerticaToHiveOperator):
- """This class is deprecated.
-
- Please use:
- `airflow.providers.apache.hive.transfers.vertica_to_hive.VerticaToHiveOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use
- `airflow.providers.apache.hive.transfers.vertica_to_hive.VerticaToHiveOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/vertica_to_mysql.py b/airflow/contrib/operators/vertica_to_mysql.py
deleted file mode 100644
index c85738f9091b2..0000000000000
--- a/airflow/contrib/operators/vertica_to_mysql.py
+++ /dev/null
@@ -1,49 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-This module is deprecated.
-Please use :mod:`airflow.providers.mysql.transfers.vertica_to_mysql`.
-"""
-
-import warnings
-
-from airflow.providers.mysql.transfers.vertica_to_mysql import VerticaToMySqlOperator
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.mysql.transfers.vertica_to_mysql`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class VerticaToMySqlTransfer(VerticaToMySqlOperator):
- """This class is deprecated.
-
- Please use:
- `airflow.providers.mysql.transfers.vertica_to_mysql.VerticaToMySqlOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use
- `airflow.providers.mysql.transfers.vertica_to_mysql.VerticaToMySqlOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/operators/wasb_delete_blob_operator.py b/airflow/contrib/operators/wasb_delete_blob_operator.py
deleted file mode 100644
index cbf11b38fbf86..0000000000000
--- a/airflow/contrib/operators/wasb_delete_blob_operator.py
+++ /dev/null
@@ -1,31 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-This module is deprecated.
-Please use :mod:`airflow.providers.microsoft.azure.operators.wasb_delete_blob`.
-"""
-
-import warnings
-
-from airflow.providers.microsoft.azure.operators.wasb_delete_blob import WasbDeleteBlobOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.microsoft.azure.operators.wasb_delete_blob`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/operators/winrm_operator.py b/airflow/contrib/operators/winrm_operator.py
deleted file mode 100644
index fcc6213e71d6e..0000000000000
--- a/airflow/contrib/operators/winrm_operator.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.winrm.operators.winrm`."""
-
-import warnings
-
-from airflow.providers.microsoft.winrm.operators.winrm import WinRMOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.microsoft.winrm.operators.winrm`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/secrets/__init__.py b/airflow/contrib/secrets/__init__.py
index 31cf12a8e38db..9e2b4c74abdd7 100644
--- a/airflow/contrib/secrets/__init__.py
+++ b/airflow/contrib/secrets/__init__.py
@@ -16,3 +16,41 @@
# specific language governing permissions and limitations
# under the License.
"""This package is deprecated. Please use `airflow.secrets` or `airflow.providers.*.secrets`."""
+from __future__ import annotations
+
+import warnings
+
+from airflow.exceptions import RemovedInAirflow3Warning
+from airflow.utils.deprecation_tools import add_deprecated_classes
+
+warnings.warn(
+ "This module is deprecated. Please use airflow.providers.*.secrets.",
+ RemovedInAirflow3Warning,
+ stacklevel=2
+)
+
+__deprecated_classes = {
+ 'aws_secrets_manager': {
+ 'SecretsManagerBackend': 'airflow.providers.amazon.aws.secrets.secrets_manager.SecretsManagerBackend',
+ },
+ 'aws_systems_manager': {
+ 'SystemsManagerParameterStoreBackend': (
+ 'airflow.providers.amazon.aws.secrets.systems_manager.SystemsManagerParameterStoreBackend'
+ ),
+ },
+ 'azure_key_vault': {
+ 'AzureKeyVaultBackend': 'airflow.providers.microsoft.azure.secrets.key_vault.AzureKeyVaultBackend',
+ },
+ 'gcp_secrets_manager': {
+ 'CloudSecretManagerBackend': (
+ 'airflow.providers.google.cloud.secrets.secret_manager.CloudSecretManagerBackend'
+ ),
+ 'CloudSecretsManagerBackend': (
+ 'airflow.providers.google.cloud.secrets.secret_manager.CloudSecretManagerBackend'
+ ),
+ },
+ 'hashicorp_vault': {
+ 'VaultBackend': 'airflow.providers.hashicorp.secrets.vault.VaultBackend',
+ },
+}
+add_deprecated_classes(__deprecated_classes, __name__)
diff --git a/airflow/contrib/secrets/aws_secrets_manager.py b/airflow/contrib/secrets/aws_secrets_manager.py
deleted file mode 100644
index 833b03a59b06c..0000000000000
--- a/airflow/contrib/secrets/aws_secrets_manager.py
+++ /dev/null
@@ -1,29 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.secrets.secrets_manager`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.secrets.secrets_manager import SecretsManagerBackend # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.secrets.secrets_manager`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/secrets/aws_systems_manager.py b/airflow/contrib/secrets/aws_systems_manager.py
deleted file mode 100644
index 4c7a30cf05ab7..0000000000000
--- a/airflow/contrib/secrets/aws_systems_manager.py
+++ /dev/null
@@ -1,29 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.secrets.systems_manager`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.secrets.systems_manager import SystemsManagerParameterStoreBackend # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.secrets.systems_manager`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/secrets/azure_key_vault.py b/airflow/contrib/secrets/azure_key_vault.py
deleted file mode 100644
index 000ae92b3ac28..0000000000000
--- a/airflow/contrib/secrets/azure_key_vault.py
+++ /dev/null
@@ -1,29 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.secrets.key_vault`."""
-
-import warnings
-
-from airflow.providers.microsoft.azure.secrets.key_vault import AzureKeyVaultBackend # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.microsoft.azure.secrets.key_vault`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/secrets/gcp_secrets_manager.py b/airflow/contrib/secrets/gcp_secrets_manager.py
deleted file mode 100644
index 7caa7ea2e88a7..0000000000000
--- a/airflow/contrib/secrets/gcp_secrets_manager.py
+++ /dev/null
@@ -1,45 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.secrets.secret_manager`."""
-
-import warnings
-
-from airflow.providers.google.cloud.secrets.secret_manager import CloudSecretManagerBackend
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.secrets.secret_manager`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class CloudSecretsManagerBackend(CloudSecretManagerBackend):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.secrets.secret_manager.CloudSecretManagerBackend`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.secrets.secret_manager.CloudSecretManagerBackend`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/secrets/hashicorp_vault.py b/airflow/contrib/secrets/hashicorp_vault.py
deleted file mode 100644
index a3158d5d03abf..0000000000000
--- a/airflow/contrib/secrets/hashicorp_vault.py
+++ /dev/null
@@ -1,29 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""This module is deprecated. Please use :mod:`airflow.providers.hashicorp.secrets.vault`."""
-
-import warnings
-
-from airflow.providers.hashicorp.secrets.vault import VaultBackend # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.hashicorp.secrets.vault`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/sensors/__init__.py b/airflow/contrib/sensors/__init__.py
index de90d6c210596..65739f1ed0eb5 100644
--- a/airflow/contrib/sensors/__init__.py
+++ b/airflow/contrib/sensors/__init__.py
@@ -16,11 +16,151 @@
# specific language governing permissions and limitations
# under the License.
"""This package is deprecated. Please use `airflow.sensors` or `airflow.providers.*.sensors`."""
+from __future__ import annotations
import warnings
+from airflow.exceptions import RemovedInAirflow3Warning
+from airflow.utils.deprecation_tools import add_deprecated_classes
+
warnings.warn(
"This package is deprecated. Please use `airflow.sensors` or `airflow.providers.*.sensors`.",
- DeprecationWarning,
+ RemovedInAirflow3Warning,
stacklevel=2,
)
+
+__deprecated_classes = {
+ 'aws_athena_sensor': {
+ 'AthenaSensor': 'airflow.providers.amazon.aws.sensors.athena.AthenaSensor',
+ },
+ 'aws_glue_catalog_partition_sensor': {
+ 'AwsGlueCatalogPartitionSensor': (
+ 'airflow.providers.amazon.aws.sensors.glue_catalog_partition.GlueCatalogPartitionSensor'
+ ),
+ },
+ 'aws_redshift_cluster_sensor': {
+ 'AwsRedshiftClusterSensor': (
+ 'airflow.providers.amazon.aws.sensors.redshift_cluster.RedshiftClusterSensor'
+ ),
+ },
+ 'aws_sqs_sensor': {
+ 'SqsSensor': 'airflow.providers.amazon.aws.sensors.sqs.SqsSensor',
+ 'SQSSensor': 'airflow.providers.amazon.aws.sensors.sqs.SqsSensor',
+ },
+ 'azure_cosmos_sensor': {
+ 'AzureCosmosDocumentSensor': (
+ 'airflow.providers.microsoft.azure.sensors.cosmos.AzureCosmosDocumentSensor'
+ ),
+ },
+ 'bash_sensor': {
+ 'STDOUT': 'airflow.sensors.bash.STDOUT',
+ 'BashSensor': 'airflow.sensors.bash.BashSensor',
+ 'Popen': 'airflow.sensors.bash.Popen',
+ 'TemporaryDirectory': 'airflow.sensors.bash.TemporaryDirectory',
+ 'gettempdir': 'airflow.sensors.bash.gettempdir',
+ },
+ 'bigquery_sensor': {
+ 'BigQueryTableExistenceSensor': (
+ 'airflow.providers.google.cloud.sensors.bigquery.BigQueryTableExistenceSensor'
+ ),
+ 'BigQueryTableSensor': 'airflow.providers.google.cloud.sensors.bigquery.BigQueryTableExistenceSensor',
+ },
+ 'cassandra_record_sensor': {
+ 'CassandraRecordSensor': 'airflow.providers.apache.cassandra.sensors.record.CassandraRecordSensor',
+ },
+ 'cassandra_table_sensor': {
+ 'CassandraTableSensor': 'airflow.providers.apache.cassandra.sensors.table.CassandraTableSensor',
+ },
+ 'celery_queue_sensor': {
+ 'CeleryQueueSensor': 'airflow.providers.celery.sensors.celery_queue.CeleryQueueSensor',
+ },
+ 'datadog_sensor': {
+ 'DatadogSensor': 'airflow.providers.datadog.sensors.datadog.DatadogSensor',
+ },
+ 'file_sensor': {
+ 'FileSensor': 'airflow.sensors.filesystem.FileSensor',
+ },
+ 'ftp_sensor': {
+ 'FTPSensor': 'airflow.providers.ftp.sensors.ftp.FTPSensor',
+ 'FTPSSensor': 'airflow.providers.ftp.sensors.ftp.FTPSSensor',
+ },
+ 'gcp_transfer_sensor': {
+ 'CloudDataTransferServiceJobStatusSensor':
+ 'airflow.providers.google.cloud.sensors.cloud_storage_transfer_service.'
+ 'CloudDataTransferServiceJobStatusSensor',
+ 'GCPTransferServiceWaitForJobStatusSensor':
+ 'airflow.providers.google.cloud.sensors.cloud_storage_transfer_service.'
+ 'CloudDataTransferServiceJobStatusSensor',
+ },
+ 'gcs_sensor': {
+ 'GCSObjectExistenceSensor': 'airflow.providers.google.cloud.sensors.gcs.GCSObjectExistenceSensor',
+ 'GCSObjectsWithPrefixExistenceSensor': (
+ 'airflow.providers.google.cloud.sensors.gcs.GCSObjectsWithPrefixExistenceSensor'
+ ),
+ 'GCSObjectUpdateSensor': 'airflow.providers.google.cloud.sensors.gcs.GCSObjectUpdateSensor',
+ 'GCSUploadSessionCompleteSensor': (
+ 'airflow.providers.google.cloud.sensors.gcs.GCSUploadSessionCompleteSensor'
+ ),
+ 'GoogleCloudStorageObjectSensor': (
+ 'airflow.providers.google.cloud.sensors.gcs.GCSObjectExistenceSensor'
+ ),
+ 'GoogleCloudStorageObjectUpdatedSensor': (
+ 'airflow.providers.google.cloud.sensors.gcs.GCSObjectUpdateSensor'
+ ),
+ 'GoogleCloudStoragePrefixSensor': (
+ 'airflow.providers.google.cloud.sensors.gcs.GCSObjectsWithPrefixExistenceSensor'
+ ),
+ 'GoogleCloudStorageUploadSessionCompleteSensor': (
+ 'airflow.providers.google.cloud.sensors.gcs.GCSUploadSessionCompleteSensor'
+ ),
+ },
+ 'hdfs_sensor': {
+ 'HdfsFolderSensor': 'airflow.providers.apache.hdfs.sensors.hdfs.HdfsFolderSensor',
+ 'HdfsRegexSensor': 'airflow.providers.apache.hdfs.sensors.hdfs.HdfsRegexSensor',
+ 'HdfsSensorFolder': 'airflow.providers.apache.hdfs.sensors.hdfs.HdfsFolderSensor',
+ 'HdfsSensorRegex': 'airflow.providers.apache.hdfs.sensors.hdfs.HdfsRegexSensor',
+ },
+ 'imap_attachment_sensor': {
+ 'ImapAttachmentSensor': 'airflow.providers.imap.sensors.imap_attachment.ImapAttachmentSensor',
+ },
+ 'jira_sensor': {
+ 'JiraSensor': 'airflow.providers.atlassian.jira.sensors.jira.JiraSensor',
+ 'JiraTicketSensor': 'airflow.providers.atlassian.jira.sensors.jira.JiraTicketSensor',
+ },
+ 'mongo_sensor': {
+ 'MongoSensor': 'airflow.providers.mongo.sensors.mongo.MongoSensor',
+ },
+ 'pubsub_sensor': {
+ 'PubSubPullSensor': 'airflow.providers.google.cloud.sensors.pubsub.PubSubPullSensor',
+ },
+ 'python_sensor': {
+ 'PythonSensor': 'airflow.sensors.python.PythonSensor',
+ },
+ 'qubole_sensor': {
+ 'QuboleFileSensor': 'airflow.providers.qubole.sensors.qubole.QuboleFileSensor',
+ 'QubolePartitionSensor': 'airflow.providers.qubole.sensors.qubole.QubolePartitionSensor',
+ 'QuboleSensor': 'airflow.providers.qubole.sensors.qubole.QuboleSensor',
+ },
+ 'redis_key_sensor': {
+ 'RedisKeySensor': 'airflow.providers.redis.sensors.redis_key.RedisKeySensor',
+ },
+ 'redis_pub_sub_sensor': {
+ 'RedisPubSubSensor': 'airflow.providers.redis.sensors.redis_pub_sub.RedisPubSubSensor',
+ },
+ 'sagemaker_training_sensor': {
+ 'SageMakerHook': 'airflow.providers.amazon.aws.sensors.sagemaker.SageMakerHook',
+ 'SageMakerTrainingSensor': 'airflow.providers.amazon.aws.sensors.sagemaker.SageMakerTrainingSensor',
+ },
+ 'sftp_sensor': {
+ 'SFTPSensor': 'airflow.providers.sftp.sensors.sftp.SFTPSensor',
+ },
+ 'wasb_sensor': {
+ 'WasbBlobSensor': 'airflow.providers.microsoft.azure.sensors.wasb.WasbBlobSensor',
+ 'WasbPrefixSensor': 'airflow.providers.microsoft.azure.sensors.wasb.WasbPrefixSensor',
+ },
+ 'weekday_sensor': {
+ 'DayOfWeekSensor': 'airflow.sensors.weekday.DayOfWeekSensor',
+ },
+}
+
+add_deprecated_classes(__deprecated_classes, __name__)
diff --git a/airflow/contrib/sensors/aws_athena_sensor.py b/airflow/contrib/sensors/aws_athena_sensor.py
deleted file mode 100644
index ddffc38bb63e8..0000000000000
--- a/airflow/contrib/sensors/aws_athena_sensor.py
+++ /dev/null
@@ -1,29 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.athena`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.sensors.athena import AthenaSensor # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.athena`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/sensors/aws_glue_catalog_partition_sensor.py b/airflow/contrib/sensors/aws_glue_catalog_partition_sensor.py
deleted file mode 100644
index 66975a86f3e2f..0000000000000
--- a/airflow/contrib/sensors/aws_glue_catalog_partition_sensor.py
+++ /dev/null
@@ -1,31 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-This module is deprecated.
-Please use :mod:`airflow.providers.amazon.aws.sensors.glue_catalog_partition`.
-"""
-
-import warnings
-
-from airflow.providers.amazon.aws.sensors.glue_catalog_partition import AwsGlueCatalogPartitionSensor # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.glue_catalog_partition`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/sensors/aws_redshift_cluster_sensor.py b/airflow/contrib/sensors/aws_redshift_cluster_sensor.py
deleted file mode 100644
index dc0da1372b400..0000000000000
--- a/airflow/contrib/sensors/aws_redshift_cluster_sensor.py
+++ /dev/null
@@ -1,33 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.redshift_cluster`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.sensors.redshift_cluster import (
- RedshiftClusterSensor as AwsRedshiftClusterSensor,
-)
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.redshift_cluster`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-__all__ = ["AwsRedshiftClusterSensor"]
diff --git a/airflow/contrib/sensors/aws_sqs_sensor.py b/airflow/contrib/sensors/aws_sqs_sensor.py
deleted file mode 100644
index ed251f1ee9c22..0000000000000
--- a/airflow/contrib/sensors/aws_sqs_sensor.py
+++ /dev/null
@@ -1,45 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.sqs`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.sensors.sqs import SqsSensor
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sqs`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class SQSSensor(SqsSensor):
- """
- This sensor is deprecated.
- Please use :class:`airflow.providers.amazon.aws.sensors.sqs.SqsSensor`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- "This class is deprecated. "
- "Please use :class:`airflow.providers.amazon.aws.sensors.sqs.SqsSensor`.",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/sensors/azure_cosmos_sensor.py b/airflow/contrib/sensors/azure_cosmos_sensor.py
deleted file mode 100644
index fc3df4f26615e..0000000000000
--- a/airflow/contrib/sensors/azure_cosmos_sensor.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.sensors.cosmos`."""
-
-import warnings
-
-from airflow.providers.microsoft.azure.sensors.cosmos import AzureCosmosDocumentSensor # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.microsoft.azure.sensors.cosmos`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/sensors/bash_sensor.py b/airflow/contrib/sensors/bash_sensor.py
deleted file mode 100644
index c3d9c814696c7..0000000000000
--- a/airflow/contrib/sensors/bash_sensor.py
+++ /dev/null
@@ -1,26 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.sensors.bash`."""
-
-import warnings
-
-from airflow.sensors.bash import STDOUT, BashSensor, Popen, TemporaryDirectory, gettempdir # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.sensors.bash`.", DeprecationWarning, stacklevel=2
-)
diff --git a/airflow/contrib/sensors/bigquery_sensor.py b/airflow/contrib/sensors/bigquery_sensor.py
deleted file mode 100644
index d58445e3d174b..0000000000000
--- a/airflow/contrib/sensors/bigquery_sensor.py
+++ /dev/null
@@ -1,44 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.sensors.bigquery`."""
-
-import warnings
-
-from airflow.providers.google.cloud.sensors.bigquery import BigQueryTableExistenceSensor
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.sensors.bigquery`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class BigQueryTableSensor(BigQueryTableExistenceSensor):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.sensors.bigquery.BigQueryTableExistenceSensor`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.sensors.bigquery.BigQueryTableExistenceSensor`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/sensors/cassandra_record_sensor.py b/airflow/contrib/sensors/cassandra_record_sensor.py
deleted file mode 100644
index cfc3b30107ead..0000000000000
--- a/airflow/contrib/sensors/cassandra_record_sensor.py
+++ /dev/null
@@ -1,29 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""This module is deprecated. Please use :mod:`airflow.providers.apache.cassandra.sensors.record`."""
-
-import warnings
-
-from airflow.providers.apache.cassandra.sensors.record import CassandraRecordSensor # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.apache.cassandra.sensors.record`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/sensors/cassandra_table_sensor.py b/airflow/contrib/sensors/cassandra_table_sensor.py
deleted file mode 100644
index 0b7c7aa6eb73f..0000000000000
--- a/airflow/contrib/sensors/cassandra_table_sensor.py
+++ /dev/null
@@ -1,29 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""This module is deprecated. Please use :mod:`airflow.providers.apache.cassandra.sensors.table`."""
-
-import warnings
-
-from airflow.providers.apache.cassandra.sensors.table import CassandraTableSensor # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.apache.cassandra.sensors.table`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/sensors/celery_queue_sensor.py b/airflow/contrib/sensors/celery_queue_sensor.py
deleted file mode 100644
index 6ed2be1c93ac5..0000000000000
--- a/airflow/contrib/sensors/celery_queue_sensor.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.celery.sensors.celery_queue`."""
-
-import warnings
-
-from airflow.providers.celery.sensors.celery_queue import CeleryQueueSensor # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.celery.sensors.celery_queue`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/sensors/datadog_sensor.py b/airflow/contrib/sensors/datadog_sensor.py
deleted file mode 100644
index d0377d11a8889..0000000000000
--- a/airflow/contrib/sensors/datadog_sensor.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.datadog.sensors.datadog`."""
-
-import warnings
-
-from airflow.providers.datadog.sensors.datadog import DatadogSensor # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.datadog.sensors.datadog`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/sensors/emr_base_sensor.py b/airflow/contrib/sensors/emr_base_sensor.py
deleted file mode 100644
index 08d0efed81475..0000000000000
--- a/airflow/contrib/sensors/emr_base_sensor.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.emr_base`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.sensors.emr_base import EmrBaseSensor # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.emr_base`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/sensors/emr_job_flow_sensor.py b/airflow/contrib/sensors/emr_job_flow_sensor.py
deleted file mode 100644
index 429052a4ec969..0000000000000
--- a/airflow/contrib/sensors/emr_job_flow_sensor.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.emr_job_flow`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.sensors.emr_job_flow import EmrJobFlowSensor # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.emr_job_flow`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/sensors/emr_step_sensor.py b/airflow/contrib/sensors/emr_step_sensor.py
deleted file mode 100644
index 9d4ac9b166ed6..0000000000000
--- a/airflow/contrib/sensors/emr_step_sensor.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.emr_step`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.sensors.emr_step import EmrStepSensor # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.emr_step`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/sensors/file_sensor.py b/airflow/contrib/sensors/file_sensor.py
deleted file mode 100644
index 6d75b657e6cef..0000000000000
--- a/airflow/contrib/sensors/file_sensor.py
+++ /dev/null
@@ -1,26 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.sensors.filesystem`."""
-
-import warnings
-
-from airflow.sensors.filesystem import FileSensor # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.sensors.filesystem`.", DeprecationWarning, stacklevel=2
-)
diff --git a/airflow/contrib/sensors/ftp_sensor.py b/airflow/contrib/sensors/ftp_sensor.py
deleted file mode 100644
index 76c47c4609854..0000000000000
--- a/airflow/contrib/sensors/ftp_sensor.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.ftp.sensors.ftp`."""
-
-import warnings
-
-from airflow.providers.ftp.sensors.ftp import FTPSensor, FTPSSensor # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.ftp.sensors.ftp`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/sensors/gcp_transfer_sensor.py b/airflow/contrib/sensors/gcp_transfer_sensor.py
deleted file mode 100644
index 429ddb1a44b05..0000000000000
--- a/airflow/contrib/sensors/gcp_transfer_sensor.py
+++ /dev/null
@@ -1,51 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-This module is deprecated.
-Please use :mod:`airflow.providers.google.cloud.sensors.cloud_storage_transfer_service`.
-"""
-
-import warnings
-
-from airflow.providers.google.cloud.sensors.cloud_storage_transfer_service import (
- CloudDataTransferServiceJobStatusSensor,
-)
-
-warnings.warn(
- "This module is deprecated. "
- "Please use `airflow.providers.google.cloud.sensors.cloud_storage_transfer_service`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class GCPTransferServiceWaitForJobStatusSensor(CloudDataTransferServiceJobStatusSensor):
- """This class is deprecated.
-
- Please use `airflow.providers.google.cloud.sensors.transfer.CloudDataTransferServiceJobStatusSensor`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use
- `airflow.providers.google.cloud.sensors.transfer.CloudDataTransferServiceJobStatusSensor`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/sensors/gcs_sensor.py b/airflow/contrib/sensors/gcs_sensor.py
deleted file mode 100644
index 3df15f676c265..0000000000000
--- a/airflow/contrib/sensors/gcs_sensor.py
+++ /dev/null
@@ -1,97 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.sensors.gcs`."""
-
-import warnings
-
-from airflow.providers.google.cloud.sensors.gcs import (
- GCSObjectExistenceSensor,
- GCSObjectsWithPrefixExistenceSensor,
- GCSObjectUpdateSensor,
- GCSUploadSessionCompleteSensor,
-)
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.sensors.gcs`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class GoogleCloudStorageObjectSensor(GCSObjectExistenceSensor):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.sensors.gcs.GCSObjectExistenceSensor`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.sensors.gcs.GCSObjectExistenceSensor`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class GoogleCloudStorageObjectUpdatedSensor(GCSObjectUpdateSensor):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.sensors.gcs.GCSObjectUpdateSensor`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.sensors.gcs.GCSObjectUpdateSensor`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class GoogleCloudStoragePrefixSensor(GCSObjectsWithPrefixExistenceSensor):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.sensors.gcs.GCSObjectsWithPrefixExistenceSensor`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.sensors.gcs.GCSObjectsWithPrefixExistenceSensor`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class GoogleCloudStorageUploadSessionCompleteSensor(GCSUploadSessionCompleteSensor):
- """
- This class is deprecated.
- Please use `airflow.providers.google.cloud.sensors.gcs.GCSUploadSessionCompleteSensor`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.google.cloud.sensors.gcs.GCSUploadSessionCompleteSensor`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/sensors/hdfs_sensor.py b/airflow/contrib/sensors/hdfs_sensor.py
deleted file mode 100644
index d71ec8fc2f454..0000000000000
--- a/airflow/contrib/sensors/hdfs_sensor.py
+++ /dev/null
@@ -1,67 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-This module is deprecated.
-Please use :mod:`airflow.providers.apache.hdfs.sensors.hdfs`.
-"""
-
-import warnings
-
-from airflow.providers.apache.hdfs.sensors.hdfs import HdfsFolderSensor, HdfsRegexSensor
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.apache.hdfs.sensors.hdfs`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class HdfsSensorFolder(HdfsFolderSensor):
- """This class is deprecated.
-
- Please use:
- `airflow.providers.apache.hdfs.sensors.hdfs.HdfsFolderSensor`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use
- `airflow.providers.apache.hdfs.sensors.hdfs.HdfsFolderSensor`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class HdfsSensorRegex(HdfsRegexSensor):
- """This class is deprecated.
-
- Please use:
- `airflow.providers.apache.hdfs.sensors.hdfs.HdfsRegexSensor`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use
- `airflow.providers.apache.hdfs.sensors.hdfs.HdfsRegexSensor`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/contrib/sensors/imap_attachment_sensor.py b/airflow/contrib/sensors/imap_attachment_sensor.py
deleted file mode 100644
index 34d2d7f1402e8..0000000000000
--- a/airflow/contrib/sensors/imap_attachment_sensor.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.imap.sensors.imap_attachment`."""
-
-import warnings
-
-from airflow.providers.imap.sensors.imap_attachment import ImapAttachmentSensor # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.imap.sensors.imap_attachment`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/sensors/jira_sensor.py b/airflow/contrib/sensors/jira_sensor.py
deleted file mode 100644
index e7c3785209616..0000000000000
--- a/airflow/contrib/sensors/jira_sensor.py
+++ /dev/null
@@ -1,29 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""This module is deprecated. Please use :mod:`airflow.providers.jira.sensors.jira`."""
-
-import warnings
-
-from airflow.providers.jira.sensors.jira import JiraSensor, JiraTicketSensor # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.jira.sensors.jira`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/sensors/mongo_sensor.py b/airflow/contrib/sensors/mongo_sensor.py
deleted file mode 100644
index 13a5f0b65af4c..0000000000000
--- a/airflow/contrib/sensors/mongo_sensor.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.mongo.sensors.mongo`."""
-
-import warnings
-
-from airflow.providers.mongo.sensors.mongo import MongoSensor # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.mongo.sensors.mongo`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/sensors/pubsub_sensor.py b/airflow/contrib/sensors/pubsub_sensor.py
deleted file mode 100644
index eea404e216995..0000000000000
--- a/airflow/contrib/sensors/pubsub_sensor.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.sensors.pubsub`."""
-
-import warnings
-
-from airflow.providers.google.cloud.sensors.pubsub import PubSubPullSensor # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.sensors.pubsub`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/sensors/python_sensor.py b/airflow/contrib/sensors/python_sensor.py
deleted file mode 100644
index bc7543c2fd372..0000000000000
--- a/airflow/contrib/sensors/python_sensor.py
+++ /dev/null
@@ -1,26 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.sensors.python`."""
-
-import warnings
-
-from airflow.sensors.python import PythonSensor # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.sensors.python`.", DeprecationWarning, stacklevel=2
-)
diff --git a/airflow/contrib/sensors/qubole_sensor.py b/airflow/contrib/sensors/qubole_sensor.py
deleted file mode 100644
index 6b656249003e9..0000000000000
--- a/airflow/contrib/sensors/qubole_sensor.py
+++ /dev/null
@@ -1,32 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.qubole.sensors.qubole`."""
-
-import warnings
-
-from airflow.providers.qubole.sensors.qubole import ( # noqa
- QuboleFileSensor,
- QubolePartitionSensor,
- QuboleSensor,
-)
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.qubole.sensors.qubole`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/sensors/redis_key_sensor.py b/airflow/contrib/sensors/redis_key_sensor.py
deleted file mode 100644
index f500c86dac5f3..0000000000000
--- a/airflow/contrib/sensors/redis_key_sensor.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.redis.sensors.redis_key`."""
-
-import warnings
-
-from airflow.providers.redis.sensors.redis_key import RedisKeySensor # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.redis.sensors.redis_key`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/sensors/redis_pub_sub_sensor.py b/airflow/contrib/sensors/redis_pub_sub_sensor.py
deleted file mode 100644
index 16946ac40d198..0000000000000
--- a/airflow/contrib/sensors/redis_pub_sub_sensor.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.redis.sensors.redis_pub_sub`."""
-
-import warnings
-
-from airflow.providers.redis.sensors.redis_pub_sub import RedisPubSubSensor # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.redis.sensors.redis_pub_sub`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/sensors/sagemaker_base_sensor.py b/airflow/contrib/sensors/sagemaker_base_sensor.py
deleted file mode 100644
index 86e32330278e1..0000000000000
--- a/airflow/contrib/sensors/sagemaker_base_sensor.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.sagemaker_base`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.sensors.sagemaker_base import SageMakerBaseSensor # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sagemaker_base`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/sensors/sagemaker_endpoint_sensor.py b/airflow/contrib/sensors/sagemaker_endpoint_sensor.py
deleted file mode 100644
index 5107d6f542fd0..0000000000000
--- a/airflow/contrib/sensors/sagemaker_endpoint_sensor.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.sagemaker_endpoint`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.sensors.sagemaker_endpoint import SageMakerEndpointSensor # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sagemaker_endpoint`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/sensors/sagemaker_training_sensor.py b/airflow/contrib/sensors/sagemaker_training_sensor.py
deleted file mode 100644
index 6a56516042a6f..0000000000000
--- a/airflow/contrib/sensors/sagemaker_training_sensor.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.sagemaker`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.sensors.sagemaker import SageMakerHook, SageMakerTrainingSensor # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sagemaker`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/sensors/sagemaker_transform_sensor.py b/airflow/contrib/sensors/sagemaker_transform_sensor.py
deleted file mode 100644
index 29fd18f8baaae..0000000000000
--- a/airflow/contrib/sensors/sagemaker_transform_sensor.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.sagemaker_transform`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.sensors.sagemaker_transform import SageMakerTransformSensor # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sagemaker_transform`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/sensors/sagemaker_tuning_sensor.py b/airflow/contrib/sensors/sagemaker_tuning_sensor.py
deleted file mode 100644
index 7079e4ccb774f..0000000000000
--- a/airflow/contrib/sensors/sagemaker_tuning_sensor.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.sagemaker_tuning`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.sensors.sagemaker_tuning import SageMakerTuningSensor # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sagemaker_tuning`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/sensors/sftp_sensor.py b/airflow/contrib/sensors/sftp_sensor.py
deleted file mode 100644
index d2700e814295a..0000000000000
--- a/airflow/contrib/sensors/sftp_sensor.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.sftp.sensors.sftp`."""
-
-import warnings
-
-from airflow.providers.sftp.sensors.sftp import SFTPSensor # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.sftp.sensors.sftp`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/sensors/wasb_sensor.py b/airflow/contrib/sensors/wasb_sensor.py
deleted file mode 100644
index d8e0748907afe..0000000000000
--- a/airflow/contrib/sensors/wasb_sensor.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.azure.sensors.wasb`."""
-
-import warnings
-
-from airflow.providers.microsoft.azure.sensors.wasb import WasbBlobSensor, WasbPrefixSensor # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.microsoft.azure.sensors.wasb`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/sensors/weekday_sensor.py b/airflow/contrib/sensors/weekday_sensor.py
deleted file mode 100644
index 1f836e1dda936..0000000000000
--- a/airflow/contrib/sensors/weekday_sensor.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.sensors.weekday`."""
-
-import warnings
-
-from airflow.sensors.weekday import DayOfWeekSensor # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.sensors.weekday`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/task_runner/__init__.py b/airflow/contrib/task_runner/__init__.py
index 77842e3dc5b4c..82bd6bf04148b 100644
--- a/airflow/contrib/task_runner/__init__.py
+++ b/airflow/contrib/task_runner/__init__.py
@@ -16,3 +16,21 @@
# specific language governing permissions and limitations
# under the License.
"""This package is deprecated. Please use `airflow.task.task_runner`."""
+from __future__ import annotations
+
+import warnings
+
+from airflow.exceptions import RemovedInAirflow3Warning
+from airflow.utils.deprecation_tools import add_deprecated_classes
+
+warnings.warn(
+ "This module is deprecated. Please use airflow.task.task_runner.", RemovedInAirflow3Warning, stacklevel=2
+)
+
+__deprecated_classes = {
+ 'cgroup_task_runner': {
+ 'CgroupTaskRunner': 'airflow.task.task_runner.cgroup_task_runner.CgroupTaskRunner',
+ },
+}
+
+add_deprecated_classes(__deprecated_classes, __name__)
diff --git a/airflow/contrib/task_runner/cgroup_task_runner.py b/airflow/contrib/task_runner/cgroup_task_runner.py
deleted file mode 100644
index f923126fe4475..0000000000000
--- a/airflow/contrib/task_runner/cgroup_task_runner.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.task.task_runner.cgroup_task_runner`."""
-
-import warnings
-
-from airflow.task.task_runner.cgroup_task_runner import CgroupTaskRunner # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.task.task_runner.cgroup_task_runner`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/utils/__init__.py b/airflow/contrib/utils/__init__.py
index 4e6cbb7d51e23..4057f36868d74 100644
--- a/airflow/contrib/utils/__init__.py
+++ b/airflow/contrib/utils/__init__.py
@@ -16,7 +16,50 @@
# specific language governing permissions and limitations
# under the License.
"""This package is deprecated. Please use `airflow.utils`."""
+from __future__ import annotations
import warnings
-warnings.warn("This module is deprecated. Please use `airflow.utils`.", DeprecationWarning, stacklevel=2)
+from airflow.exceptions import RemovedInAirflow3Warning
+from airflow.utils.deprecation_tools import add_deprecated_classes
+
+warnings.warn(
+ "This module is deprecated. Please use `airflow.utils`.",
+ RemovedInAirflow3Warning,
+ stacklevel=2
+)
+
+__deprecated_classes = {
+ 'gcp_field_sanitizer': {
+ 'GcpBodyFieldSanitizer': 'airflow.providers.google.cloud.utils.field_sanitizer.GcpBodyFieldSanitizer',
+ 'GcpFieldSanitizerException': (
+ 'airflow.providers.google.cloud.utils.field_sanitizer.GcpFieldSanitizerException'
+ ),
+ },
+ 'gcp_field_validator': {
+ 'GcpBodyFieldValidator': 'airflow.providers.google.cloud.utils.field_validator.GcpBodyFieldValidator',
+ 'GcpFieldValidationException': (
+ 'airflow.providers.google.cloud.utils.field_validator.GcpFieldValidationException'
+ ),
+ 'GcpValidationSpecificationException': (
+ 'airflow.providers.google.cloud.utils.field_validator.GcpValidationSpecificationException'
+ ),
+ },
+ 'mlengine_operator_utils': {
+ 'create_evaluate_ops': (
+ 'airflow.providers.google.cloud.utils.mlengine_operator_utils.create_evaluate_ops'
+ ),
+ },
+ 'mlengine_prediction_summary': {
+ 'JsonCoder': 'airflow.providers.google.cloud.utils.mlengine_prediction_summary.JsonCoder',
+ 'MakeSummary': 'airflow.providers.google.cloud.utils.mlengine_prediction_summary.MakeSummary',
+ },
+ 'sendgrid': {
+ 'import_string': 'airflow.utils.module_loading.import_string',
+ },
+ 'weekday': {
+ 'WeekDay': 'airflow.utils.weekday.WeekDay',
+ },
+}
+
+add_deprecated_classes(__deprecated_classes, __name__)
diff --git a/airflow/contrib/utils/gcp_field_sanitizer.py b/airflow/contrib/utils/gcp_field_sanitizer.py
deleted file mode 100644
index 37c0aff2b76d6..0000000000000
--- a/airflow/contrib/utils/gcp_field_sanitizer.py
+++ /dev/null
@@ -1,31 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.utils.field_sanitizer`"""
-
-import warnings
-
-from airflow.providers.google.cloud.utils.field_sanitizer import ( # noqa
- GcpBodyFieldSanitizer,
- GcpFieldSanitizerException,
-)
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.utils.field_sanitizer`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/utils/gcp_field_validator.py b/airflow/contrib/utils/gcp_field_validator.py
deleted file mode 100644
index fc42dca94be00..0000000000000
--- a/airflow/contrib/utils/gcp_field_validator.py
+++ /dev/null
@@ -1,32 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.google.cloud.utils.field_validator`."""
-
-import warnings
-
-from airflow.providers.google.cloud.utils.field_validator import ( # noqa
- GcpBodyFieldValidator,
- GcpFieldValidationException,
- GcpValidationSpecificationException,
-)
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.utils.field_validator`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/utils/log/__init__.py b/airflow/contrib/utils/log/__init__.py
index ecc47d3d85122..baf3a481858c8 100644
--- a/airflow/contrib/utils/log/__init__.py
+++ b/airflow/contrib/utils/log/__init__.py
@@ -15,7 +15,20 @@
# specific language governing permissions and limitations
# under the License.
"""This package is deprecated. Please use `airflow.utils.log`."""
+from __future__ import annotations
import warnings
+from airflow.utils.deprecation_tools import add_deprecated_classes
+
warnings.warn("This module is deprecated. Please use `airflow.utils.log`.", DeprecationWarning, stacklevel=2)
+
+__deprecated_classes = {
+ 'task_handler_with_custom_formatter': {
+ 'TaskHandlerWithCustomFormatter': (
+ 'airflow.utils.log.task_handler_with_custom_formatter.TaskHandlerWithCustomFormatter'
+ ),
+ },
+}
+
+add_deprecated_classes(__deprecated_classes, __name__)
diff --git a/airflow/contrib/utils/log/task_handler_with_custom_formatter.py b/airflow/contrib/utils/log/task_handler_with_custom_formatter.py
deleted file mode 100644
index 9bbdee3b2c8c3..0000000000000
--- a/airflow/contrib/utils/log/task_handler_with_custom_formatter.py
+++ /dev/null
@@ -1,27 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.utils.log.task_handler_with_custom_formatter`."""
-
-import warnings
-
-from airflow.utils.log.task_handler_with_custom_formatter import TaskHandlerWithCustomFormatter # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.utils.log.task_handler_with_custom_formatter`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/utils/mlengine_operator_utils.py b/airflow/contrib/utils/mlengine_operator_utils.py
deleted file mode 100644
index ebd630c1912a2..0000000000000
--- a/airflow/contrib/utils/mlengine_operator_utils.py
+++ /dev/null
@@ -1,31 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-This module is deprecated.
-Please use :mod:`airflow.providers.google.cloud.utils.mlengine_operator_utils`.
-"""
-
-import warnings
-
-from airflow.providers.google.cloud.utils.mlengine_operator_utils import create_evaluate_ops # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.google.cloud.utils.mlengine_operator_utils`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/utils/mlengine_prediction_summary.py b/airflow/contrib/utils/mlengine_prediction_summary.py
deleted file mode 100644
index ea390525a359d..0000000000000
--- a/airflow/contrib/utils/mlengine_prediction_summary.py
+++ /dev/null
@@ -1,32 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-This module is deprecated.
-Please use :mod:`airflow.providers.google.cloud.utils.mlengine_prediction_summary`.
-"""
-
-import warnings
-
-from airflow.providers.google.cloud.utils.mlengine_prediction_summary import JsonCoder, MakeSummary # noqa
-
-warnings.warn(
- "This module is deprecated. "
- "Please use `airflow.providers.google.cloud.utils.mlengine_prediction_summary`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/contrib/utils/sendgrid.py b/airflow/contrib/utils/sendgrid.py
deleted file mode 100644
index 16408f92d3a5a..0000000000000
--- a/airflow/contrib/utils/sendgrid.py
+++ /dev/null
@@ -1,36 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""
-This module is deprecated.
-Please use `airflow.providers.sendgrid.utils.emailer`.
-"""
-
-import warnings
-
-from airflow.utils.module_loading import import_string
-
-
-def send_email(*args, **kwargs):
- """This function is deprecated. Please use `airflow.providers.sendgrid.utils.emailer.send_email`."""
- warnings.warn(
- "This function is deprecated. Please use `airflow.providers.sendgrid.utils.emailer.send_email`.",
- DeprecationWarning,
- stacklevel=2,
- )
- return import_string('airflow.providers.sendgrid.utils.emailer.send_email')(*args, **kwargs)
diff --git a/airflow/contrib/utils/weekday.py b/airflow/contrib/utils/weekday.py
deleted file mode 100644
index 2f2448c8896da..0000000000000
--- a/airflow/contrib/utils/weekday.py
+++ /dev/null
@@ -1,24 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.utils.weekday`."""
-import warnings
-
-from airflow.utils.weekday import WeekDay # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.utils.weekday`.", DeprecationWarning, stacklevel=2
-)
diff --git a/airflow/dag_processing/manager.py b/airflow/dag_processing/manager.py
index 38f5ffdbbf73a..1d8a082869efc 100644
--- a/airflow/dag_processing/manager.py
+++ b/airflow/dag_processing/manager.py
@@ -16,6 +16,9 @@
# specific language governing permissions and limitations
# under the License.
"""Processes DAGs."""
+from __future__ import annotations
+
+import collections
import enum
import importlib
import inspect
@@ -31,17 +34,21 @@
from datetime import datetime, timedelta
from importlib import import_module
from multiprocessing.connection import Connection as MultiprocessingConnection
-from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Union, cast
+from pathlib import Path
+from typing import Any, NamedTuple, cast
from setproctitle import setproctitle
from sqlalchemy.orm import Session
from tabulate import tabulate
import airflow.models
-from airflow.callbacks.callback_requests import CallbackRequest
+from airflow.callbacks.callback_requests import CallbackRequest, SlaCallbackRequest
from airflow.configuration import conf
from airflow.dag_processing.processor import DagFileProcessorProcess
-from airflow.models import DagModel, DbCallbackRequest, errors
+from airflow.models import errors
+from airflow.models.dag import DagModel
+from airflow.models.dagwarning import DagWarning
+from airflow.models.db_callback_request import DbCallbackRequest
from airflow.models.serialized_dag import SerializedDagModel
from airflow.stats import Stats
from airflow.utils import timezone
@@ -57,9 +64,6 @@
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import prohibit_commit, skip_locked, with_row_locks
-if TYPE_CHECKING:
- import pathlib
-
class DagParsingStat(NamedTuple):
"""Information on processing progress"""
@@ -73,17 +77,17 @@ class DagFileStat(NamedTuple):
num_dags: int
import_errors: int
- last_finish_time: Optional[datetime]
- last_duration: Optional[timedelta]
+ last_finish_time: datetime | None
+ last_duration: timedelta | None
run_count: int
class DagParsingSignal(enum.Enum):
"""All signals sent to parser."""
- AGENT_RUN_ONCE = 'agent_run_once'
- TERMINATE_MANAGER = 'terminate_manager'
- END_MANAGER = 'end_manager'
+ AGENT_RUN_ONCE = "agent_run_once"
+ TERMINATE_MANAGER = "terminate_manager"
+ END_MANAGER = "end_manager"
class DagFileProcessorAgent(LoggingMixin, MultiprocessingStartMethodMixin):
@@ -107,30 +111,29 @@ class DagFileProcessorAgent(LoggingMixin, MultiprocessingStartMethodMixin):
def __init__(
self,
- dag_directory: str,
+ dag_directory: os.PathLike,
max_runs: int,
processor_timeout: timedelta,
- dag_ids: Optional[List[str]],
+ dag_ids: list[str] | None,
pickle_dags: bool,
async_mode: bool,
):
super().__init__()
- self._file_path_queue: List[str] = []
- self._dag_directory: str = dag_directory
+ self._dag_directory: os.PathLike = dag_directory
self._max_runs = max_runs
self._processor_timeout = processor_timeout
self._dag_ids = dag_ids
self._pickle_dags = pickle_dags
self._async_mode = async_mode
# Map from file path to the processor
- self._processors: Dict[str, DagFileProcessorProcess] = {}
+ self._processors: dict[str, DagFileProcessorProcess] = {}
# Pipe for communicating signals
- self._process: Optional[multiprocessing.process.BaseProcess] = None
+ self._process: multiprocessing.process.BaseProcess | None = None
self._done: bool = False
# Initialized as true so we do not deactivate w/o any actual DAG parsing.
self._all_files_processed = True
- self._parent_signal_conn: Optional[MultiprocessingConnection] = None
+ self._parent_signal_conn: MultiprocessingConnection | None = None
self._last_parsing_stat_received_at: float = time.monotonic()
@@ -205,11 +208,11 @@ def wait_until_finished(self) -> None:
@staticmethod
def _run_processor_manager(
- dag_directory: str,
+ dag_directory: os.PathLike,
max_runs: int,
processor_timeout: timedelta,
signal_conn: MultiprocessingConnection,
- dag_ids: Optional[List[str]],
+ dag_ids: list[str] | None,
pickle_dags: bool,
async_mode: bool,
) -> None:
@@ -217,15 +220,15 @@ def _run_processor_manager(
# Make this process start as a new process group - that makes it easy
# to kill all sub-process of this at the OS-level, rather than having
# to iterate the child processes
- os.setpgid(0, 0)
+ set_new_process_group()
setproctitle("airflow scheduler -- DagFileProcessorManager")
# Reload configurations and settings to avoid collision with parent process.
# Because this process may need custom configurations that cannot be shared,
# e.g. RotatingFileHandler. And it can cause connection corruption if we
# do not recreate the SQLA connection pool.
- os.environ['CONFIG_PROCESSOR_MANAGER_LOGGER'] = 'True'
- os.environ['AIRFLOW__LOGGING__COLORED_CONSOLE_LOG'] = 'False'
+ os.environ["CONFIG_PROCESSOR_MANAGER_LOGGER"] = "True"
+ os.environ["AIRFLOW__LOGGING__COLORED_CONSOLE_LOG"] = "False"
# Replicating the behavior of how logging module was loaded
# in logging_config.py
@@ -238,10 +241,10 @@ def _run_processor_manager(
# The issue that describes the problem and possible remediation is
# at https://github.com/apache/airflow/issues/19934
- importlib.reload(import_module(airflow.settings.LOGGING_CLASS_PATH.rsplit('.', 1)[0])) # type: ignore
+ importlib.reload(import_module(airflow.settings.LOGGING_CLASS_PATH.rsplit(".", 1)[0])) # type: ignore
importlib.reload(airflow.settings)
airflow.settings.initialize()
- del os.environ['CONFIG_PROCESSOR_MANAGER_LOGGER']
+ del os.environ["CONFIG_PROCESSOR_MANAGER_LOGGER"]
processor_manager = DagFileProcessorManager(
dag_directory=dag_directory,
max_runs=max_runs,
@@ -251,7 +254,6 @@ def _run_processor_manager(
signal_conn=signal_conn,
async_mode=async_mode,
)
-
processor_manager.start()
def heartbeat(self) -> None:
@@ -295,7 +297,7 @@ def _heartbeat_manager(self):
parsing_stat_age = time.monotonic() - self._last_parsing_stat_received_at
if parsing_stat_age > self._processor_timeout.total_seconds():
- Stats.incr('dag_processing.manager_stalls')
+ Stats.incr("dag_processing.manager_stalls")
self.log.error(
"DagFileProcessorManager (PID=%d) last sent a heartbeat %.2f seconds ago! Restarting it",
self._process.pid,
@@ -338,7 +340,7 @@ def end(self):
:return:
"""
if not self._process:
- self.log.warning('Ending without manager process.')
+ self.log.warning("Ending without manager process.")
return
# Give the Manager some time to cleanly shut down, but not too long, as
# it's better to finish sooner than wait for (non-critical) work to
@@ -369,25 +371,25 @@ class DagFileProcessorManager(LoggingMixin):
def __init__(
self,
- dag_directory: Union[str, "pathlib.Path"],
+ dag_directory: os.PathLike,
max_runs: int,
processor_timeout: timedelta,
- dag_ids: Optional[List[str]],
+ dag_ids: list[str] | None,
pickle_dags: bool,
- signal_conn: Optional[MultiprocessingConnection] = None,
+ signal_conn: MultiprocessingConnection | None = None,
async_mode: bool = True,
):
super().__init__()
- self._file_paths: List[str] = []
- self._file_path_queue: List[str] = []
- self._dag_directory = dag_directory
+ self._file_paths: list[str] = []
+ self._file_path_queue: collections.deque[str] = collections.deque()
self._max_runs = max_runs
# signal_conn is None for dag_processor_standalone mode.
self._direct_scheduler_conn = signal_conn
self._pickle_dags = pickle_dags
self._dag_ids = dag_ids
self._async_mode = async_mode
- self._parsing_start_time: Optional[int] = None
+ self._parsing_start_time: int | None = None
+ self._dag_directory = dag_directory
# Set the signal conn in to non-blocking mode, so that attempting to
# send when the buffer is full errors, rather than hangs for-ever
@@ -398,9 +400,10 @@ def __init__(
if self._async_mode and self._direct_scheduler_conn is not None:
os.set_blocking(self._direct_scheduler_conn.fileno(), False)
- self._parallelism = conf.getint('scheduler', 'parsing_processes')
+ self.standalone_dag_processor = conf.getboolean("scheduler", "standalone_dag_processor")
+ self._parallelism = conf.getint("scheduler", "parsing_processes")
if (
- conf.get_mandatory_value('database', 'sql_alchemy_conn').startswith('sqlite')
+ conf.get_mandatory_value("database", "sql_alchemy_conn").startswith("sqlite")
and self._parallelism > 1
):
self.log.warning(
@@ -411,18 +414,18 @@ def __init__(
self._parallelism = 1
# Parse and schedule each file no faster than this interval.
- self._file_process_interval = conf.getint('scheduler', 'min_file_process_interval')
+ self._file_process_interval = conf.getint("scheduler", "min_file_process_interval")
# How often to print out DAG file processing stats to the log. Default to
# 30 seconds.
- self.print_stats_interval = conf.getint('scheduler', 'print_stats_interval')
+ self.print_stats_interval = conf.getint("scheduler", "print_stats_interval")
# Map from file path to the processor
- self._processors: Dict[str, DagFileProcessorProcess] = {}
+ self._processors: dict[str, DagFileProcessorProcess] = {}
self._num_run = 0
# Map from file path to stats about the file
- self._file_stats: Dict[str, DagFileStat] = {}
+ self._file_stats: dict[str, DagFileStat] = {}
# Last time that the DAG dir was traversed to look for files
self.last_dag_dir_refresh_time = timezone.make_aware(datetime.fromtimestamp(0))
@@ -431,18 +434,18 @@ def __init__(
# Last time we cleaned up DAGs which are no longer in files
self.last_deactivate_stale_dags_time = timezone.make_aware(datetime.fromtimestamp(0))
# How often to check for DAGs which are no longer in files
- self.deactivate_stale_dags_interval = conf.getint('scheduler', 'deactivate_stale_dags_interval')
+ self.parsing_cleanup_interval = conf.getint("scheduler", "parsing_cleanup_interval")
# How long to wait before timing out a process to parse a DAG file
self._processor_timeout = processor_timeout
# How often to scan the DAGs directory for new files. Default to 5 minutes.
- self.dag_dir_list_interval = conf.getint('scheduler', 'dag_dir_list_interval')
+ self.dag_dir_list_interval = conf.getint("scheduler", "dag_dir_list_interval")
# Mapping file name and callbacks requests
- self._callback_to_execute: Dict[str, List[CallbackRequest]] = defaultdict(list)
+ self._callback_to_execute: dict[str, list[CallbackRequest]] = defaultdict(list)
- self._log = logging.getLogger('airflow.processor_manager')
+ self._log = logging.getLogger("airflow.processor_manager")
- self.waitables: Dict[Any, Union[MultiprocessingConnection, DagFileProcessorProcess]] = (
+ self.waitables: dict[Any, MultiprocessingConnection | DagFileProcessorProcess] = (
{
self._direct_scheduler_conn: self._direct_scheduler_conn,
}
@@ -460,7 +463,7 @@ def register_exit_signals(self):
def _exit_gracefully(self, signum, frame):
"""Helper method to clean up DAG file processors to avoid leaving orphan processes."""
self.log.info("Exiting gracefully upon receiving signal %s", signum)
- self.log.debug("Current Stacktrace is: %s", '\n'.join(map(str, inspect.stack())))
+ self.log.debug("Current Stacktrace is: %s", "\n".join(map(str, inspect.stack())))
self.terminate()
self.end()
self.log.debug("Finished terminating DAG processors.")
@@ -494,16 +497,18 @@ def _deactivate_stale_dags(self, session=None):
"""
now = timezone.utcnow()
elapsed_time_since_refresh = (now - self.last_deactivate_stale_dags_time).total_seconds()
- if elapsed_time_since_refresh > self.deactivate_stale_dags_interval:
+ if elapsed_time_since_refresh > self.parsing_cleanup_interval:
last_parsed = {
fp: self.get_last_finish_time(fp) for fp in self.file_paths if self.get_last_finish_time(fp)
}
to_deactivate = set()
- dags_parsed = (
- session.query(DagModel.dag_id, DagModel.fileloc, DagModel.last_parsed_time)
- .filter(DagModel.is_active)
- .all()
+ query = session.query(DagModel.dag_id, DagModel.fileloc, DagModel.last_parsed_time).filter(
+ DagModel.is_active
)
+ if self.standalone_dag_processor:
+ query = query.filter(DagModel.processor_subdir == self.get_dag_directory())
+ dags_parsed = query.all()
+
for dag in dags_parsed:
# The largest valid difference between a DagFileStat's last_finished_time and a DAG's
# last_parsed_time is _processor_timeout. Longer than that indicates that the DAG is
@@ -541,7 +546,7 @@ def _run_parsing_loop(self):
self._refresh_dag_dir()
self.prepare_file_path_queue()
max_callbacks_per_loop = conf.getint("scheduler", "max_callbacks_per_loop")
- standalone_dag_processor = conf.getboolean("scheduler", "standalone_dag_processor")
+
if self._async_mode:
# If we're in async mode, we can start up straight away. If we're
# in sync mode we need to be told to start a "loop"
@@ -592,10 +597,11 @@ def _run_parsing_loop(self):
self.waitables.pop(sentinel)
self._processors.pop(processor.file_path)
- if standalone_dag_processor:
+ if self.standalone_dag_processor:
self._fetch_callbacks(max_callbacks_per_loop)
self._deactivate_stale_dags()
- self._refresh_dag_dir()
+ DagWarning.purge_inactive_dag_warnings()
+ refreshed_dag_dir = self._refresh_dag_dir()
self._kill_timed_out_processors()
@@ -604,6 +610,8 @@ def _run_parsing_loop(self):
if not self._file_path_queue:
self.emit_metrics()
self.prepare_file_path_queue()
+ elif refreshed_dag_dir:
+ self.add_new_file_path_to_queue()
self.start_new_processes()
@@ -661,11 +669,12 @@ def _fetch_callbacks(self, max_callbacks: int, session: Session = NEW_SESSION):
"""Fetches callbacks from database and add them to the internal queue for execution."""
self.log.debug("Fetching callbacks from the database.")
with prohibit_commit(session) as guard:
- query = (
- session.query(DbCallbackRequest)
- .order_by(DbCallbackRequest.priority_weight.asc())
- .limit(max_callbacks)
- )
+ query = session.query(DbCallbackRequest)
+ if self.standalone_dag_processor:
+ query = query.filter(
+ DbCallbackRequest.processor_subdir == self.get_dag_directory(),
+ )
+ query = query.order_by(DbCallbackRequest.priority_weight.asc()).limit(max_callbacks)
callbacks = with_row_locks(
query, of=DbCallbackRequest, session=session, **skip_locked(session=session)
).all()
@@ -678,16 +687,35 @@ def _fetch_callbacks(self, max_callbacks: int, session: Session = NEW_SESSION):
guard.commit()
def _add_callback_to_queue(self, request: CallbackRequest):
- self._callback_to_execute[request.full_filepath].append(request)
- # Callback has a higher priority over DAG Run scheduling
- if request.full_filepath in self._file_path_queue:
- # Remove file paths matching request.full_filepath from self._file_path_queue
- # Since we are already going to use that filepath to run callback,
- # there is no need to have same file path again in the queue
- self._file_path_queue = [
- file_path for file_path in self._file_path_queue if file_path != request.full_filepath
- ]
- self._file_path_queue.insert(0, request.full_filepath)
+
+ # requests are sent by dag processors. SLAs exist per-dag, but can be generated once per SLA-enabled
+ # task in the dag. If treated like other callbacks, SLAs can cause feedback where a SLA arrives,
+ # goes to the front of the queue, gets processed, triggers more SLAs from the same DAG, which go to
+ # the front of the queue, and we never get round to picking stuff off the back of the queue
+ if isinstance(request, SlaCallbackRequest):
+ if request in self._callback_to_execute[request.full_filepath]:
+ self.log.debug("Skipping already queued SlaCallbackRequest")
+ return
+
+ # not already queued, queue the file _at the back_, and add the request to the file's callbacks
+ self.log.debug("Queuing SlaCallbackRequest for %s", request.dag_id)
+ self._callback_to_execute[request.full_filepath].append(request)
+ if request.full_filepath not in self._file_path_queue:
+ self._file_path_queue.append(request.full_filepath)
+
+ # Other callbacks have a higher priority over DAG Run scheduling, so those callbacks gazump, even if
+ # already in the queue
+ else:
+ self.log.debug("Queuing %s CallbackRequest: %s", type(request).__name__, request)
+ self._callback_to_execute[request.full_filepath].append(request)
+ if request.full_filepath in self._file_path_queue:
+ # Remove file paths matching request.full_filepath from self._file_path_queue
+ # Since we are already going to use that filepath to run callback,
+ # there is no need to have same file path again in the queue
+ self._file_path_queue = collections.deque(
+ file_path for file_path in self._file_path_queue if file_path != request.full_filepath
+ )
+ self._file_path_queue.appendleft(request.full_filepath)
def _refresh_dag_dir(self):
"""Refresh file paths from dag dir if we haven't done it for too long."""
@@ -713,24 +741,33 @@ def _refresh_dag_dir(self):
dag_filelocs = []
for fileloc in self._file_paths:
if not fileloc.endswith(".py") and zipfile.is_zipfile(fileloc):
- with zipfile.ZipFile(fileloc) as z:
- dag_filelocs.extend(
- [
- os.path.join(fileloc, info.filename)
- for info in z.infolist()
- if might_contain_dag(info.filename, True, z)
- ]
- )
+ try:
+ with zipfile.ZipFile(fileloc) as z:
+ dag_filelocs.extend(
+ [
+ os.path.join(fileloc, info.filename)
+ for info in z.infolist()
+ if might_contain_dag(info.filename, True, z)
+ ]
+ )
+ except zipfile.BadZipFile as err:
+ self.log.error("There was an err accessing %s, %s", fileloc, err)
else:
dag_filelocs.append(fileloc)
- SerializedDagModel.remove_deleted_dags(dag_filelocs)
+ SerializedDagModel.remove_deleted_dags(
+ alive_dag_filelocs=dag_filelocs,
+ processor_subdir=self.get_dag_directory(),
+ )
DagModel.deactivate_deleted_dags(self._file_paths)
from airflow.models.dagcode import DagCode
DagCode.remove_deleted_code(dag_filelocs)
+ return True
+ return False
+
def _print_stat(self):
"""Occasionally print out stats about how fast the files are getting processed"""
if 0 < self.print_stats_interval < time.monotonic() - self.last_stat_print_time:
@@ -748,7 +785,7 @@ def clear_nonexistent_import_errors(self, session):
query = session.query(errors.ImportError)
if self._file_paths:
query = query.filter(~errors.ImportError.filename.in_(self._file_paths))
- query.delete(synchronize_session='fetch')
+ query.delete(synchronize_session="fetch")
session.commit()
def _log_file_processing_stats(self, known_file_paths):
@@ -776,7 +813,7 @@ def _log_file_processing_stats(self, known_file_paths):
num_dags = self.get_last_dag_count(file_path)
num_errors = self.get_last_error_count(file_path)
file_name = os.path.basename(file_path)
- file_name = os.path.splitext(file_name)[0].replace(os.sep, '.')
+ file_name = os.path.splitext(file_name)[0].replace(os.sep, ".")
processor_pid = self.get_pid(file_path)
processor_start_time = self.get_start_time(file_path)
@@ -784,7 +821,7 @@ def _log_file_processing_stats(self, known_file_paths):
last_run = self.get_last_finish_time(file_path)
if last_run:
seconds_ago = (now - last_run).total_seconds()
- Stats.gauge(f'dag_processing.last_run.seconds_ago.{file_name}', seconds_ago)
+ Stats.gauge(f"dag_processing.last_run.seconds_ago.{file_name}", seconds_ago)
rows.append((file_path, processor_pid, runtime, num_dags, num_errors, last_runtime, last_run))
@@ -816,84 +853,85 @@ def _log_file_processing_stats(self, known_file_paths):
self.log.info(log_str)
- def get_pid(self, file_path):
+ def get_pid(self, file_path) -> int | None:
"""
:param file_path: the path to the file that's being processed
:return: the PID of the process processing the given file or None if
the specified file is not being processed
- :rtype: int
"""
if file_path in self._processors:
return self._processors[file_path].pid
return None
- def get_all_pids(self):
+ def get_all_pids(self) -> list[int]:
"""
+ Get all pids.
+
:return: a list of the PIDs for the processors that are running
- :rtype: List[int]
"""
return [x.pid for x in self._processors.values()]
- def get_last_runtime(self, file_path):
+ def get_last_runtime(self, file_path) -> float | None:
"""
:param file_path: the path to the file that was processed
:return: the runtime (in seconds) of the process of the last run, or
None if the file was never processed.
- :rtype: float
"""
stat = self._file_stats.get(file_path)
return stat.last_duration.total_seconds() if stat and stat.last_duration else None
- def get_last_dag_count(self, file_path):
+ def get_last_dag_count(self, file_path) -> int | None:
"""
:param file_path: the path to the file that was processed
:return: the number of dags loaded from that file, or None if the file
was never processed.
- :rtype: int
"""
stat = self._file_stats.get(file_path)
return stat.num_dags if stat else None
- def get_last_error_count(self, file_path):
+ def get_last_error_count(self, file_path) -> int | None:
"""
:param file_path: the path to the file that was processed
:return: the number of import errors from processing, or None if the file
was never processed.
- :rtype: int
"""
stat = self._file_stats.get(file_path)
return stat.import_errors if stat else None
- def get_last_finish_time(self, file_path):
+ def get_last_finish_time(self, file_path) -> datetime | None:
"""
:param file_path: the path to the file that was processed
:return: the finish time of the process of the last run, or None if the
file was never processed.
- :rtype: datetime
"""
stat = self._file_stats.get(file_path)
return stat.last_finish_time if stat else None
- def get_start_time(self, file_path):
+ def get_start_time(self, file_path) -> datetime | None:
"""
:param file_path: the path to the file that's being processed
:return: the start time of the process that's processing the
specified file or None if the file is not currently being processed
- :rtype: datetime
"""
if file_path in self._processors:
return self._processors[file_path].start_time
return None
- def get_run_count(self, file_path):
+ def get_run_count(self, file_path) -> int:
"""
:param file_path: the path to the file that's being processed
:return: the number of times the given file has been parsed
- :rtype: int
"""
stat = self._file_stats.get(file_path)
return stat.run_count if stat else 0
+ def get_dag_directory(self) -> str:
+ """Returns the dag_director as a string."""
+ if isinstance(self._dag_directory, Path):
+ return str(self._dag_directory.resolve())
+ else:
+ return str(self._dag_directory)
+
def set_file_paths(self, new_file_paths):
"""
Update this with a new set of paths to DAG definition files.
@@ -902,7 +940,7 @@ def set_file_paths(self, new_file_paths):
:return: None
"""
self._file_paths = new_file_paths
- self._file_path_queue = [x for x in self._file_path_queue if x in new_file_paths]
+ self._file_path_queue = collections.deque(x for x in self._file_path_queue if x in new_file_paths)
# Stop processors that are working on deleted files
filtered_processors = {}
for file_path, processor in self._processors.items():
@@ -910,9 +948,15 @@ def set_file_paths(self, new_file_paths):
filtered_processors[file_path] = processor
else:
self.log.warning("Stopping processor for %s", file_path)
- Stats.decr('dag_processing.processes')
+ Stats.decr("dag_processing.processes")
processor.terminate()
self._file_stats.pop(file_path)
+
+ to_remove = set(self._file_stats.keys()) - set(self._file_paths)
+ for key in to_remove:
+ # Remove the stats for any dag files that don't exist anymore
+ del self._file_stats[key]
+
self._processors = filtered_processors
def wait_until_finished(self):
@@ -923,7 +967,7 @@ def wait_until_finished(self):
def _collect_results_from_processor(self, processor) -> None:
self.log.debug("Processor for %s finished", processor.file_path)
- Stats.decr('dag_processing.processes')
+ Stats.decr("dag_processing.processes")
last_finish_time = timezone.utcnow()
if processor.result is not None:
@@ -945,8 +989,8 @@ def _collect_results_from_processor(self, processor) -> None:
)
self._file_stats[processor.file_path] = stat
- file_name = os.path.splitext(os.path.basename(processor.file_path))[0].replace(os.sep, '.')
- Stats.timing(f'dag_processing.last_duration.{file_name}', last_duration)
+ file_name = os.path.splitext(os.path.basename(processor.file_path))[0].replace(os.sep, ".")
+ Stats.timing(f"dag_processing.last_duration.{file_name}", last_duration)
def collect_results(self) -> None:
"""Collect the result from any finished DAG processors"""
@@ -967,33 +1011,51 @@ def collect_results(self) -> None:
self.log.debug("%s file paths queued for processing", len(self._file_path_queue))
@staticmethod
- def _create_process(file_path, pickle_dags, dag_ids, callback_requests):
+ def _create_process(file_path, pickle_dags, dag_ids, dag_directory, callback_requests):
"""Creates DagFileProcessorProcess instance."""
return DagFileProcessorProcess(
- file_path=file_path, pickle_dags=pickle_dags, dag_ids=dag_ids, callback_requests=callback_requests
+ file_path=file_path,
+ pickle_dags=pickle_dags,
+ dag_ids=dag_ids,
+ dag_directory=dag_directory,
+ callback_requests=callback_requests,
)
def start_new_processes(self):
"""Start more processors if we have enough slots and files to process"""
while self._parallelism - len(self._processors) > 0 and self._file_path_queue:
- file_path = self._file_path_queue.pop(0)
+ file_path = self._file_path_queue.popleft()
# Stop creating duplicate processor i.e. processor with the same filepath
if file_path in self._processors.keys():
continue
callback_to_execute_for_file = self._callback_to_execute[file_path]
processor = self._create_process(
- file_path, self._pickle_dags, self._dag_ids, callback_to_execute_for_file
+ file_path,
+ self._pickle_dags,
+ self._dag_ids,
+ self.get_dag_directory(),
+ callback_to_execute_for_file,
)
del self._callback_to_execute[file_path]
- Stats.incr('dag_processing.processes')
+ Stats.incr("dag_processing.processes")
processor.start()
self.log.debug("Started a process (PID: %s) to generate tasks for %s", processor.pid, file_path)
self._processors[file_path] = processor
self.waitables[processor.waitable_handle] = processor
+ def add_new_file_path_to_queue(self):
+ for file_path in self.file_paths:
+ if file_path not in self._file_stats:
+ # We found new file after refreshing dir. add to parsing queue at start
+ self.log.info("Adding new file %s to parsing queue", file_path)
+ self._file_stats[file_path] = DagFileStat(
+ num_dags=0, import_errors=0, last_finish_time=None, last_duration=None, run_count=0
+ )
+ self._file_path_queue.appendleft(file_path)
+
def prepare_file_path_queue(self):
"""Generate more file paths to process. Result are saved in _file_path_queue."""
self._parsing_start_time = time.perf_counter()
@@ -1010,6 +1072,7 @@ def prepare_file_path_queue(self):
is_mtime_mode = list_mode == "modified_time"
file_paths_recently_processed = []
+ file_paths_to_stop_watching = set()
for file_path in self._file_paths:
if is_mtime_mode:
@@ -1017,6 +1080,8 @@ def prepare_file_path_queue(self):
files_with_mtime[file_path] = os.path.getmtime(file_path)
except FileNotFoundError:
self.log.warning("Skipping processing of missing file: %s", file_path)
+ self._file_stats.pop(file_path, None)
+ file_paths_to_stop_watching.add(file_path)
continue
file_modified_time = timezone.make_aware(datetime.fromtimestamp(files_with_mtime[file_path]))
else:
@@ -1045,12 +1110,18 @@ def prepare_file_path_queue(self):
# set of files. Since we set the seed, the sort order will remain same per host
random.Random(get_hostname()).shuffle(file_paths)
+ if file_paths_to_stop_watching:
+ self.set_file_paths(
+ [path for path in self._file_paths if path not in file_paths_to_stop_watching]
+ )
+
files_paths_at_run_limit = [
file_path for file_path, stat in self._file_stats.items() if stat.run_count == self._max_runs
]
file_paths_to_exclude = set(file_paths_in_progress).union(
- file_paths_recently_processed, files_paths_at_run_limit
+ file_paths_recently_processed,
+ files_paths_at_run_limit,
)
# Do not convert the following list to set as set does not preserve the order
@@ -1068,12 +1139,11 @@ def prepare_file_path_queue(self):
self.log.debug("Queuing the following files for processing:\n\t%s", "\n\t".join(files_paths_to_queue))
+ default = DagFileStat(
+ num_dags=0, import_errors=0, last_finish_time=None, last_duration=None, run_count=0
+ )
for file_path in files_paths_to_queue:
- if file_path not in self._file_stats:
- self._file_stats[file_path] = DagFileStat(
- num_dags=0, import_errors=0, last_finish_time=None, last_duration=None, run_count=0
- )
-
+ self._file_stats.setdefault(file_path, default)
self._file_path_queue.extend(files_paths_to_queue)
def _kill_timed_out_processors(self):
@@ -1089,16 +1159,25 @@ def _kill_timed_out_processors(self):
processor.pid,
processor.start_time.isoformat(),
)
- Stats.decr('dag_processing.processes')
- Stats.incr('dag_processing.processor_timeouts')
- # TODO: Remove after Airflow 2.0
- Stats.incr('dag_file_processor_timeouts')
+ Stats.decr("dag_processing.processes")
+ Stats.incr("dag_processing.processor_timeouts")
+ # Deprecated; may be removed in a future Airflow release.
+ Stats.incr("dag_file_processor_timeouts")
processor.kill()
# Clean up processor references
self.waitables.pop(processor.waitable_handle)
processors_to_remove.append(file_path)
+ stat = DagFileStat(
+ num_dags=0,
+ import_errors=1,
+ last_finish_time=now,
+ last_duration=duration,
+ run_count=self.get_run_count(file_path) + 1,
+ )
+ self._file_stats[processor.file_path] = stat
+
# Clean up `self._processors` after iterating over it
for proc in processors_to_remove:
self._processors.pop(proc)
@@ -1120,7 +1199,7 @@ def terminate(self):
:return: None
"""
for processor in self._processors.values():
- Stats.decr('dag_processing.processes')
+ Stats.decr("dag_processing.processes")
processor.terminate()
def end(self):
@@ -1140,10 +1219,10 @@ def emit_metrics(self):
all files have been parsed.
"""
parse_time = time.perf_counter() - self._parsing_start_time
- Stats.gauge('dag_processing.total_parse_time', parse_time)
- Stats.gauge('dagbag_size', sum(stat.num_dags for stat in self._file_stats.values()))
+ Stats.gauge("dag_processing.total_parse_time", parse_time)
+ Stats.gauge("dagbag_size", sum(stat.num_dags for stat in self._file_stats.values()))
Stats.gauge(
- 'dag_processing.import_errors', sum(stat.import_errors for stat in self._file_stats.values())
+ "dag_processing.import_errors", sum(stat.import_errors for stat in self._file_stats.values())
)
@property
diff --git a/airflow/dag_processing/processor.py b/airflow/dag_processing/processor.py
index 34821ffb92010..02bb5eeebf399 100644
--- a/airflow/dag_processing/processor.py
+++ b/airflow/dag_processing/processor.py
@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
import datetime
import logging
@@ -25,13 +26,13 @@
from contextlib import redirect_stderr, redirect_stdout, suppress
from datetime import timedelta
from multiprocessing.connection import Connection as MultiprocessingConnection
-from typing import Iterator, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Iterator
from setproctitle import setproctitle
-from sqlalchemy import func, or_
+from sqlalchemy import exc, func, or_
from sqlalchemy.orm.session import Session
-from airflow import models, settings
+from airflow import settings
from airflow.callbacks.callback_requests import (
CallbackRequest,
DagCallbackRequest,
@@ -43,6 +44,9 @@
from airflow.models import SlaMiss, errors
from airflow.models.dag import DAG, DagModel
from airflow.models.dagbag import DagBag
+from airflow.models.dagrun import DagRun as DR
+from airflow.models.dagwarning import DagWarning, DagWarningType
+from airflow.models.taskinstance import TaskInstance as TI
from airflow.stats import Stats
from airflow.utils import timezone
from airflow.utils.email import get_email_address_list, send_email
@@ -51,8 +55,8 @@
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import State
-DR = models.DagRun
-TI = models.TaskInstance
+if TYPE_CHECKING:
+ from airflow.models.operator import Operator
class DagFileProcessorProcess(LoggingMixin, MultiprocessingStartMethodMixin):
@@ -71,28 +75,30 @@ def __init__(
self,
file_path: str,
pickle_dags: bool,
- dag_ids: Optional[List[str]],
- callback_requests: List[CallbackRequest],
+ dag_ids: list[str] | None,
+ dag_directory: str,
+ callback_requests: list[CallbackRequest],
):
super().__init__()
self._file_path = file_path
self._pickle_dags = pickle_dags
self._dag_ids = dag_ids
+ self._dag_directory = dag_directory
self._callback_requests = callback_requests
# The process that was launched to process the given .
- self._process: Optional[multiprocessing.process.BaseProcess] = None
+ self._process: multiprocessing.process.BaseProcess | None = None
# The result of DagFileProcessor.process_file(file_path).
- self._result: Optional[Tuple[int, int]] = None
+ self._result: tuple[int, int] | None = None
# Whether the process is done running.
self._done = False
# When the process started.
- self._start_time: Optional[datetime.datetime] = None
+ self._start_time: datetime.datetime | None = None
# This ID is use to uniquely name the process / thread that's launched
# by this processor instance
self._instance_id = DagFileProcessorProcess.class_creation_counter
- self._parent_channel: Optional[MultiprocessingConnection] = None
+ self._parent_channel: MultiprocessingConnection | None = None
DagFileProcessorProcess.class_creation_counter += 1
@property
@@ -105,9 +111,10 @@ def _run_file_processor(
parent_channel: MultiprocessingConnection,
file_path: str,
pickle_dags: bool,
- dag_ids: Optional[List[str]],
+ dag_ids: list[str] | None,
thread_name: str,
- callback_requests: List[CallbackRequest],
+ dag_directory: str,
+ callback_requests: list[CallbackRequest],
) -> None:
"""
Process the given file.
@@ -122,42 +129,49 @@ def _run_file_processor(
:param thread_name: the name to use for the process that is launched
:param callback_requests: failure callback to execute
:return: the process that was launched
- :rtype: multiprocessing.Process
"""
# This helper runs in the newly created process
log: logging.Logger = logging.getLogger("airflow.processor")
# Since we share all open FDs from the parent, we need to close the parent side of the pipe here in
# the child, else it won't get closed properly until we exit.
- log.info("Closing parent pipe")
-
parent_channel.close()
del parent_channel
set_context(log, file_path)
setproctitle(f"airflow scheduler - DagFileProcessor {file_path}")
+ def _handle_dag_file_processing():
+ # Re-configure the ORM engine as there are issues with multiple processes
+ settings.configure_orm()
+
+ # Change the thread name to differentiate log lines. This is
+ # really a separate process, but changing the name of the
+ # process doesn't work, so changing the thread name instead.
+ threading.current_thread().name = thread_name
+
+ log.info("Started process (PID=%s) to work on %s", os.getpid(), file_path)
+ dag_file_processor = DagFileProcessor(dag_ids=dag_ids, dag_directory=dag_directory, log=log)
+ result: tuple[int, int] = dag_file_processor.process_file(
+ file_path=file_path,
+ pickle_dags=pickle_dags,
+ callback_requests=callback_requests,
+ )
+ result_channel.send(result)
+
try:
- # redirect stdout/stderr to log
- with redirect_stdout(StreamLogWriter(log, logging.INFO)), redirect_stderr(
- StreamLogWriter(log, logging.WARN)
- ), Stats.timer() as timer:
- # Re-configure the ORM engine as there are issues with multiple processes
- settings.configure_orm()
-
- # Change the thread name to differentiate log lines. This is
- # really a separate process, but changing the name of the
- # process doesn't work, so changing the thread name instead.
- threading.current_thread().name = thread_name
-
- log.info("Started process (PID=%s) to work on %s", os.getpid(), file_path)
- dag_file_processor = DagFileProcessor(dag_ids=dag_ids, log=log)
- result: Tuple[int, int] = dag_file_processor.process_file(
- file_path=file_path,
- pickle_dags=pickle_dags,
- callback_requests=callback_requests,
- )
- result_channel.send(result)
+ DAG_PROCESSOR_LOG_TARGET = conf.get_mandatory_value("logging", "DAG_PROCESSOR_LOG_TARGET")
+ if DAG_PROCESSOR_LOG_TARGET == "stdout":
+ with Stats.timer() as timer:
+ _handle_dag_file_processing()
+ else:
+ # The following line ensures that stdout goes to the same destination as the logs. If stdout
+ # gets sent to logs and logs are sent to stdout, this leads to an infinite loop. This
+ # necessitates this conditional based on the value of DAG_PROCESSOR_LOG_TARGET.
+ with redirect_stdout(StreamLogWriter(log, logging.INFO)), redirect_stderr(
+ StreamLogWriter(log, logging.WARN)
+ ), Stats.timer() as timer:
+ _handle_dag_file_processing()
log.info("Processing %s took %.3f seconds", file_path, timer.duration)
except Exception:
# Log exceptions through the logging framework.
@@ -185,6 +199,7 @@ def start(self) -> None:
self._pickle_dags,
self._dag_ids,
f"DagFileProcessor{self._instance_id}",
+ self._dag_directory,
self._callback_requests,
),
name=f"DagFileProcessor{self._instance_id}-Process",
@@ -243,21 +258,17 @@ def _kill_process(self) -> None:
@property
def pid(self) -> int:
- """
- :return: the PID of the process launched to process the given file
- :rtype: int
- """
+ """PID of the process launched to process the given file."""
if self._process is None or self._process.pid is None:
raise AirflowException("Tried to get PID before starting!")
return self._process.pid
@property
- def exit_code(self) -> Optional[int]:
+ def exit_code(self) -> int | None:
"""
After the process is finished, this can be called to get the return code
:return: the exit code of the process
- :rtype: int
"""
if self._process is None:
raise AirflowException("Tried to get exit code before starting!")
@@ -271,7 +282,6 @@ def done(self) -> bool:
Check if the process launched to process this file is done.
:return: whether the process is finished running
- :rtype: bool
"""
if self._process is None or self._parent_channel is None:
raise AirflowException("Tried to see if it's done before starting!")
@@ -309,21 +319,15 @@ def done(self) -> bool:
return False
@property
- def result(self) -> Optional[Tuple[int, int]]:
- """
- :return: result of running DagFileProcessor.process_file()
- :rtype: tuple[int, int] or None
- """
+ def result(self) -> tuple[int, int] | None:
+ """Result of running ``DagFileProcessor.process_file()``."""
if not self.done:
raise AirflowException("Tried to get the result before it's done!")
return self._result
@property
def start_time(self) -> datetime.datetime:
- """
- :return: when this started to process the file
- :rtype: datetime
- """
+ """Time when this started to process the file."""
if self._start_time is None:
raise AirflowException("Tried to get start time before it started!")
return self._start_time
@@ -351,12 +355,14 @@ class DagFileProcessor(LoggingMixin):
:param log: Logger to save the processing process
"""
- UNIT_TEST_MODE: bool = conf.getboolean('core', 'UNIT_TEST_MODE')
+ UNIT_TEST_MODE: bool = conf.getboolean("core", "UNIT_TEST_MODE")
- def __init__(self, dag_ids: Optional[List[str]], log: logging.Logger):
+ def __init__(self, dag_ids: list[str] | None, dag_directory: str, log: logging.Logger):
super().__init__()
self.dag_ids = dag_ids
self._log = log
+ self._dag_directory = dag_directory
+ self.dag_warnings: set[tuple[str, str]] = set()
@provide_session
def manage_slas(self, dag: DAG, session: Session = None) -> None:
@@ -373,14 +379,13 @@ def manage_slas(self, dag: DAG, session: Session = None) -> None:
return
qry = (
- session.query(TI.task_id, func.max(DR.execution_date).label('max_ti'))
+ session.query(TI.task_id, func.max(DR.execution_date).label("max_ti"))
.join(TI.dag_run)
- .with_hint(TI, 'USE INDEX (PRIMARY)', dialect_name='mysql')
.filter(TI.dag_id == dag.dag_id)
.filter(or_(TI.state == State.SUCCESS, TI.state == State.SKIPPED))
.filter(TI.task_id.in_(dag.task_ids))
.group_by(TI.task_id)
- .subquery('sq')
+ .subquery("sq")
)
# get recorded SlaMiss
recorded_slas_query = set(
@@ -414,42 +419,40 @@ def manage_slas(self, dag: DAG, session: Session = None) -> None:
sla_misses = []
next_info = dag.next_dagrun_info(dag.get_run_data_interval(ti.dag_run), restricted=False)
- if next_info is None:
- self.log.info("Skipping SLA check for %s because task does not have scheduled date", ti)
- else:
- while next_info.logical_date < ts:
- next_info = dag.next_dagrun_info(next_info.data_interval, restricted=False)
-
- if next_info is None:
- break
- if (ti.dag_id, ti.task_id, next_info.logical_date) in recorded_slas_query:
- break
- if next_info.logical_date + task.sla < ts:
-
- sla_miss = SlaMiss(
- task_id=ti.task_id,
- dag_id=ti.dag_id,
- execution_date=next_info.logical_date,
- timestamp=ts,
- )
- sla_misses.append(sla_miss)
+ while next_info and next_info.logical_date < ts:
+ next_info = dag.next_dagrun_info(next_info.data_interval, restricted=False)
+
+ if next_info is None:
+ break
+ if (ti.dag_id, ti.task_id, next_info.logical_date) in recorded_slas_query:
+ continue
+ if next_info.logical_date + task.sla < ts:
+
+ sla_miss = SlaMiss(
+ task_id=ti.task_id,
+ dag_id=ti.dag_id,
+ execution_date=next_info.logical_date,
+ timestamp=ts,
+ )
+ sla_misses.append(sla_miss)
+ Stats.incr("sla_missed")
if sla_misses:
session.add_all(sla_misses)
session.commit()
- slas: List[SlaMiss] = (
+ slas: list[SlaMiss] = (
session.query(SlaMiss)
.filter(SlaMiss.notification_sent == False, SlaMiss.dag_id == dag.dag_id) # noqa
.all()
)
if slas:
- sla_dates: List[datetime.datetime] = [sla.execution_date for sla in slas]
- fetched_tis: List[TI] = (
+ sla_dates: list[datetime.datetime] = [sla.execution_date for sla in slas]
+ fetched_tis: list[TI] = (
session.query(TI)
.filter(TI.state != State.SUCCESS, TI.execution_date.in_(sla_dates), TI.dag_id == dag.dag_id)
.all()
)
- blocking_tis: List[TI] = []
+ blocking_tis: list[TI] = []
for ti in fetched_tis:
if ti.task_id in dag.task_ids:
ti.task = dag.get_task(ti.task_id)
@@ -458,9 +461,9 @@ def manage_slas(self, dag: DAG, session: Session = None) -> None:
session.delete(ti)
session.commit()
- task_list = "\n".join(sla.task_id + ' on ' + sla.execution_date.isoformat() for sla in slas)
+ task_list = "\n".join(sla.task_id + " on " + sla.execution_date.isoformat() for sla in slas)
blocking_task_list = "\n".join(
- ti.task_id + ' on ' + ti.execution_date.isoformat() for ti in blocking_tis
+ ti.task_id + " on " + ti.execution_date.isoformat() for ti in blocking_tis
)
# Track whether email or any alert notification sent
# We consider email or the alert callback as notifications
@@ -468,12 +471,12 @@ def manage_slas(self, dag: DAG, session: Session = None) -> None:
notification_sent = False
if dag.sla_miss_callback:
# Execute the alert callback
- self.log.info('Calling SLA miss callback')
+ self.log.info("Calling SLA miss callback")
try:
dag.sla_miss_callback(dag, task_list, blocking_task_list, slas, blocking_tis)
notification_sent = True
except Exception:
- Stats.incr('sla_callback_notification_failure')
+ Stats.incr("sla_callback_notification_failure")
self.log.exception("Could not call sla_miss_callback for DAG %s", dag.dag_id)
email_content = f"""\
Here's a list of tasks that missed their SLAs:
@@ -495,7 +498,7 @@ def manage_slas(self, dag: DAG, session: Session = None) -> None:
continue
tasks_missed_sla.append(task)
- emails: Set[str] = set()
+ emails: set[str] = set()
for task in tasks_missed_sla:
if task.email:
if isinstance(task.email, str):
@@ -508,7 +511,7 @@ def manage_slas(self, dag: DAG, session: Session = None) -> None:
email_sent = True
notification_sent = True
except Exception:
- Stats.incr('sla_email_notification_failure')
+ Stats.incr("sla_email_notification_failure")
self.log.exception("Could not send SLA Miss email notification for DAG %s", dag.dag_id)
# If we sent any notification, update the sla_miss table
if notification_sent:
@@ -545,7 +548,7 @@ def update_import_errors(session: Session, dagbag: DagBag) -> None:
if filename in existing_import_error_files:
session.query(errors.ImportError).filter(errors.ImportError.filename == filename).update(
dict(filename=filename, timestamp=timezone.utcnow(), stacktrace=stacktrace),
- synchronize_session='fetch',
+ synchronize_session="fetch",
)
else:
session.add(
@@ -554,14 +557,63 @@ def update_import_errors(session: Session, dagbag: DagBag) -> None:
(
session.query(DagModel)
.filter(DagModel.fileloc == filename)
- .update({'has_import_errors': True}, synchronize_session='fetch')
+ .update({"has_import_errors": True}, synchronize_session="fetch")
)
session.commit()
+ @provide_session
+ def _validate_task_pools(self, *, dagbag: DagBag, session: Session = NEW_SESSION):
+ """
+ Validates and raise exception if any task in a dag is using a non-existent pool
+ :meta private:
+ """
+ from airflow.models.pool import Pool
+
+ def check_pools(dag):
+ task_pools = {task.pool for task in dag.tasks}
+ nonexistent_pools = task_pools - pools
+ if nonexistent_pools:
+ return (
+ f"Dag '{dag.dag_id}' references non-existent pools: {list(sorted(nonexistent_pools))!r}"
+ )
+
+ pools = {p.pool for p in Pool.get_pools(session)}
+ for dag in dagbag.dags.values():
+ message = check_pools(dag)
+ if message:
+ self.dag_warnings.add(DagWarning(dag.dag_id, DagWarningType.NONEXISTENT_POOL, message))
+ for subdag in dag.subdags:
+ message = check_pools(subdag)
+ if message:
+ self.dag_warnings.add(DagWarning(subdag.dag_id, DagWarningType.NONEXISTENT_POOL, message))
+
+ def update_dag_warnings(self, *, session: Session, dagbag: DagBag) -> None:
+ """
+ For the DAGs in the given DagBag, record any associated configuration warnings and clear
+ warnings for files that no longer have them. These are usually displayed through the
+ Airflow UI so that users know that there are issues parsing DAGs.
+
+ :param session: session for ORM operations
+ :param dagbag: DagBag containing DAGs with configuration warnings
+ """
+ self._validate_task_pools(dagbag=dagbag)
+
+ stored_warnings = set(
+ session.query(DagWarning).filter(DagWarning.dag_id.in_(dagbag.dags.keys())).all()
+ )
+
+ for warning_to_delete in stored_warnings - self.dag_warnings:
+ session.delete(warning_to_delete)
+
+ for warning_to_add in self.dag_warnings:
+ session.merge(warning_to_add)
+
+ session.commit()
+
@provide_session
def execute_callbacks(
- self, dagbag: DagBag, callback_requests: List[CallbackRequest], session: Session = NEW_SESSION
+ self, dagbag: DagBag, callback_requests: list[CallbackRequest], session: Session = NEW_SESSION
) -> None:
"""
Execute on failure callbacks. These objects can come from SchedulerJob or from
@@ -575,7 +627,7 @@ def execute_callbacks(
self.log.debug("Processing Callback Request: %s", request)
try:
if isinstance(request, TaskCallbackRequest):
- self._execute_task_callbacks(dagbag, request)
+ self._execute_task_callbacks(dagbag, request, session=session)
elif isinstance(request, SlaCallbackRequest):
self.manage_slas(dagbag.get_dag(request.dag_id), session=session)
elif isinstance(request, DagCallbackRequest):
@@ -587,7 +639,27 @@ def execute_callbacks(
request.full_filepath,
)
- session.commit()
+ session.flush()
+
+ def execute_callbacks_without_dag(
+ self, callback_requests: list[CallbackRequest], session: Session
+ ) -> None:
+ """
+ Execute what callbacks we can as "best effort" when the dag cannot be found/had parse errors.
+
+ This is so important so that tasks that failed when there is a parse
+ error don't get stuck in queued state.
+ """
+ for request in callback_requests:
+ self.log.debug("Processing Callback Request: %s", request)
+ if isinstance(request, TaskCallbackRequest):
+ self._execute_task_callbacks(None, request, session)
+ else:
+ self.log.info(
+ "Not executing %s callback for file %s as there was a dag parse error",
+ request.__class__.__name__,
+ request.full_filepath,
+ )
@provide_session
def _execute_dag_callbacks(self, dagbag: DagBag, request: DagCallbackRequest, session: Session):
@@ -597,27 +669,58 @@ def _execute_dag_callbacks(self, dagbag: DagBag, request: DagCallbackRequest, se
dagrun=dag_run, success=not request.is_failure_callback, reason=request.msg, session=session
)
- def _execute_task_callbacks(self, dagbag: DagBag, request: TaskCallbackRequest):
+ def _execute_task_callbacks(self, dagbag: DagBag | None, request: TaskCallbackRequest, session: Session):
+ if not request.is_failure_callback:
+ return
+
simple_ti = request.simple_task_instance
- if simple_ti.dag_id in dagbag.dags:
+ ti: TI | None = (
+ session.query(TI)
+ .filter_by(
+ dag_id=simple_ti.dag_id,
+ run_id=simple_ti.run_id,
+ task_id=simple_ti.task_id,
+ map_index=simple_ti.map_index,
+ )
+ .one_or_none()
+ )
+ if not ti:
+ return
+
+ task: Operator | None = None
+
+ if dagbag and simple_ti.dag_id in dagbag.dags:
dag = dagbag.dags[simple_ti.dag_id]
if simple_ti.task_id in dag.task_ids:
task = dag.get_task(simple_ti.task_id)
- if request.is_failure_callback:
- ti = TI(task, run_id=simple_ti.run_id, map_index=simple_ti.map_index)
- # TODO: Use simple_ti to improve performance here in the future
- ti.refresh_from_db()
- ti.handle_failure(error=request.msg, test_mode=self.UNIT_TEST_MODE)
- self.log.info('Executed failure callback for %s in state %s', ti, ti.state)
+ else:
+ # We don't have the _real_ dag here (perhaps it had a parse error?) but we still want to run
+ # `handle_failure` so that the state of the TI gets progressed.
+ #
+ # Since handle_failure _really_ wants a task, we do our best effort to give it one
+ from airflow.models.serialized_dag import SerializedDagModel
+
+ try:
+ model = session.query(SerializedDagModel).get(simple_ti.dag_id)
+ if model:
+ task = model.dag.get_task(simple_ti.task_id)
+ except (exc.NoResultFound, TaskNotFound):
+ pass
+ if task:
+ ti.refresh_from_task(task)
+
+ ti.handle_failure(error=request.msg, test_mode=self.UNIT_TEST_MODE, session=session)
+ self.log.info("Executed failure callback for %s in state %s", ti, ti.state)
+ session.flush()
@provide_session
def process_file(
self,
file_path: str,
- callback_requests: List[CallbackRequest],
+ callback_requests: list[CallbackRequest],
pickle_dags: bool = False,
- session: Session = None,
- ) -> Tuple[int, int]:
+ session: Session = NEW_SESSION,
+ ) -> tuple[int, int]:
"""
Process a Python file containing Airflow DAGs.
@@ -636,15 +739,14 @@ def process_file(
save them to the db
:param session: Sqlalchemy ORM Session
:return: number of dags found, count of import errors
- :rtype: Tuple[int, int]
"""
self.log.info("Processing file %s for tasks to queue", file_path)
try:
- dagbag = DagBag(file_path, include_examples=False, include_smart_sensor=False)
+ dagbag = DagBag(file_path, include_examples=False)
except Exception:
self.log.exception("Failed at reloading the DAG file %s", file_path)
- Stats.incr('dag_file_refresh_error', 1, 1)
+ Stats.incr("dag_file_refresh_error", 1, 1)
return 0, 0
if len(dagbag.dags) > 0:
@@ -652,17 +754,24 @@ def process_file(
else:
self.log.warning("No viable dags retrieved from %s", file_path)
self.update_import_errors(session, dagbag)
+ if callback_requests:
+ # If there were callback requests for this file but there was a
+ # parse error we still need to progress the state of TIs,
+ # otherwise they might be stuck in queued/running for ever!
+ self.execute_callbacks_without_dag(callback_requests, session)
return 0, len(dagbag.import_errors)
- self.execute_callbacks(dagbag, callback_requests)
+ self.execute_callbacks(dagbag, callback_requests, session)
+ session.commit()
# Save individual DAGs in the ORM
- dagbag.sync_to_db()
+ dagbag.sync_to_db(processor_subdir=self._dag_directory, session=session)
+ session.commit()
if pickle_dags:
paused_dag_ids = DagModel.get_paused_dag_ids(dag_ids=dagbag.dag_ids)
- unpaused_dags: List[DAG] = [
+ unpaused_dags: list[DAG] = [
dag for dag_id, dag in dagbag.dags.items() if dag_id not in paused_dag_ids
]
@@ -675,4 +784,10 @@ def process_file(
except Exception:
self.log.exception("Error logging import errors!")
+ # Record DAG warnings in the metadatabase.
+ try:
+ self.update_dag_warnings(session=session, dagbag=dagbag)
+ except Exception:
+ self.log.exception("Error logging DAG warnings.")
+
return len(dagbag.dags), len(dagbag.import_errors)
diff --git a/airflow/datasets/__init__.py b/airflow/datasets/__init__.py
new file mode 100644
index 0000000000000..aa50fd16ecd9e
--- /dev/null
+++ b/airflow/datasets/__init__.py
@@ -0,0 +1,44 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import Any, ClassVar
+from urllib.parse import urlsplit
+
+import attr
+
+
+@attr.define()
+class Dataset:
+ """A Dataset is used for marking data dependencies between workflows."""
+
+ uri: str = attr.field(validator=[attr.validators.min_len(1), attr.validators.max_len(3000)])
+ extra: dict[str, Any] | None = None
+
+ version: ClassVar[int] = 1
+
+ @uri.validator
+ def _check_uri(self, attr, uri: str):
+ if uri.isspace():
+ raise ValueError(f"{attr.name} cannot be just whitespace")
+ try:
+ uri.encode("ascii")
+ except UnicodeEncodeError:
+ raise ValueError(f"{attr.name!r} must be ascii")
+ parsed = urlsplit(uri)
+ if parsed.scheme and parsed.scheme.lower() == "airflow":
+ raise ValueError(f"{attr.name!r} scheme `airflow` is reserved")
diff --git a/airflow/datasets/manager.py b/airflow/datasets/manager.py
new file mode 100644
index 0000000000000..1e6723ea64ecb
--- /dev/null
+++ b/airflow/datasets/manager.py
@@ -0,0 +1,126 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from sqlalchemy import exc
+from sqlalchemy.orm.session import Session
+
+from airflow.configuration import conf
+from airflow.datasets import Dataset
+from airflow.models.dataset import DatasetDagRunQueue, DatasetEvent, DatasetModel
+from airflow.utils.log.logging_mixin import LoggingMixin
+
+if TYPE_CHECKING:
+ from airflow.models.taskinstance import TaskInstance
+
+
+class DatasetManager(LoggingMixin):
+ """
+ A pluggable class that manages operations for datasets.
+
+ The intent is to have one place to handle all Dataset-related operations, so different
+ Airflow deployments can use plugins that broadcast dataset events to each other.
+ """
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ def register_dataset_change(
+ self, *, task_instance: TaskInstance, dataset: Dataset, extra=None, session: Session, **kwargs
+ ) -> None:
+ """
+ Register dataset related changes.
+
+ For local datasets, look them up, record the dataset event, queue dagruns, and broadcast
+ the dataset event
+ """
+ dataset_model = session.query(DatasetModel).filter(DatasetModel.uri == dataset.uri).one_or_none()
+ if not dataset_model:
+ self.log.warning("DatasetModel %s not found", dataset)
+ return
+ session.add(
+ DatasetEvent(
+ dataset_id=dataset_model.id,
+ source_task_id=task_instance.task_id,
+ source_dag_id=task_instance.dag_id,
+ source_run_id=task_instance.run_id,
+ source_map_index=task_instance.map_index,
+ extra=extra,
+ )
+ )
+ session.flush()
+ if dataset_model.consuming_dags:
+ self._queue_dagruns(dataset_model, session)
+ session.flush()
+
+ def _queue_dagruns(self, dataset: DatasetModel, session: Session) -> None:
+ # Possible race condition: if multiple dags or multiple (usually
+ # mapped) tasks update the same dataset, this can fail with a unique
+ # constraint violation.
+ #
+ # If we support it, use ON CONFLICT to do nothing, otherwise
+ # "fallback" to running this in a nested transaction. This is needed
+ # so that the adding of these rows happens in the same transaction
+ # where `ti.state` is changed.
+
+ if session.bind.dialect.name == "postgresql":
+ return self._postgres_queue_dagruns(dataset, session)
+ return self._slow_path_queue_dagruns(dataset, session)
+
+ def _slow_path_queue_dagruns(self, dataset: DatasetModel, session: Session) -> None:
+ consuming_dag_ids = [x.dag_id for x in dataset.consuming_dags]
+ self.log.debug("consuming dag ids %s", consuming_dag_ids)
+
+ # Don't error whole transaction when a single RunQueue item conflicts.
+ # https://docs.sqlalchemy.org/en/14/orm/session_transaction.html#using-savepoint
+ for dag_id in consuming_dag_ids:
+ item = DatasetDagRunQueue(target_dag_id=dag_id, dataset_id=dataset.id)
+ try:
+ with session.begin_nested():
+ session.merge(item)
+ except exc.IntegrityError:
+ self.log.debug("Skipping record %s", item, exc_info=True)
+
+ def _postgres_queue_dagruns(self, dataset: DatasetModel, session: Session) -> None:
+ from sqlalchemy.dialects.postgresql import insert
+
+ stmt = insert(DatasetDagRunQueue).values(dataset_id=dataset.id).on_conflict_do_nothing()
+ session.execute(
+ stmt,
+ [{"target_dag_id": target_dag.dag_id} for target_dag in dataset.consuming_dags],
+ )
+
+
+def resolve_dataset_manager() -> DatasetManager:
+ """Retrieve the dataset manager."""
+ _dataset_manager_class = conf.getimport(
+ section="core",
+ key="dataset_manager_class",
+ fallback="airflow.datasets.manager.DatasetManager",
+ )
+ _dataset_manager_kwargs = conf.getjson(
+ section="core",
+ key="dataset_manager_kwargs",
+ fallback={},
+ )
+ return _dataset_manager_class(**_dataset_manager_kwargs)
+
+
+dataset_manager = resolve_dataset_manager()
diff --git a/airflow/decorators/__init__.py b/airflow/decorators/__init__.py
index 40d9921ec0169..af478314e18e4 100644
--- a/airflow/decorators/__init__.py
+++ b/airflow/decorators/__init__.py
@@ -14,13 +14,17 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
from typing import Any
from airflow.decorators.base import TaskDecorator
from airflow.decorators.branch_python import branch_task
+from airflow.decorators.external_python import external_python_task
from airflow.decorators.python import python_task
from airflow.decorators.python_virtualenv import virtualenv_task
+from airflow.decorators.sensor import sensor_task
+from airflow.decorators.short_circuit import short_circuit_task
from airflow.decorators.task_group import task_group
from airflow.models.dag import dag
from airflow.providers_manager import ProvidersManager
@@ -34,18 +38,24 @@
"task_group",
"python_task",
"virtualenv_task",
+ "external_python_task",
"branch_task",
+ "short_circuit_task",
+ "sensor_task",
]
class TaskDecoratorCollection:
"""Implementation to provide the ``@task`` syntax."""
- python: Any = staticmethod(python_task)
+ python = staticmethod(python_task)
virtualenv = staticmethod(virtualenv_task)
+ external_python = staticmethod(external_python_task)
branch = staticmethod(branch_task)
+ short_circuit = staticmethod(short_circuit_task)
+ sensor = staticmethod(sensor_task)
- __call__ = python # Alias '@task' to '@task.python'.
+ __call__: Any = python # Alias '@task' to '@task.python'.
def __getattr__(self, name: str) -> TaskDecorator:
"""Dynamically get provider-registered task decorators, e.g. ``@task.docker``."""
diff --git a/airflow/decorators/__init__.pyi b/airflow/decorators/__init__.pyi
index e970f61657c64..0a6d534b247fa 100644
--- a/airflow/decorators/__init__.pyi
+++ b/airflow/decorators/__init__.pyi
@@ -14,19 +14,24 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
# This file provides better type hinting and editor autocompletion support for
# dynamically generated task decorators. Functions declared in this stub do not
# necessarily exist at run time. See "Creating Custom @task Decorators"
# documentation for more details.
-from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union, overload
+from datetime import timedelta
+from typing import Any, Callable, Iterable, Mapping, Union, overload
+
+from kubernetes.client import models as k8s
-from airflow.decorators.base import Function, Task, TaskDecorator
+from airflow.decorators.base import FParams, FReturn, Task, TaskDecorator
from airflow.decorators.branch_python import branch_task
+from airflow.decorators.external_python import external_python_task
from airflow.decorators.python import python_task
from airflow.decorators.python_virtualenv import virtualenv_task
+from airflow.decorators.sensor import sensor_task
from airflow.decorators.task_group import task_group
+from airflow.kubernetes.secret import Secret
from airflow.models.dag import dag
# Please keep this in sync with __init__.py's __all__.
@@ -38,7 +43,10 @@ __all__ = [
"task_group",
"python_task",
"virtualenv_task",
+ "external_python_task",
"branch_task",
+ "short_circuit_task",
+ "sensor_task",
]
class TaskDecoratorCollection:
@@ -46,10 +54,10 @@ class TaskDecoratorCollection:
def python(
self,
*,
- multiple_outputs: Optional[bool] = None,
+ multiple_outputs: bool | None = None,
# 'python_callable', 'op_args' and 'op_kwargs' since they are filled by
# _PythonDecoratedOperator.
- templates_dict: Optional[Mapping[str, Any]] = None,
+ templates_dict: Mapping[str, Any] | None = None,
show_return_value_in_logs: bool = True,
**kwargs,
) -> TaskDecorator:
@@ -60,7 +68,7 @@ class TaskDecoratorCollection:
:param templates_dict: a dictionary where the values are templates that
will get templated by the Airflow engine sometime between
``__init__`` and ``execute`` takes place and are made available
- in your callable's context after the template has been applied
+ in your callable's context after the template has been applied.
:param show_return_value_in_logs: a bool value whether to show return_value
logs. Defaults to True, which allows return value log output.
It can be set to False to prevent log output of return value when you return huge data
@@ -68,33 +76,33 @@ class TaskDecoratorCollection:
"""
# [START mixin_for_typing]
@overload
- def python(self, python_callable: Function) -> Task[Function]: ...
+ def python(self, python_callable: Callable[FParams, FReturn]) -> Task[FParams, FReturn]: ...
# [END mixin_for_typing]
@overload
def __call__(
self,
*,
- multiple_outputs: Optional[bool] = None,
- templates_dict: Optional[Mapping[str, Any]] = None,
+ multiple_outputs: bool | None = None,
+ templates_dict: Mapping[str, Any] | None = None,
show_return_value_in_logs: bool = True,
**kwargs,
) -> TaskDecorator:
"""Aliasing ``python``; signature should match exactly."""
@overload
- def __call__(self, python_callable: Function) -> Task[Function]:
+ def __call__(self, python_callable: Callable[FParams, FReturn]) -> Task[FParams, FReturn]:
"""Aliasing ``python``; signature should match exactly."""
@overload
def virtualenv(
self,
*,
- multiple_outputs: Optional[bool] = None,
+ multiple_outputs: bool | None = None,
# 'python_callable', 'op_args' and 'op_kwargs' since they are filled by
# _PythonVirtualenvDecoratedOperator.
requirements: Union[None, Iterable[str], str] = None,
python_version: Union[None, str, int, float] = None,
use_dill: bool = False,
system_site_packages: bool = True,
- templates_dict: Optional[Mapping[str, Any]] = None,
+ templates_dict: Mapping[str, Any] | None = None,
show_return_value_in_logs: bool = True,
**kwargs,
) -> TaskDecorator:
@@ -115,71 +123,115 @@ class TaskDecoratorCollection:
:param templates_dict: a dictionary where the values are templates that
will get templated by the Airflow engine sometime between
``__init__`` and ``execute`` takes place and are made available
- in your callable's context after the template has been applied
+ in your callable's context after the template has been applied.
:param show_return_value_in_logs: a bool value whether to show return_value
logs. Defaults to True, which allows return value log output.
It can be set to False to prevent log output of return value when you return huge data
such as transmission a large amount of XCom to TaskAPI.
"""
@overload
- def virtualenv(self, python_callable: Function) -> Task[Function]: ...
+ def virtualenv(self, python_callable: Callable[FParams, FReturn]) -> Task[FParams, FReturn]: ...
+ def external_python(
+ self,
+ *,
+ python: str,
+ multiple_outputs: bool | None = None,
+ # 'python_callable', 'op_args' and 'op_kwargs' since they are filled by
+ # _PythonVirtualenvDecoratedOperator.
+ use_dill: bool = False,
+ templates_dict: Mapping[str, Any] | None = None,
+ show_return_value_in_logs: bool = True,
+ **kwargs,
+ ) -> TaskDecorator:
+ """Create a decorator to convert the decorated callable to a virtual environment task.
+
+ :param python: Full path string (file-system specific) that points to a Python binary inside
+ a virtualenv that should be used (in ``VENV/bin`` folder). Should be absolute path
+ (so usually start with "/" or "X:/" depending on the filesystem/os used).
+ :param multiple_outputs: If set, function return value will be unrolled to multiple XCom values.
+ Dict will unroll to XCom values with keys as XCom keys. Defaults to False.
+ :param use_dill: Whether to use dill to serialize
+ the args and result (pickle is default). This allow more complex types
+ but requires you to include dill in your requirements.
+ :param templates_dict: a dictionary where the values are templates that
+ will get templated by the Airflow engine sometime between
+ ``__init__`` and ``execute`` takes place and are made available
+ in your callable's context after the template has been applied.
+ :param show_return_value_in_logs: a bool value whether to show return_value
+ logs. Defaults to True, which allows return value log output.
+ It can be set to False to prevent log output of return value when you return huge data
+ such as transmission a large amount of XCom to TaskAPI.
+ """
@overload
- def branch(
- self, python_callable: Optional[Callable] = None, multiple_outputs: Optional[bool] = None, **kwargs
+ def branch(self, *, multiple_outputs: bool | None = None, **kwargs) -> TaskDecorator:
+ """Create a decorator to wrap the decorated callable into a BranchPythonOperator.
+
+ For more information on how to use this decorator, see :ref:`howto/operator:BranchPythonOperator`.
+ Accepts arbitrary for operator kwarg. Can be reused in a single DAG.
+
+ :param multiple_outputs: If set, function return value will be unrolled to multiple XCom values.
+ Dict will unroll to XCom values with keys as XCom keys. Defaults to False.
+ """
+ @overload
+ def branch(self, python_callable: Callable[FParams, FReturn]) -> Task[FParams, FReturn]: ...
+ @overload
+ def short_circuit(
+ self,
+ *,
+ multiple_outputs: bool | None = None,
+ ignore_downstream_trigger_rules: bool = True,
+ **kwargs,
) -> TaskDecorator:
- """Wraps a python function into a BranchPythonOperator
+ """Create a decorator to wrap the decorated callable into a ShortCircuitOperator.
- For more information on how to use this operator, take a look at the guide:
- :ref:`howto/operator:BranchPythonOperator`
- Accepts kwargs for operator kwarg. Can be reused in a single DAG.
- :param python_callable: Function to decorate
- :type python_callable: Optional[Callable]
- :param multiple_outputs: if set, function return value will be
- unrolled to multiple XCom values. Dict will unroll to xcom values with keys as XCom keys.
- Defaults to False.
- :type multiple_outputs: bool
+ :param multiple_outputs: If set, function return value will be unrolled to multiple XCom values.
+ Dict will unroll to XCom values with keys as XCom keys. Defaults to False.
+ :param ignore_downstream_trigger_rules: If set to True, all downstream tasks from this operator task
+ will be skipped. This is the default behavior. If set to False, the direct, downstream task(s)
+ will be skipped but the ``trigger_rule`` defined for a other downstream tasks will be respected.
+ Defaults to True.
"""
@overload
- def branch(self, python_callable: Function) -> Task[Function]: ...
+ def short_circuit(self, python_callable: Callable[FParams, FReturn]) -> Task[FParams, FReturn]: ...
# [START decorator_signature]
def docker(
self,
*,
- multiple_outputs: Optional[bool] = None,
+ multiple_outputs: bool | None = None,
use_dill: bool = False, # Added by _DockerDecoratedOperator.
python_command: str = "python3",
# 'command', 'retrieve_output', and 'retrieve_output_path' are filled by
# _DockerDecoratedOperator.
image: str,
- api_version: Optional[str] = None,
- container_name: Optional[str] = None,
+ api_version: str | None = None,
+ container_name: str | None = None,
cpus: float = 1.0,
docker_url: str = "unix://var/run/docker.sock",
- environment: Optional[Dict[str, str]] = None,
- private_environment: Optional[Dict[str, str]] = None,
+ environment: dict[str, str] | None = None,
+ private_environment: dict[str, str] | None = None,
force_pull: bool = False,
- mem_limit: Optional[Union[float, str]] = None,
- host_tmp_dir: Optional[str] = None,
- network_mode: Optional[str] = None,
- tls_ca_cert: Optional[str] = None,
- tls_client_cert: Optional[str] = None,
- tls_client_key: Optional[str] = None,
- tls_hostname: Optional[Union[str, bool]] = None,
- tls_ssl_version: Optional[str] = None,
+ mem_limit: float | str | None = None,
+ host_tmp_dir: str | None = None,
+ network_mode: str | None = None,
+ tls_ca_cert: str | None = None,
+ tls_client_cert: str | None = None,
+ tls_client_key: str | None = None,
+ tls_hostname: str | bool | None = None,
+ tls_ssl_version: str | None = None,
tmp_dir: str = "/tmp/airflow",
- user: Optional[Union[str, int]] = None,
- mounts: Optional[List[str]] = None,
- working_dir: Optional[str] = None,
+ user: str | int | None = None,
+ mounts: list[str] | None = None,
+ working_dir: str | None = None,
xcom_all: bool = False,
- docker_conn_id: Optional[str] = None,
- dns: Optional[List[str]] = None,
- dns_search: Optional[List[str]] = None,
+ docker_conn_id: str | None = None,
+ dns: list[str] | None = None,
+ dns_search: list[str] | None = None,
auto_remove: bool = False,
- shm_size: Optional[int] = None,
+ shm_size: int | None = None,
tty: bool = False,
privileged: bool = False,
- cap_add: Optional[Iterable[str]] = None,
- extra_hosts: Optional[Dict[str, str]] = None,
+ cap_add: str | None = None,
+ extra_hosts: dict[str, str] | None = None,
**kwargs,
) -> TaskDecorator:
"""Create a decorator to convert the decorated callable to a Docker task.
@@ -245,5 +297,157 @@ class TaskDecoratorCollection:
:param cap_add: Include container capabilities
"""
# [END decorator_signature]
+ def kubernetes(
+ self,
+ *,
+ image: str,
+ kubernetes_conn_id: str = ...,
+ namespace: str = "default",
+ name: str = ...,
+ random_name_suffix: bool = True,
+ ports: list[k8s.V1ContainerPort] | None = None,
+ volume_mounts: list[k8s.V1VolumeMount] | None = None,
+ volumes: list[k8s.V1Volume] | None = None,
+ env_vars: list[k8s.V1EnvVar] | None = None,
+ env_from: list[k8s.V1EnvFromSource] | None = None,
+ secrets: list[Secret] | None = None,
+ in_cluster: bool | None = None,
+ cluster_context: str | None = None,
+ labels: dict | None = None,
+ reattach_on_restart: bool = True,
+ startup_timeout_seconds: int = 120,
+ get_logs: bool = True,
+ image_pull_policy: str | None = None,
+ annotations: dict | None = None,
+ container_resources: k8s.V1ResourceRequirements | None = None,
+ affinity: k8s.V1Affinity | None = None,
+ config_file: str = ...,
+ node_selector: dict | None = None,
+ image_pull_secrets: list[k8s.V1LocalObjectReference] | None = None,
+ service_account_name: str | None = None,
+ is_delete_operator_pod: bool = True,
+ hostnetwork: bool = False,
+ tolerations: list[k8s.V1Toleration] | None = None,
+ security_context: dict | None = None,
+ dnspolicy: str | None = None,
+ schedulername: str | None = None,
+ init_containers: list[k8s.V1Container] | None = None,
+ log_events_on_failure: bool = False,
+ do_xcom_push: bool = False,
+ pod_template_file: str | None = None,
+ priority_class_name: str | None = None,
+ pod_runtime_info_envs: list[k8s.V1EnvVar] | None = None,
+ termination_grace_period: int | None = None,
+ configmaps: list[str] | None = None,
+ **kwargs,
+ ) -> TaskDecorator:
+ """Create a decorator to convert a callable to a Kubernetes Pod task.
+
+ :param kubernetes_conn_id: The Kubernetes cluster's
+ :ref:`connection ID `.
+ :param namespace: Namespace to run within Kubernetes. Defaults to *default*.
+ :param image: Docker image to launch. Defaults to *hub.docker.com*, but
+ a fully qualified URL will point to a custom repository. (templated)
+ :param name: Name of the pod to run. This will be used (plus a random
+ suffix if *random_name_suffix* is *True*) to generate a pod ID
+ (DNS-1123 subdomain, containing only ``[a-z0-9.-]``). Defaults to
+ ``k8s_airflow_pod_{RANDOM_UUID}``.
+ :param random_name_suffix: If *True*, will generate a random suffix.
+ :param ports: Ports for the launched pod.
+ :param volume_mounts: *volumeMounts* for the launched pod.
+ :param volumes: Volumes for the launched pod. Includes *ConfigMaps* and
+ *PersistentVolumes*.
+ :param env_vars: Environment variables initialized in the container.
+ (templated)
+ :param env_from: List of sources to populate environment variables in
+ the container.
+ :param secrets: Kubernetes secrets to inject in the container. They can
+ be exposed as environment variables or files in a volume.
+ :param in_cluster: Run kubernetes client with *in_cluster* configuration.
+ :param cluster_context: Context that points to the Kubernetes cluster.
+ Ignored when *in_cluster* is *True*. If *None*, current-context will
+ be used.
+ :param reattach_on_restart: If the worker dies while the pod is running,
+ reattach and monitor during the next try. If *False*, always create
+ a new pod for each try.
+ :param labels: Labels to apply to the pod. (templated)
+ :param startup_timeout_seconds: Timeout in seconds to startup the pod.
+ :param get_logs: Get the stdout of the container as logs of the tasks.
+ :param image_pull_policy: Specify a policy to cache or always pull an
+ image.
+ :param annotations: Non-identifying metadata you can attach to the pod.
+ Can be a large range of data, and can include characters that are
+ not permitted by labels.
+ :param container_resources: Resources for the launched pod.
+ :param affinity: Affinity scheduling rules for the launched pod.
+ :param config_file: The path to the Kubernetes config file. If not
+ specified, default value is ``~/.kube/config``. (templated)
+ :param node_selector: A dict containing a group of scheduling rules.
+ :param image_pull_secrets: Any image pull secrets to be given to the
+ pod. If more than one secret is required, provide a comma separated
+ list, e.g. ``secret_a,secret_b``.
+ :param service_account_name: Name of the service account.
+ :param is_delete_operator_pod: What to do when the pod reaches its final
+ state, or the execution is interrupted. If *True* (default), delete
+ the pod; otherwise leave the pod.
+ :param hostnetwork: If *True*, enable host networking on the pod.
+ :param tolerations: A list of Kubernetes tolerations.
+ :param security_context: Security options the pod should run with
+ (PodSecurityContext).
+ :param dnspolicy: DNS policy for the pod.
+ :param schedulername: Specify a scheduler name for the pod
+ :param init_containers: Init containers for the launched pod.
+ :param log_events_on_failure: Log the pod's events if a failure occurs.
+ :param do_xcom_push: If *True*, the content of
+ ``/airflow/xcom/return.json`` in the container will also be pushed
+ to an XCom when the container completes.
+ :param pod_template_file: Path to pod template file (templated)
+ :param priority_class_name: Priority class name for the launched pod.
+ :param pod_runtime_info_envs: A list of environment variables
+ to be set in the container.
+ :param termination_grace_period: Termination grace period if task killed
+ in UI, defaults to kubernetes default
+ :param configmaps: A list of names of config maps from which it collects
+ ConfigMaps to populate the environment variables with. The contents
+ of the target ConfigMap's Data field will represent the key-value
+ pairs as environment variables. Extends env_from.
+ """
+ @overload
+ def sensor(
+ self,
+ *,
+ poke_interval: float = ...,
+ timeout: float = ...,
+ soft_fail: bool = False,
+ mode: str = ...,
+ exponential_backoff: bool = False,
+ max_wait: timedelta | float | None = None,
+ **kwargs,
+ ) -> TaskDecorator:
+ """
+ Wraps a Python function into a sensor operator.
+
+ :param poke_interval: Time in seconds that the job should wait in
+ between each try
+ :param timeout: Time, in seconds before the task times out and fails.
+ :param soft_fail: Set to true to mark the task as SKIPPED on failure
+ :param mode: How the sensor operates.
+ Options are: ``{ poke | reschedule }``, default is ``poke``.
+ When set to ``poke`` the sensor is taking up a worker slot for its
+ whole execution time and sleeps between pokes. Use this mode if the
+ expected runtime of the sensor is short or if a short poke interval
+ is required. Note that the sensor will hold onto a worker slot and
+ a pool slot for the duration of the sensor's runtime in this mode.
+ When set to ``reschedule`` the sensor task frees the worker slot when
+ the criteria is not yet met and it's rescheduled at a later time. Use
+ this mode if the time before the criteria is met is expected to be
+ quite long. The poke interval should be more than one minute to
+ prevent too much load on the scheduler.
+ :param exponential_backoff: allow progressive longer waits between
+ pokes by using exponential backoff algorithm
+ :param max_wait: maximum wait interval between pokes, can be ``timedelta`` or ``float`` seconds
+ """
+ @overload
+ def sensor(self, python_callable: Optional[FParams, FReturn] = None) -> Task[FParams, FReturn]: ...
task: TaskDecoratorCollection
diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py
index 1b5b5b760bf77..7da74e5514e42 100644
--- a/airflow/decorators/base.py
+++ b/airflow/decorators/base.py
@@ -14,23 +14,21 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
-import functools
import inspect
import re
+from itertools import chain
from typing import (
- TYPE_CHECKING,
Any,
Callable,
+ ClassVar,
Collection,
Dict,
Generic,
Iterator,
Mapping,
- Optional,
Sequence,
- Set,
- Type,
TypeVar,
cast,
overload,
@@ -38,8 +36,10 @@
import attr
import typing_extensions
+from sqlalchemy.orm import Session
-from airflow.compat.functools import cache, cached_property
+from airflow import Dataset
+from airflow.compat.functools import cached_property
from airflow.exceptions import AirflowException
from airflow.models.abstractoperator import DEFAULT_RETRIES, DEFAULT_RETRY_DELAY
from airflow.models.baseoperator import (
@@ -50,50 +50,74 @@
parse_retries,
)
from airflow.models.dag import DAG, DagContext
-from airflow.models.mappedoperator import (
- MappedOperator,
- ValidationSource,
- ensure_xcomarg_return_value,
- get_mappable_types,
- prevent_duplicates,
+from airflow.models.expandinput import (
+ EXPAND_INPUT_EMPTY,
+ DictOfListsExpandInput,
+ ExpandInput,
+ ListOfDictsExpandInput,
+ OperatorExpandArgument,
+ OperatorExpandKwargsArgument,
+ is_mappable,
)
+from airflow.models.mappedoperator import MappedOperator, ValidationSource, ensure_xcomarg_return_value
from airflow.models.pool import Pool
from airflow.models.xcom_arg import XComArg
-from airflow.typing_compat import Protocol
+from airflow.typing_compat import ParamSpec, Protocol
from airflow.utils import timezone
from airflow.utils.context import KNOWN_CONTEXT_KEYS, Context
+from airflow.utils.helpers import prevent_duplicates
from airflow.utils.task_group import TaskGroup, TaskGroupContext
from airflow.utils.types import NOTSET
-if TYPE_CHECKING:
- import jinja2 # Slow import.
- from sqlalchemy.orm import Session
- from airflow.models.mappedoperator import Mappable
+class ExpandableFactory(Protocol):
+ """Protocol providing inspection against wrapped function.
+ This is used in ``validate_expand_kwargs`` and implemented by function
+ decorators like ``@task`` and ``@task_group``.
-def validate_python_callable(python_callable: Any) -> None:
+ :meta private:
"""
- Validate that python callable can be wrapped by operator.
- Raises exception if invalid.
- :param python_callable: Python object to be validated
- :raises: TypeError, AirflowException
- """
- if not callable(python_callable):
- raise TypeError('`python_callable` param must be callable')
- if 'self' in inspect.signature(python_callable).parameters.keys():
- raise AirflowException('@task does not support methods')
+ function: Callable
+
+ @cached_property
+ def function_signature(self) -> inspect.Signature:
+ return inspect.signature(self.function)
+
+ @cached_property
+ def _mappable_function_argument_names(self) -> set[str]:
+ """Arguments that can be mapped against."""
+ return set(self.function_signature.parameters)
+
+ def _validate_arg_names(self, func: ValidationSource, kwargs: dict[str, Any]) -> None:
+ """Ensure that all arguments passed to operator-mapping functions are accounted for."""
+ parameters = self.function_signature.parameters
+ if any(v.kind == inspect.Parameter.VAR_KEYWORD for v in parameters.values()):
+ return
+ kwargs_left = kwargs.copy()
+ for arg_name in self._mappable_function_argument_names:
+ value = kwargs_left.pop(arg_name, NOTSET)
+ if func != "expand" or value is NOTSET or is_mappable(value):
+ continue
+ tname = type(value).__name__
+ raise ValueError(f"expand() got an unexpected type {tname!r} for keyword argument {arg_name!r}")
+ if len(kwargs_left) == 1:
+ raise TypeError(f"{func}() got an unexpected keyword argument {next(iter(kwargs_left))!r}")
+ elif kwargs_left:
+ names = ", ".join(repr(n) for n in kwargs_left)
+ raise TypeError(f"{func}() got unexpected keyword arguments {names}")
def get_unique_task_id(
task_id: str,
- dag: Optional[DAG] = None,
- task_group: Optional[TaskGroup] = None,
+ dag: DAG | None = None,
+ task_group: TaskGroup | None = None,
) -> str:
"""
- Generate unique task id given a DAG (or if run in a DAG context)
- Ids are generated by appending a unique number to the end of
+ Generate unique task id given a DAG (or if run in a DAG context).
+
+ IDs are generated by appending a unique number to the end of
the original task id.
Example:
@@ -144,38 +168,52 @@ class DecoratedOperator(BaseOperator):
PythonOperator). This gives a user the option to upstream kwargs as needed.
"""
- template_fields: Sequence[str] = ('op_args', 'op_kwargs')
+ template_fields: Sequence[str] = ("op_args", "op_kwargs")
template_fields_renderers = {"op_args": "py", "op_kwargs": "py"}
# since we won't mutate the arguments, we should just do the shallow copy
# there are some cases we can't deepcopy the objects (e.g protobuf).
- shallow_copy_attrs: Sequence[str] = ('python_callable',)
+ shallow_copy_attrs: Sequence[str] = ("python_callable",)
def __init__(
self,
*,
python_callable: Callable,
task_id: str,
- op_args: Optional[Collection[Any]] = None,
- op_kwargs: Optional[Mapping[str, Any]] = None,
+ op_args: Collection[Any] | None = None,
+ op_kwargs: Mapping[str, Any] | None = None,
multiple_outputs: bool = False,
- kwargs_to_upstream: Optional[Dict[str, Any]] = None,
+ kwargs_to_upstream: dict[str, Any] | None = None,
**kwargs,
) -> None:
- task_id = get_unique_task_id(task_id, kwargs.get('dag'), kwargs.get('task_group'))
+ task_id = get_unique_task_id(task_id, kwargs.get("dag"), kwargs.get("task_group"))
self.python_callable = python_callable
kwargs_to_upstream = kwargs_to_upstream or {}
op_args = op_args or []
op_kwargs = op_kwargs or {}
- # Check that arguments can be binded
- inspect.signature(python_callable).bind(*op_args, **op_kwargs)
+ # Check that arguments can be binded. There's a slight difference when
+ # we do validation for task-mapping: Since there's no guarantee we can
+ # receive enough arguments at parse time, we use bind_partial to simply
+ # check all the arguments we know are valid. Whether these are enough
+ # can only be known at execution time, when unmapping happens, and this
+ # is called without the _airflow_mapped_validation_only flag.
+ if kwargs.get("_airflow_mapped_validation_only"):
+ inspect.signature(python_callable).bind_partial(*op_args, **op_kwargs)
+ else:
+ inspect.signature(python_callable).bind(*op_args, **op_kwargs)
+
self.multiple_outputs = multiple_outputs
self.op_args = op_args
self.op_kwargs = op_kwargs
super().__init__(task_id=task_id, **kwargs_to_upstream, **kwargs)
def execute(self, context: Context):
+ # todo make this more generic (move to prepare_lineage) so it deals with non taskflow operators
+ # as well
+ for arg in chain(self.op_args, self.op_kwargs.values()):
+ if isinstance(arg, Dataset):
+ self.inlets.append(arg)
return_value = super().execute(context)
return self._handle_output(return_value=return_value, context=context, xcom_push=self.xcom_push)
@@ -183,49 +221,61 @@ def _handle_output(self, return_value: Any, context: Context, xcom_push: Callabl
"""
Handles logic for whether a decorator needs to push a single return value or multiple return values.
+ It sets outlets if any datasets are found in the returned value(s)
+
:param return_value:
:param context:
:param xcom_push:
"""
+ if isinstance(return_value, Dataset):
+ self.outlets.append(return_value)
+ if isinstance(return_value, list):
+ for item in return_value:
+ if isinstance(item, Dataset):
+ self.outlets.append(item)
if not self.multiple_outputs:
return return_value
if isinstance(return_value, dict):
for key in return_value.keys():
if not isinstance(key, str):
raise AirflowException(
- 'Returned dictionary keys must be strings when using '
- f'multiple_outputs, found {key} ({type(key)}) instead'
+ "Returned dictionary keys must be strings when using "
+ f"multiple_outputs, found {key} ({type(key)}) instead"
)
for key, value in return_value.items():
+ if isinstance(value, Dataset):
+ self.outlets.append(value)
xcom_push(context, key, value)
else:
raise AirflowException(
- f'Returned output was type {type(return_value)} expected dictionary for multiple_outputs'
+ f"Returned output was type {type(return_value)} expected dictionary for multiple_outputs"
)
return return_value
def _hook_apply_defaults(self, *args, **kwargs):
- if 'python_callable' not in kwargs:
+ if "python_callable" not in kwargs:
return args, kwargs
- python_callable = kwargs['python_callable']
- default_args = kwargs.get('default_args') or {}
- op_kwargs = kwargs.get('op_kwargs') or {}
+ python_callable = kwargs["python_callable"]
+ default_args = kwargs.get("default_args") or {}
+ op_kwargs = kwargs.get("op_kwargs") or {}
f_sig = inspect.signature(python_callable)
for arg in f_sig.parameters:
if arg not in op_kwargs and arg in default_args:
op_kwargs[arg] = default_args[arg]
- kwargs['op_kwargs'] = op_kwargs
+ kwargs["op_kwargs"] = op_kwargs
return args, kwargs
-Function = TypeVar("Function", bound=Callable)
+FParams = ParamSpec("FParams")
+
+FReturn = TypeVar("FReturn")
OperatorSubclass = TypeVar("OperatorSubclass", bound="BaseOperator")
@attr.define(slots=False)
-class _TaskDecorator(Generic[Function, OperatorSubclass]):
+class _TaskDecorator(ExpandableFactory, Generic[FParams, FReturn, OperatorSubclass]):
"""
Helper class for providing dynamic task mapping to decorated functions.
@@ -234,18 +284,20 @@ class _TaskDecorator(Generic[Function, OperatorSubclass]):
:meta private:
"""
- function: Function = attr.ib()
- operator_class: Type[OperatorSubclass]
+ function: Callable[FParams, FReturn] = attr.ib(validator=attr.validators.is_callable())
+ operator_class: type[OperatorSubclass]
multiple_outputs: bool = attr.ib()
- kwargs: Dict[str, Any] = attr.ib(factory=dict)
+ kwargs: dict[str, Any] = attr.ib(factory=dict)
decorator_name: str = attr.ib(repr=False, default="task")
+ _airflow_is_task_decorator: ClassVar[bool] = True
+
@multiple_outputs.default
def _infer_multiple_outputs(self):
try:
return_type = typing_extensions.get_type_hints(self.function).get("return", Any)
- except Exception: # Can't evaluate retrurn type.
+ except TypeError: # Can't evaluate return type.
return False
ttype = getattr(return_type, "__origin__", return_type)
return ttype == dict or ttype == Dict
@@ -253,9 +305,9 @@ def _infer_multiple_outputs(self):
def __attrs_post_init__(self):
if "self" in self.function_signature.parameters:
raise TypeError(f"@{self.decorator_name} does not support methods")
- self.kwargs.setdefault('task_id', self.function.__name__)
+ self.kwargs.setdefault("task_id", self.function.__name__)
- def __call__(self, *args, **kwargs) -> XComArg:
+ def __call__(self, *args: FParams.args, **kwargs: FParams.kwargs) -> XComArg:
op = self.operator_class(
python_callable=self.function,
op_args=args,
@@ -268,24 +320,10 @@ def __call__(self, *args, **kwargs) -> XComArg:
return XComArg(op)
@property
- def __wrapped__(self) -> Function:
+ def __wrapped__(self) -> Callable[FParams, FReturn]:
return self.function
- @cached_property
- def function_signature(self):
- return inspect.signature(self.function)
-
- @cached_property
- def _function_is_vararg(self):
- parameters = self.function_signature.parameters
- return any(v.kind == inspect.Parameter.VAR_KEYWORD for v in parameters.values())
-
- @cached_property
- def _mappable_function_argument_names(self) -> Set[str]:
- """Arguments that can be mapped against."""
- return set(self.function_signature.parameters)
-
- def _validate_arg_names(self, func: ValidationSource, kwargs: Dict[str, Any]):
+ def _validate_arg_names(self, func: ValidationSource, kwargs: dict[str, Any]):
# Ensure that context variables are not shadowed.
context_keys_being_mapped = KNOWN_CONTEXT_KEYS.intersection(kwargs)
if len(context_keys_being_mapped) == 1:
@@ -295,35 +333,34 @@ def _validate_arg_names(self, func: ValidationSource, kwargs: Dict[str, Any]):
names = ", ".join(repr(n) for n in context_keys_being_mapped)
raise ValueError(f"cannot call {func}() on task context variables {names}")
- # Ensure that all arguments passed in are accounted for.
- if self._function_is_vararg:
- return
- kwargs_left = kwargs.copy()
- for arg_name in self._mappable_function_argument_names:
- value = kwargs_left.pop(arg_name, NOTSET)
- if func != "expand" or value is NOTSET or isinstance(value, get_mappable_types()):
- continue
- tname = type(value).__name__
- raise ValueError(f"expand() got an unexpected type {tname!r} for keyword argument {arg_name!r}")
- if len(kwargs_left) == 1:
- raise TypeError(f"{func}() got an unexpected keyword argument {next(iter(kwargs_left))!r}")
- elif kwargs_left:
- names = ", ".join(repr(n) for n in kwargs_left)
- raise TypeError(f"{func}() got unexpected keyword arguments {names}")
+ super()._validate_arg_names(func, kwargs)
- def expand(self, **map_kwargs: "Mappable") -> XComArg:
+ def expand(self, **map_kwargs: OperatorExpandArgument) -> XComArg:
if not map_kwargs:
raise TypeError("no arguments to expand against")
-
self._validate_arg_names("expand", map_kwargs)
prevent_duplicates(self.kwargs, map_kwargs, fail_reason="mapping already partial")
- ensure_xcomarg_return_value(map_kwargs)
+ # Since the input is already checked at parse time, we can set strict
+ # to False to skip the checks on execution.
+ return self._expand(DictOfListsExpandInput(map_kwargs), strict=False)
+
+ def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> XComArg:
+ if isinstance(kwargs, Sequence):
+ for item in kwargs:
+ if not isinstance(item, (XComArg, Mapping)):
+ raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")
+ elif not isinstance(kwargs, XComArg):
+ raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")
+ return self._expand(ListOfDictsExpandInput(kwargs), strict=strict)
+
+ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg:
+ ensure_xcomarg_return_value(expand_input.value)
task_kwargs = self.kwargs.copy()
dag = task_kwargs.pop("dag", None) or DagContext.get_current_dag()
task_group = task_kwargs.pop("task_group", None) or TaskGroupContext.get_current_task_group(dag)
- partial_kwargs, default_params = get_merged_defaults(
+ partial_kwargs, partial_params = get_merged_defaults(
dag=dag,
task_group=task_group,
task_params=task_kwargs.pop("params", None),
@@ -332,7 +369,8 @@ def expand(self, **map_kwargs: "Mappable") -> XComArg:
partial_kwargs.update(task_kwargs)
task_id = get_unique_task_id(partial_kwargs.pop("task_id"), dag, task_group)
- params = partial_kwargs.pop("params", None) or default_params
+ if task_group:
+ task_id = task_group.child_id(task_id)
# Logic here should be kept in sync with BaseOperatorMeta.partial().
if "task_concurrency" in partial_kwargs:
@@ -361,12 +399,18 @@ def expand(self, **map_kwargs: "Mappable") -> XComArg:
# Mypy does not work well with a subclassed attrs class :(
_MappedOperator = cast(Any, DecoratedMappedOperator)
+
+ try:
+ operator_name = self.operator_class.custom_operator_name # type: ignore
+ except AttributeError:
+ operator_name = self.operator_class.__name__
+
operator = _MappedOperator(
operator_class=self.operator_class,
- mapped_kwargs={},
+ expand_input=EXPAND_INPUT_EMPTY, # Don't use this; mapped values go to op_kwargs_expand_input.
partial_kwargs=partial_kwargs,
task_id=task_id,
- params=params,
+ params=partial_params,
deps=MappedOperator.deps_for(self.operator_class),
operator_extra_links=self.operator_class.operator_extra_links,
template_ext=self.operator_class.template_ext,
@@ -377,42 +421,33 @@ def expand(self, **map_kwargs: "Mappable") -> XComArg:
is_empty=False,
task_module=self.operator_class.__module__,
task_type=self.operator_class.__name__,
+ operator_name=operator_name,
dag=dag,
task_group=task_group,
start_date=start_date,
end_date=end_date,
multiple_outputs=self.multiple_outputs,
python_callable=self.function,
- mapped_op_kwargs=map_kwargs,
+ op_kwargs_expand_input=expand_input,
+ disallow_kwargs_override=strict,
# Different from classic operators, kwargs passed to a taskflow
# task's expand() contribute to the op_kwargs operator argument, not
# the operator arguments themselves, and should expand against it.
- expansion_kwargs_attr="mapped_op_kwargs",
+ expand_input_attr="op_kwargs_expand_input",
)
return XComArg(operator=operator)
- def partial(self, **kwargs) -> "_TaskDecorator[Function, OperatorSubclass]":
+ def partial(self, **kwargs: Any) -> _TaskDecorator[FParams, FReturn, OperatorSubclass]:
self._validate_arg_names("partial", kwargs)
+ old_kwargs = self.kwargs.get("op_kwargs", {})
+ prevent_duplicates(old_kwargs, kwargs, fail_reason="duplicate partial")
+ kwargs.update(old_kwargs)
+ return attr.evolve(self, kwargs={**self.kwargs, "op_kwargs": kwargs})
- op_kwargs = self.kwargs.get("op_kwargs", {})
- op_kwargs = _merge_kwargs(op_kwargs, kwargs, fail_reason="duplicate partial")
-
- return attr.evolve(self, kwargs={**self.kwargs, "op_kwargs": op_kwargs})
-
- def override(self, **kwargs) -> "_TaskDecorator[Function, OperatorSubclass]":
+ def override(self, **kwargs: Any) -> _TaskDecorator[FParams, FReturn, OperatorSubclass]:
return attr.evolve(self, kwargs={**self.kwargs, **kwargs})
-def _merge_kwargs(kwargs1: Dict[str, Any], kwargs2: Dict[str, Any], *, fail_reason: str) -> Dict[str, Any]:
- duplicated_keys = set(kwargs1).intersection(kwargs2)
- if len(duplicated_keys) == 1:
- raise TypeError(f"{fail_reason} argument: {duplicated_keys.pop()}")
- elif duplicated_keys:
- duplicated_keys_display = ", ".join(sorted(duplicated_keys))
- raise TypeError(f"{fail_reason} arguments: {duplicated_keys_display}")
- return {**kwargs1, **kwargs2}
-
-
@attr.define(kw_only=True, repr=False)
class DecoratedMappedOperator(MappedOperator):
"""MappedOperator implementation for @task-decorated task function."""
@@ -420,81 +455,41 @@ class DecoratedMappedOperator(MappedOperator):
multiple_outputs: bool
python_callable: Callable
- # We can't save these in mapped_kwargs because op_kwargs need to be present
+ # We can't save these in expand_input because op_kwargs need to be present
# in partial_kwargs, and MappedOperator prevents duplication.
- mapped_op_kwargs: Dict[str, "Mappable"]
+ op_kwargs_expand_input: ExpandInput
def __hash__(self):
return id(self)
- @classmethod
- @cache
- def get_serialized_fields(cls):
- # The magic super() doesn't work here, so we use the explicit form.
- # Not using super(..., cls) to work around pyupgrade bug.
- sup = super(DecoratedMappedOperator, DecoratedMappedOperator)
- return sup.get_serialized_fields() | {"mapped_op_kwargs"}
-
def __attrs_post_init__(self):
# The magic super() doesn't work here, so we use the explicit form.
# Not using super(..., self) to work around pyupgrade bug.
super(DecoratedMappedOperator, DecoratedMappedOperator).__attrs_post_init__(self)
- XComArg.apply_upstream_relationship(self, self.mapped_op_kwargs)
-
- def _get_unmap_kwargs(self) -> Dict[str, Any]:
- partial_kwargs = self.partial_kwargs.copy()
- op_kwargs = _merge_kwargs(
- partial_kwargs.pop("op_kwargs"),
- self.mapped_op_kwargs,
- fail_reason="mapping already partial",
- )
- self._combined_op_kwargs = op_kwargs
- return {
- "dag": self.dag,
- "task_group": self.task_group,
- "task_id": self.task_id,
- "op_kwargs": op_kwargs,
+ XComArg.apply_upstream_relationship(self, self.op_kwargs_expand_input.value)
+
+ def _expand_mapped_kwargs(self, context: Context, session: Session) -> tuple[Mapping[str, Any], set[int]]:
+ # We only use op_kwargs_expand_input so this must always be empty.
+ assert self.expand_input is EXPAND_INPUT_EMPTY
+ op_kwargs, resolved_oids = super()._expand_mapped_kwargs(context, session)
+ return {"op_kwargs": op_kwargs}, resolved_oids
+
+ def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict: bool) -> dict[str, Any]:
+ partial_op_kwargs = self.partial_kwargs["op_kwargs"]
+ mapped_op_kwargs = mapped_kwargs["op_kwargs"]
+
+ if strict:
+ prevent_duplicates(partial_op_kwargs, mapped_op_kwargs, fail_reason="mapping already partial")
+
+ kwargs = {
"multiple_outputs": self.multiple_outputs,
"python_callable": self.python_callable,
- **partial_kwargs,
- **self.mapped_kwargs,
+ "op_kwargs": {**partial_op_kwargs, **mapped_op_kwargs},
}
+ return super()._get_unmap_kwargs(kwargs, strict=False)
- def _resolve_expansion_kwargs(
- self, kwargs: Dict[str, Any], template_fields: Set[str], context: Context, session: "Session"
- ) -> None:
- expansion_kwargs = self._get_expansion_kwargs()
-
- self._already_resolved_op_kwargs = set()
- for k, v in expansion_kwargs.items():
- if isinstance(v, XComArg):
- self._already_resolved_op_kwargs.add(k)
- v = v.resolve(context, session=session)
- v = self._expand_mapped_field(k, v, context, session=session)
- kwargs['op_kwargs'][k] = v
- template_fields.discard(k)
-
- def render_template(
- self,
- value: Any,
- context: Context,
- jinja_env: Optional["jinja2.Environment"] = None,
- seen_oids: Optional[Set] = None,
- ) -> Any:
- if hasattr(self, '_combined_op_kwargs') and value is self._combined_op_kwargs:
- # Avoid rendering values that came out of resolved XComArgs
- return {
- k: v
- if k in self._already_resolved_op_kwargs
- else super(DecoratedMappedOperator, DecoratedMappedOperator).render_template(
- self, v, context, jinja_env=jinja_env, seen_oids=seen_oids
- )
- for k, v in value.items()
- }
- return super().render_template(value, context, jinja_env=jinja_env, seen_oids=seen_oids)
-
-
-class Task(Generic[Function]):
+
+class Task(Generic[FParams, FReturn]):
"""Declaration of a @task-decorated callable for type-checking.
An instance of this type inherits the call signature of the decorated
@@ -505,18 +500,24 @@ class Task(Generic[Function]):
This type is implemented by ``_TaskDecorator`` at runtime.
"""
- __call__: Function
+ __call__: Callable[FParams, XComArg]
- function: Function
+ function: Callable[FParams, FReturn]
@property
- def __wrapped__(self) -> Function:
+ def __wrapped__(self) -> Callable[FParams, FReturn]:
...
- def expand(self, **kwargs: "Mappable") -> XComArg:
+ def partial(self, **kwargs: Any) -> Task[FParams, FReturn]:
...
- def partial(self, **kwargs: Any) -> "Task[Function]":
+ def expand(self, **kwargs: OperatorExpandArgument) -> XComArg:
+ ...
+
+ def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> XComArg:
+ ...
+
+ def override(self, **kwargs: Any) -> Task[FParams, FReturn]:
...
@@ -524,36 +525,46 @@ class TaskDecorator(Protocol):
"""Type declaration for ``task_decorator_factory`` return type."""
@overload
- def __call__(self, python_callable: Function) -> Task[Function]:
+ def __call__( # type: ignore[misc]
+ self,
+ python_callable: Callable[FParams, FReturn],
+ ) -> Task[FParams, FReturn]:
"""For the "bare decorator" ``@task`` case."""
@overload
def __call__(
self,
*,
- multiple_outputs: Optional[bool] = None,
+ multiple_outputs: bool | None = None,
**kwargs: Any,
- ) -> Callable[[Function], Task[Function]]:
+ ) -> Callable[[Callable[FParams, FReturn]], Task[FParams, FReturn]]:
"""For the decorator factory ``@task()`` case."""
+ def override(self, **kwargs: Any) -> Task[FParams, FReturn]:
+ ...
+
def task_decorator_factory(
- python_callable: Optional[Callable] = None,
+ python_callable: Callable | None = None,
*,
- multiple_outputs: Optional[bool] = None,
- decorated_operator_class: Type[BaseOperator],
+ multiple_outputs: bool | None = None,
+ decorated_operator_class: type[BaseOperator],
**kwargs,
) -> TaskDecorator:
- """
- A factory that generates a wrapper that wraps a function into an Airflow operator.
- Accepts kwargs for operator kwarg. Can be reused in a single DAG.
+ """Generate a wrapper that wraps a function into an Airflow operator.
- :param python_callable: Function to decorate
- :param multiple_outputs: If set to True, the decorated function's return value will be unrolled to
- multiple XCom values. Dict will unroll to XCom values with its keys as XCom keys. Defaults to False.
- :param decorated_operator_class: The operator that executes the logic needed to run the python function in
- the correct environment
+ Can be reused in a single DAG.
+
+ :param python_callable: Function to decorate.
+ :param multiple_outputs: If set to True, the decorated function's return
+ value will be unrolled to multiple XCom values. Dict will unroll to XCom
+ values with its keys as XCom keys. If set to False (default), only at
+ most one XCom value is pushed.
+ :param decorated_operator_class: The operator that executes the logic needed
+ to run the python function in the correct environment.
+ Other kwargs are directly forwarded to the underlying operator class when
+ it's instantiated.
"""
if multiple_outputs is None:
multiple_outputs = cast(bool, attr.NOTHING)
@@ -566,11 +577,14 @@ def task_decorator_factory(
)
return cast(TaskDecorator, decorator)
elif python_callable is not None:
- raise TypeError('No args allowed while using @task, use kwargs instead')
- decorator_factory = functools.partial(
- _TaskDecorator,
- multiple_outputs=multiple_outputs,
- operator_class=decorated_operator_class,
- kwargs=kwargs,
- )
+ raise TypeError("No args allowed while using @task, use kwargs instead")
+
+ def decorator_factory(python_callable):
+ return _TaskDecorator(
+ function=python_callable,
+ multiple_outputs=multiple_outputs,
+ operator_class=decorated_operator_class,
+ kwargs=kwargs,
+ )
+
return cast(TaskDecorator, decorator_factory)
diff --git a/airflow/decorators/branch_python.py b/airflow/decorators/branch_python.py
index ac83132574402..b7e3a94826679 100644
--- a/airflow/decorators/branch_python.py
+++ b/airflow/decorators/branch_python.py
@@ -14,14 +14,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
import inspect
from textwrap import dedent
-from typing import Callable, Optional, Sequence, TypeVar
+from typing import Callable, Sequence
from airflow.decorators.base import DecoratedOperator, TaskDecorator, task_decorator_factory
from airflow.operators.python import BranchPythonOperator
-from airflow.utils.python_virtualenv import remove_task_decorator
+from airflow.utils.decorators import remove_task_decorator
class _BranchPythonDecoratedOperator(DecoratedOperator, BranchPythonOperator):
@@ -38,12 +39,14 @@ class _BranchPythonDecoratedOperator(DecoratedOperator, BranchPythonOperator):
Defaults to False.
"""
- template_fields: Sequence[str] = ('op_args', 'op_kwargs')
+ template_fields: Sequence[str] = ("op_args", "op_kwargs")
template_fields_renderers = {"op_args": "py", "op_kwargs": "py"}
# since we won't mutate the arguments, we should just do the shallow copy
# there are some cases we can't deepcopy the objects (e.g protobuf).
- shallow_copy_attrs: Sequence[str] = ('python_callable',)
+ shallow_copy_attrs: Sequence[str] = ("python_callable",)
+
+ custom_operator_name: str = "@task.branch"
def __init__(
self,
@@ -63,14 +66,12 @@ def get_python_source(self):
return res
-T = TypeVar("T", bound=Callable)
-
-
def branch_task(
- python_callable: Optional[Callable] = None, multiple_outputs: Optional[bool] = None, **kwargs
+ python_callable: Callable | None = None, multiple_outputs: bool | None = None, **kwargs
) -> TaskDecorator:
"""
- Wraps a python function into a BranchPythonOperator
+ Wraps a python function into a BranchPythonOperator.
+
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:BranchPythonOperator`
diff --git a/airflow/decorators/external_python.py b/airflow/decorators/external_python.py
new file mode 100644
index 0000000000000..273bf95321d83
--- /dev/null
+++ b/airflow/decorators/external_python.py
@@ -0,0 +1,102 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import inspect
+from textwrap import dedent
+from typing import Callable, Sequence
+
+from airflow.decorators.base import DecoratedOperator, TaskDecorator, task_decorator_factory
+from airflow.operators.python import ExternalPythonOperator
+from airflow.utils.decorators import remove_task_decorator
+
+
+class _PythonExternalDecoratedOperator(DecoratedOperator, ExternalPythonOperator):
+ """
+ Wraps a Python callable and captures args/kwargs when called for execution.
+
+ :param python: Full path string (file-system specific) that points to a Python binary inside
+ a virtualenv that should be used (in ``VENV/bin`` folder). Should be absolute path
+ (so usually start with "/" or "X:/" depending on the filesystem/os used).
+ :param python_callable: A reference to an object that is callable
+ :param op_kwargs: a dictionary of keyword arguments that will get unpacked
+ in your function (templated)
+ :param op_args: a list of positional arguments that will get unpacked when
+ calling your callable (templated)
+ :param multiple_outputs: If set to True, the decorated function's return value will be unrolled to
+ multiple XCom values. Dict will unroll to XCom values with its keys as XCom keys. Defaults to False.
+ """
+
+ template_fields: Sequence[str] = ("op_args", "op_kwargs")
+ template_fields_renderers = {"op_args": "py", "op_kwargs": "py"}
+
+ # since we won't mutate the arguments, we should just do the shallow copy
+ # there are some cases we can't deepcopy the objects (e.g protobuf).
+ shallow_copy_attrs: Sequence[str] = ("python_callable",)
+
+ custom_operator_name: str = "@task.external_python"
+
+ def __init__(self, *, python_callable, op_args, op_kwargs, **kwargs) -> None:
+ kwargs_to_upstream = {
+ "python_callable": python_callable,
+ "op_args": op_args,
+ "op_kwargs": op_kwargs,
+ }
+ super().__init__(
+ kwargs_to_upstream=kwargs_to_upstream,
+ python_callable=python_callable,
+ op_args=op_args,
+ op_kwargs=op_kwargs,
+ **kwargs,
+ )
+
+ def get_python_source(self):
+ raw_source = inspect.getsource(self.python_callable)
+ res = dedent(raw_source)
+ res = remove_task_decorator(res, "@task.external_python")
+ return res
+
+
+def external_python_task(
+ python: str | None = None,
+ python_callable: Callable | None = None,
+ multiple_outputs: bool | None = None,
+ **kwargs,
+) -> TaskDecorator:
+ """Wraps a callable into an Airflow operator to run via a Python virtual environment.
+
+ Accepts kwargs for operator kwarg. Can be reused in a single DAG.
+
+ This function is only used during type checking or auto-completion.
+
+ :meta private:
+
+ :param python: Full path string (file-system specific) that points to a Python binary inside
+ a virtualenv that should be used (in ``VENV/bin`` folder). Should be absolute path
+ (so usually start with "/" or "X:/" depending on the filesystem/os used).
+ :param python_callable: Function to decorate
+ :param multiple_outputs: If set to True, the decorated function's return value will be unrolled to
+ multiple XCom values. Dict will unroll to XCom values with its keys as XCom keys.
+ Defaults to False.
+ """
+ return task_decorator_factory(
+ python=python,
+ python_callable=python_callable,
+ multiple_outputs=multiple_outputs,
+ decorated_operator_class=_PythonExternalDecoratedOperator,
+ **kwargs,
+ )
diff --git a/airflow/decorators/python.py b/airflow/decorators/python.py
index c7c1a629a4dfb..af88e534bdae5 100644
--- a/airflow/decorators/python.py
+++ b/airflow/decorators/python.py
@@ -14,8 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
-from typing import Callable, Optional, Sequence, TypeVar
+from typing import Callable, Sequence
from airflow.decorators.base import DecoratedOperator, TaskDecorator, task_decorator_factory
from airflow.operators.python import PythonOperator
@@ -34,12 +35,14 @@ class _PythonDecoratedOperator(DecoratedOperator, PythonOperator):
multiple XCom values. Dict will unroll to XCom values with its keys as XCom keys. Defaults to False.
"""
- template_fields: Sequence[str] = ('op_args', 'op_kwargs')
- template_fields_renderers = {"op_args": "py", "op_kwargs": "py"}
+ template_fields: Sequence[str] = ("templates_dict", "op_args", "op_kwargs")
+ template_fields_renderers = {"templates_dict": "json", "op_args": "py", "op_kwargs": "py"}
# since we won't mutate the arguments, we should just do the shallow copy
# there are some cases we can't deepcopy the objects (e.g protobuf).
- shallow_copy_attrs: Sequence[str] = ('python_callable',)
+ shallow_copy_attrs: Sequence[str] = ("python_callable",)
+
+ custom_operator_name: str = "@task"
def __init__(self, *, python_callable, op_args, op_kwargs, **kwargs) -> None:
kwargs_to_upstream = {
@@ -56,12 +59,9 @@ def __init__(self, *, python_callable, op_args, op_kwargs, **kwargs) -> None:
)
-T = TypeVar("T", bound=Callable)
-
-
def python_task(
- python_callable: Optional[Callable] = None,
- multiple_outputs: Optional[bool] = None,
+ python_callable: Callable | None = None,
+ multiple_outputs: bool | None = None,
**kwargs,
) -> TaskDecorator:
"""Wraps a function into an Airflow operator.
diff --git a/airflow/decorators/python_virtualenv.py b/airflow/decorators/python_virtualenv.py
index e8fd681b8f3af..53ca04b8e8f78 100644
--- a/airflow/decorators/python_virtualenv.py
+++ b/airflow/decorators/python_virtualenv.py
@@ -14,14 +14,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
import inspect
from textwrap import dedent
-from typing import Callable, Optional, Sequence, TypeVar
+from typing import Callable, Sequence
from airflow.decorators.base import DecoratedOperator, TaskDecorator, task_decorator_factory
from airflow.operators.python import PythonVirtualenvOperator
-from airflow.utils.python_virtualenv import remove_task_decorator
+from airflow.utils.decorators import remove_task_decorator
class _PythonVirtualenvDecoratedOperator(DecoratedOperator, PythonVirtualenvOperator):
@@ -37,12 +38,14 @@ class _PythonVirtualenvDecoratedOperator(DecoratedOperator, PythonVirtualenvOper
multiple XCom values. Dict will unroll to XCom values with its keys as XCom keys. Defaults to False.
"""
- template_fields: Sequence[str] = ('op_args', 'op_kwargs')
+ template_fields: Sequence[str] = ("op_args", "op_kwargs")
template_fields_renderers = {"op_args": "py", "op_kwargs": "py"}
# since we won't mutate the arguments, we should just do the shallow copy
# there are some cases we can't deepcopy the objects (e.g protobuf).
- shallow_copy_attrs: Sequence[str] = ('python_callable',)
+ shallow_copy_attrs: Sequence[str] = ("python_callable",)
+
+ custom_operator_name: str = "@task.virtualenv"
def __init__(self, *, python_callable, op_args, op_kwargs, **kwargs) -> None:
kwargs_to_upstream = {
@@ -65,12 +68,9 @@ def get_python_source(self):
return res
-T = TypeVar("T", bound=Callable)
-
-
def virtualenv_task(
- python_callable: Optional[Callable] = None,
- multiple_outputs: Optional[bool] = None,
+ python_callable: Callable | None = None,
+ multiple_outputs: bool | None = None,
**kwargs,
) -> TaskDecorator:
"""Wraps a callable into an Airflow operator to run via a Python virtual environment.
diff --git a/airflow/decorators/sensor.py b/airflow/decorators/sensor.py
new file mode 100644
index 0000000000000..20339686201a9
--- /dev/null
+++ b/airflow/decorators/sensor.py
@@ -0,0 +1,74 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from typing import Callable, Sequence
+
+from airflow.decorators.base import TaskDecorator, get_unique_task_id, task_decorator_factory
+from airflow.models.taskinstance import Context
+from airflow.sensors.base import PokeReturnValue
+from airflow.sensors.python import PythonSensor
+
+
+class DecoratedSensorOperator(PythonSensor):
+ """
+ Wraps a Python callable and captures args/kwargs when called for execution.
+ :param python_callable: A reference to an object that is callable
+ :param task_id: task Id
+ :param op_args: a list of positional arguments that will get unpacked when
+ calling your callable (templated)
+ :param op_kwargs: a dictionary of keyword arguments that will get unpacked
+ in your function (templated)
+ :param kwargs_to_upstream: For certain operators, we might need to upstream certain arguments
+ that would otherwise be absorbed by the DecoratedOperator (for example python_callable for the
+ PythonOperator). This gives a user the option to upstream kwargs as needed.
+ """
+
+ template_fields: Sequence[str] = ("op_args", "op_kwargs")
+ template_fields_renderers: dict[str, str] = {"op_args": "py", "op_kwargs": "py"}
+
+ # since we won't mutate the arguments, we should just do the shallow copy
+ # there are some cases we can't deepcopy the objects (e.g protobuf).
+ shallow_copy_attrs: Sequence[str] = ("python_callable",)
+
+ def __init__(
+ self,
+ *,
+ task_id: str,
+ **kwargs,
+ ) -> None:
+ kwargs.pop("multiple_outputs")
+ kwargs["task_id"] = get_unique_task_id(task_id, kwargs.get("dag"), kwargs.get("task_group"))
+ super().__init__(**kwargs)
+
+ def poke(self, context: Context) -> PokeReturnValue | bool:
+ return self.python_callable(*self.op_args, **self.op_kwargs)
+
+
+def sensor_task(python_callable: Callable | None = None, **kwargs) -> TaskDecorator:
+ """
+ Wraps a function into an Airflow operator.
+ Accepts kwargs for operator kwarg. Can be reused in a single DAG.
+ :param python_callable: Function to decorate
+ """
+ return task_decorator_factory(
+ python_callable=python_callable,
+ multiple_outputs=False,
+ decorated_operator_class=DecoratedSensorOperator,
+ **kwargs,
+ )
diff --git a/airflow/decorators/short_circuit.py b/airflow/decorators/short_circuit.py
new file mode 100644
index 0000000000000..8422c4aa0844a
--- /dev/null
+++ b/airflow/decorators/short_circuit.py
@@ -0,0 +1,84 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import Callable, Sequence
+
+from airflow.decorators.base import DecoratedOperator, TaskDecorator, task_decorator_factory
+from airflow.operators.python import ShortCircuitOperator
+
+
+class _ShortCircuitDecoratedOperator(DecoratedOperator, ShortCircuitOperator):
+ """
+ Wraps a Python callable and captures args/kwargs when called for execution.
+
+ :param python_callable: A reference to an object that is callable
+ :param op_kwargs: a dictionary of keyword arguments that will get unpacked
+ in your function (templated)
+ :param op_args: a list of positional arguments that will get unpacked when
+ calling your callable (templated)
+ :param multiple_outputs: If set to True, the decorated function's return value will be unrolled to
+ multiple XCom values. Dict will unroll to XCom values with its keys as XCom keys. Defaults to False.
+ """
+
+ template_fields: Sequence[str] = ("op_args", "op_kwargs")
+ template_fields_renderers = {"op_args": "py", "op_kwargs": "py"}
+
+ # since we won't mutate the arguments, we should just do the shallow copy
+ # there are some cases we can't deepcopy the objects (e.g protobuf).
+ shallow_copy_attrs: Sequence[str] = ("python_callable",)
+
+ custom_operator_name: str = "@task.short_circuit"
+
+ def __init__(self, *, python_callable, op_args, op_kwargs, **kwargs) -> None:
+ kwargs_to_upstream = {
+ "python_callable": python_callable,
+ "op_args": op_args,
+ "op_kwargs": op_kwargs,
+ }
+ super().__init__(
+ kwargs_to_upstream=kwargs_to_upstream,
+ python_callable=python_callable,
+ op_args=op_args,
+ op_kwargs=op_kwargs,
+ **kwargs,
+ )
+
+
+def short_circuit_task(
+ python_callable: Callable | None = None,
+ multiple_outputs: bool | None = None,
+ **kwargs,
+) -> TaskDecorator:
+ """Wraps a function into an ShortCircuitOperator.
+
+ Accepts kwargs for operator kwarg. Can be reused in a single DAG.
+
+ This function is only used only used during type checking or auto-completion.
+
+ :param python_callable: Function to decorate
+ :param multiple_outputs: If set to True, the decorated function's return value will be unrolled to
+ multiple XCom values. Dict will unroll to XCom values with its keys as XCom keys. Defaults to False.
+
+ :meta private:
+ """
+ return task_decorator_factory(
+ python_callable=python_callable,
+ multiple_outputs=multiple_outputs,
+ decorated_operator_class=_ShortCircuitDecoratedOperator,
+ **kwargs,
+ )
diff --git a/airflow/decorators/task_group.py b/airflow/decorators/task_group.py
index 674474947036e..2aa714be3b6c5 100644
--- a/airflow/decorators/task_group.py
+++ b/airflow/decorators/task_group.py
@@ -15,56 +15,89 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""
-A TaskGroup is a collection of closely related tasks on the same DAG that should be grouped
+"""Implements the ``@task_group`` function decorator.
+
+When the decorated function is called, a task group will be created to represent
+a collection of closely related tasks on the same DAG that should be grouped
together when the DAG is displayed graphically.
"""
+
+from __future__ import annotations
+
import functools
-from inspect import signature
-from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, TypeVar, Union, cast, overload
+import inspect
+import warnings
+from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, Mapping, Sequence, TypeVar, overload
import attr
-from airflow.utils.task_group import TaskGroup
+from airflow.decorators.base import ExpandableFactory
+from airflow.models.expandinput import (
+ DictOfListsExpandInput,
+ ListOfDictsExpandInput,
+ MappedArgument,
+ OperatorExpandArgument,
+ OperatorExpandKwargsArgument,
+)
+from airflow.models.taskmixin import DAGNode
+from airflow.models.xcom_arg import XComArg
+from airflow.typing_compat import ParamSpec
+from airflow.utils.helpers import prevent_duplicates
+from airflow.utils.task_group import MappedTaskGroup, TaskGroup
if TYPE_CHECKING:
from airflow.models.dag import DAG
- from airflow.models.mappedoperator import Mappable
-F = TypeVar("F", bound=Callable)
-R = TypeVar("R")
+FParams = ParamSpec("FParams")
+FReturn = TypeVar("FReturn", None, DAGNode)
-task_group_sig = signature(TaskGroup.__init__)
+task_group_sig = inspect.signature(TaskGroup.__init__)
-@attr.define
-class TaskGroupDecorator(Generic[R]):
- """:meta private:"""
+@attr.define()
+class _TaskGroupFactory(ExpandableFactory, Generic[FParams, FReturn]):
+ function: Callable[FParams, FReturn] = attr.ib(validator=attr.validators.is_callable())
+ tg_kwargs: dict[str, Any] = attr.ib(factory=dict) # Parameters forwarded to TaskGroup.
+ partial_kwargs: dict[str, Any] = attr.ib(factory=dict) # Parameters forwarded to 'function'.
- function: Callable[..., Optional[R]] = attr.ib(validator=attr.validators.is_callable())
- kwargs: Dict[str, Any] = attr.ib(factory=dict)
- """kwargs for the TaskGroup"""
+ _task_group_created: bool = attr.ib(False, init=False)
- @function.validator
- def _validate_function(self, _, f):
- if 'self' in signature(f).parameters:
- raise TypeError('@task_group does not support methods')
+ tg_class: ClassVar[type[TaskGroup]] = TaskGroup
- @kwargs.validator
+ @tg_kwargs.validator
def _validate(self, _, kwargs):
task_group_sig.bind_partial(**kwargs)
def __attrs_post_init__(self):
- self.kwargs.setdefault('group_id', self.function.__name__)
+ self.tg_kwargs.setdefault("group_id", self.function.__name__)
+
+ def __del__(self):
+ if self.partial_kwargs and not self._task_group_created:
+ try:
+ group_id = repr(self.tg_kwargs["group_id"])
+ except KeyError:
+ group_id = f"at {hex(id(self))}"
+ warnings.warn(f"Partial task group {group_id} was never mapped!")
+
+ def __call__(self, *args: FParams.args, **kwargs: FParams.kwargs) -> DAGNode:
+ """Instantiate the task group.
+
+ This uses the wrapped function to create a task group. Depending on the
+ return type of the wrapped function, this either returns the last task
+ in the group, or the group itself, to support task chaining.
+ """
+ return self._create_task_group(TaskGroup, *args, **kwargs)
+
+ def _create_task_group(self, tg_factory: Callable[..., TaskGroup], *args: Any, **kwargs: Any) -> DAGNode:
+ with tg_factory(add_suffix_on_collision=True, **self.tg_kwargs) as task_group:
+ if self.function.__doc__ and not task_group.tooltip:
+ task_group.tooltip = self.function.__doc__
- def _make_task_group(self, **kwargs) -> TaskGroup:
- return TaskGroup(**kwargs)
-
- def __call__(self, *args, **kwargs) -> Union[R, TaskGroup]:
- with self._make_task_group(add_suffix_on_collision=True, **self.kwargs) as task_group:
# Invoke function to run Tasks inside the TaskGroup
retval = self.function(*args, **kwargs)
+ self._task_group_created = True
+
# If the task-creating function returns a task, forward the return value
# so dependencies bind to it. This is equivalent to
# with TaskGroup(...) as tg:
@@ -80,28 +113,55 @@ def __call__(self, *args, **kwargs) -> Union[R, TaskGroup]:
# start >> tg >> end
return task_group
-
-class Group(Generic[F]):
- """Declaration of a @task_group-decorated callable for type-checking.
-
- An instance of this type inherits the call signature of the decorated
- function wrapped in it (not *exactly* since it actually turns the function
- into an XComArg-compatible, but there's no way to express that right now),
- and provides two additional methods for task-mapping.
-
- This type is implemented by ``TaskGroupDecorator`` at runtime.
- """
-
- __call__: F
-
- function: F
-
- # Return value should match F's return type, but that's impossible to declare.
- def expand(self, **kwargs: "Mappable") -> Any:
- ...
-
- def partial(self, **kwargs: Any) -> "Group[F]":
- ...
+ def override(self, **kwargs: Any) -> _TaskGroupFactory[FParams, FReturn]:
+ return attr.evolve(self, tg_kwargs={**self.tg_kwargs, **kwargs})
+
+ def partial(self, **kwargs: Any) -> _TaskGroupFactory[FParams, FReturn]:
+ self._validate_arg_names("partial", kwargs)
+ prevent_duplicates(self.partial_kwargs, kwargs, fail_reason="duplicate partial")
+ kwargs.update(self.partial_kwargs)
+ return attr.evolve(self, partial_kwargs=kwargs)
+
+ def expand(self, **kwargs: OperatorExpandArgument) -> DAGNode:
+ if not kwargs:
+ raise TypeError("no arguments to expand against")
+ self._validate_arg_names("expand", kwargs)
+ prevent_duplicates(self.partial_kwargs, kwargs, fail_reason="mapping already partial")
+ expand_input = DictOfListsExpandInput(kwargs)
+ return self._create_task_group(
+ functools.partial(MappedTaskGroup, expand_input=expand_input),
+ **self.partial_kwargs,
+ **{k: MappedArgument(input=expand_input, key=k) for k in kwargs},
+ )
+
+ def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument) -> DAGNode:
+ if isinstance(kwargs, Sequence):
+ for item in kwargs:
+ if not isinstance(item, (XComArg, Mapping)):
+ raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")
+ elif not isinstance(kwargs, XComArg):
+ raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")
+
+ # It's impossible to build a dict of stubs as keyword arguments if the
+ # function uses * or ** wildcard arguments.
+ function_has_vararg = any(
+ v.kind == inspect.Parameter.VAR_POSITIONAL or v.kind == inspect.Parameter.VAR_KEYWORD
+ for v in self.function_signature.parameters.values()
+ )
+ if function_has_vararg:
+ raise TypeError("calling expand_kwargs() on task group function with * or ** is not supported")
+
+ # We can't be sure how each argument is used in the function (well
+ # technically we can with AST but let's not), so we have to create stubs
+ # for every argument, including those with default values.
+ map_kwargs = (k for k in self.function_signature.parameters if k not in self.partial_kwargs)
+
+ expand_input = ListOfDictsExpandInput(kwargs)
+ return self._create_task_group(
+ functools.partial(MappedTaskGroup, expand_input=expand_input),
+ **self.partial_kwargs,
+ **{k: MappedArgument(input=expand_input, key=k) for k in map_kwargs},
+ )
# This covers the @task_group() case. Annotations are copied from the TaskGroup
@@ -113,28 +173,27 @@ def partial(self, **kwargs: Any) -> "Group[F]":
# disastrous if they go out of sync with TaskGroup.
@overload
def task_group(
- group_id: Optional[str] = None,
+ group_id: str | None = None,
prefix_group_id: bool = True,
- parent_group: Optional[TaskGroup] = None,
- dag: Optional["DAG"] = None,
- default_args: Optional[Dict[str, Any]] = None,
+ parent_group: TaskGroup | None = None,
+ dag: DAG | None = None,
+ default_args: dict[str, Any] | None = None,
tooltip: str = "",
ui_color: str = "CornflowerBlue",
ui_fgcolor: str = "#000",
add_suffix_on_collision: bool = False,
-) -> Callable[[F], Group[F]]:
+) -> Callable[[Callable[FParams, FReturn]], _TaskGroupFactory[FParams, FReturn]]:
...
# This covers the @task_group case (no parentheses).
@overload
-def task_group(python_callable: F) -> Group[F]:
+def task_group(python_callable: Callable[FParams, FReturn]) -> _TaskGroupFactory[FParams, FReturn]:
...
def task_group(python_callable=None, **tg_kwargs):
- """
- Python TaskGroup decorator.
+ """Python TaskGroup decorator.
This wraps a function into an Airflow TaskGroup. When used as the
``@task_group()`` form, all arguments are forwarded to the underlying
@@ -143,6 +202,6 @@ def task_group(python_callable=None, **tg_kwargs):
:param python_callable: Function to decorate.
:param tg_kwargs: Keyword arguments for the TaskGroup object.
"""
- if callable(python_callable):
- return TaskGroupDecorator(function=python_callable, kwargs=tg_kwargs)
- return cast(Callable[[F], F], functools.partial(TaskGroupDecorator, kwargs=tg_kwargs))
+ if callable(python_callable) and not tg_kwargs:
+ return _TaskGroupFactory(function=python_callable, tg_kwargs=tg_kwargs)
+ return functools.partial(_TaskGroupFactory, tg_kwargs=tg_kwargs)
diff --git a/airflow/example_dags/example_bash_operator.py b/airflow/example_dags/example_bash_operator.py
index 335f3ad961d3b..2f0343e1584d4 100644
--- a/airflow/example_dags/example_bash_operator.py
+++ b/airflow/example_dags/example_bash_operator.py
@@ -15,8 +15,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Example DAG demonstrating the usage of the BashOperator."""
+from __future__ import annotations
import datetime
@@ -27,22 +27,22 @@
from airflow.operators.empty import EmptyOperator
with DAG(
- dag_id='example_bash_operator',
- schedule_interval='0 0 * * *',
+ dag_id="example_bash_operator",
+ schedule="0 0 * * *",
start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
catchup=False,
dagrun_timeout=datetime.timedelta(minutes=60),
- tags=['example', 'example2'],
+ tags=["example", "example2"],
params={"example_key": "example_value"},
) as dag:
run_this_last = EmptyOperator(
- task_id='run_this_last',
+ task_id="run_this_last",
)
# [START howto_operator_bash]
run_this = BashOperator(
- task_id='run_after_loop',
- bash_command='echo 1',
+ task_id="run_after_loop",
+ bash_command="echo 1",
)
# [END howto_operator_bash]
@@ -50,22 +50,22 @@
for i in range(3):
task = BashOperator(
- task_id='runme_' + str(i),
+ task_id="runme_" + str(i),
bash_command='echo "{{ task_instance_key_str }}" && sleep 1',
)
task >> run_this
# [START howto_operator_bash_template]
also_run_this = BashOperator(
- task_id='also_run_this',
- bash_command='echo "run_id={{ run_id }} | dag_run={{ dag_run }}"',
+ task_id="also_run_this",
+ bash_command='echo "ti_key={{ task_instance_key_str }}"',
)
# [END howto_operator_bash_template]
also_run_this >> run_this_last
# [START howto_operator_bash_skip]
this_will_skip = BashOperator(
- task_id='this_will_skip',
+ task_id="this_will_skip",
bash_command='echo "hello world"; exit 99;',
dag=dag,
)
@@ -73,4 +73,4 @@
this_will_skip >> run_this_last
if __name__ == "__main__":
- dag.cli()
+ dag.test()
diff --git a/airflow/example_dags/example_branch_datetime_operator.py b/airflow/example_dags/example_branch_datetime_operator.py
index 3c86e40402aef..8a4d579b4c03a 100644
--- a/airflow/example_dags/example_branch_datetime_operator.py
+++ b/airflow/example_dags/example_branch_datetime_operator.py
@@ -15,64 +15,90 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""
Example DAG demonstrating the usage of DateTimeBranchOperator with datetime as well as time objects as
targets.
"""
+from __future__ import annotations
+
import pendulum
from airflow import DAG
from airflow.operators.datetime import BranchDateTimeOperator
from airflow.operators.empty import EmptyOperator
-dag = DAG(
+dag1 = DAG(
dag_id="example_branch_datetime_operator",
start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
catchup=False,
tags=["example"],
- schedule_interval="@daily",
+ schedule="@daily",
)
# [START howto_branch_datetime_operator]
-empty_task_1 = EmptyOperator(task_id='date_in_range', dag=dag)
-empty_task_2 = EmptyOperator(task_id='date_outside_range', dag=dag)
+empty_task_11 = EmptyOperator(task_id="date_in_range", dag=dag1)
+empty_task_21 = EmptyOperator(task_id="date_outside_range", dag=dag1)
cond1 = BranchDateTimeOperator(
- task_id='datetime_branch',
- follow_task_ids_if_true=['date_in_range'],
- follow_task_ids_if_false=['date_outside_range'],
+ task_id="datetime_branch",
+ follow_task_ids_if_true=["date_in_range"],
+ follow_task_ids_if_false=["date_outside_range"],
target_upper=pendulum.datetime(2020, 10, 10, 15, 0, 0),
target_lower=pendulum.datetime(2020, 10, 10, 14, 0, 0),
- dag=dag,
+ dag=dag1,
)
-# Run empty_task_1 if cond1 executes between 2020-10-10 14:00:00 and 2020-10-10 15:00:00
-cond1 >> [empty_task_1, empty_task_2]
+# Run empty_task_11 if cond1 executes between 2020-10-10 14:00:00 and 2020-10-10 15:00:00
+cond1 >> [empty_task_11, empty_task_21]
# [END howto_branch_datetime_operator]
-dag = DAG(
+dag2 = DAG(
dag_id="example_branch_datetime_operator_2",
start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
catchup=False,
tags=["example"],
- schedule_interval="@daily",
+ schedule="@daily",
)
# [START howto_branch_datetime_operator_next_day]
-empty_task_1 = EmptyOperator(task_id='date_in_range', dag=dag)
-empty_task_2 = EmptyOperator(task_id='date_outside_range', dag=dag)
+empty_task_12 = EmptyOperator(task_id="date_in_range", dag=dag2)
+empty_task_22 = EmptyOperator(task_id="date_outside_range", dag=dag2)
cond2 = BranchDateTimeOperator(
- task_id='datetime_branch',
- follow_task_ids_if_true=['date_in_range'],
- follow_task_ids_if_false=['date_outside_range'],
+ task_id="datetime_branch",
+ follow_task_ids_if_true=["date_in_range"],
+ follow_task_ids_if_false=["date_outside_range"],
target_upper=pendulum.time(0, 0, 0),
target_lower=pendulum.time(15, 0, 0),
- dag=dag,
+ dag=dag2,
)
# Since target_lower happens after target_upper, target_upper will be moved to the following day
-# Run empty_task_1 if cond2 executes between 15:00:00, and 00:00:00 of the following day
-cond2 >> [empty_task_1, empty_task_2]
+# Run empty_task_12 if cond2 executes between 15:00:00, and 00:00:00 of the following day
+cond2 >> [empty_task_12, empty_task_22]
# [END howto_branch_datetime_operator_next_day]
+
+dag3 = DAG(
+ dag_id="example_branch_datetime_operator_3",
+ start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
+ catchup=False,
+ tags=["example"],
+ schedule="@daily",
+)
+# [START howto_branch_datetime_operator_logical_date]
+empty_task_13 = EmptyOperator(task_id="date_in_range", dag=dag3)
+empty_task_23 = EmptyOperator(task_id="date_outside_range", dag=dag3)
+
+cond3 = BranchDateTimeOperator(
+ task_id="datetime_branch",
+ use_task_logical_date=True,
+ follow_task_ids_if_true=["date_in_range"],
+ follow_task_ids_if_false=["date_outside_range"],
+ target_upper=pendulum.datetime(2020, 10, 10, 15, 0, 0),
+ target_lower=pendulum.datetime(2020, 10, 10, 14, 0, 0),
+ dag=dag3,
+)
+
+# Run empty_task_13 if cond3 executes between 2020-10-10 14:00:00 and 2020-10-10 15:00:00
+cond3 >> [empty_task_13, empty_task_23]
+# [END howto_branch_datetime_operator_logical_date]
diff --git a/airflow/example_dags/example_branch_day_of_week_operator.py b/airflow/example_dags/example_branch_day_of_week_operator.py
index 879824ab1c876..fb7dea56b07fe 100644
--- a/airflow/example_dags/example_branch_day_of_week_operator.py
+++ b/airflow/example_dags/example_branch_day_of_week_operator.py
@@ -15,26 +15,30 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""
Example DAG demonstrating the usage of BranchDayOfWeekOperator.
"""
+from __future__ import annotations
+
import pendulum
from airflow import DAG
from airflow.operators.empty import EmptyOperator
from airflow.operators.weekday import BranchDayOfWeekOperator
+from airflow.utils.weekday import WeekDay
with DAG(
dag_id="example_weekday_branch_operator",
start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
catchup=False,
tags=["example"],
- schedule_interval="@daily",
+ schedule="@daily",
) as dag:
# [START howto_operator_day_of_week_branch]
- empty_task_1 = EmptyOperator(task_id='branch_true')
- empty_task_2 = EmptyOperator(task_id='branch_false')
+ empty_task_1 = EmptyOperator(task_id="branch_true")
+ empty_task_2 = EmptyOperator(task_id="branch_false")
+ empty_task_3 = EmptyOperator(task_id="branch_weekend")
+ empty_task_4 = EmptyOperator(task_id="branch_mid_week")
branch = BranchDayOfWeekOperator(
task_id="make_choice",
@@ -42,7 +46,15 @@
follow_task_ids_if_false="branch_false",
week_day="Monday",
)
+ branch_weekend = BranchDayOfWeekOperator(
+ task_id="make_weekend_choice",
+ follow_task_ids_if_true="branch_weekend",
+ follow_task_ids_if_false="branch_mid_week",
+ week_day={WeekDay.SATURDAY, WeekDay.SUNDAY},
+ )
- # Run empty_task_1 if branch executes on Monday
+ # Run empty_task_1 if branch executes on Monday, empty_task_2 otherwise
branch >> [empty_task_1, empty_task_2]
+ # Run empty_task_3 if it's a weekend, empty_task_4 otherwise
+ empty_task_2 >> branch_weekend >> [empty_task_3, empty_task_4]
# [END howto_operator_day_of_week_branch]
diff --git a/airflow/example_dags/example_branch_labels.py b/airflow/example_dags/example_branch_labels.py
index 72337eb45c873..220bb445cf153 100644
--- a/airflow/example_dags/example_branch_labels.py
+++ b/airflow/example_dags/example_branch_labels.py
@@ -15,10 +15,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""
Example DAG demonstrating the usage of labels with different branches.
"""
+from __future__ import annotations
+
import pendulum
from airflow import DAG
@@ -27,7 +28,7 @@
with DAG(
"example_branch_labels",
- schedule_interval="@daily",
+ schedule="@daily",
start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
catchup=False,
) as dag:
diff --git a/airflow/example_dags/example_branch_operator.py b/airflow/example_dags/example_branch_operator.py
index 8721c78bcbc8e..43bb34eeec68f 100644
--- a/airflow/example_dags/example_branch_operator.py
+++ b/airflow/example_dags/example_branch_operator.py
@@ -15,8 +15,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Example DAG demonstrating the usage of the BranchPythonOperator."""
+from __future__ import annotations
import random
@@ -29,26 +29,26 @@
from airflow.utils.trigger_rule import TriggerRule
with DAG(
- dag_id='example_branch_operator',
+ dag_id="example_branch_operator",
start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
catchup=False,
- schedule_interval="@daily",
- tags=['example', 'example2'],
+ schedule="@daily",
+ tags=["example", "example2"],
) as dag:
run_this_first = EmptyOperator(
- task_id='run_this_first',
+ task_id="run_this_first",
)
- options = ['branch_a', 'branch_b', 'branch_c', 'branch_d']
+ options = ["branch_a", "branch_b", "branch_c", "branch_d"]
branching = BranchPythonOperator(
- task_id='branching',
+ task_id="branching",
python_callable=lambda: random.choice(options),
)
run_this_first >> branching
join = EmptyOperator(
- task_id='join',
+ task_id="join",
trigger_rule=TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS,
)
@@ -58,7 +58,7 @@
)
empty_follow = EmptyOperator(
- task_id='follow_' + option,
+ task_id="follow_" + option,
)
# Label is optional here, but it can help identify more complex branches
diff --git a/airflow/example_dags/example_branch_operator_decorator.py b/airflow/example_dags/example_branch_operator_decorator.py
index 0ab4f76cafa2a..91101dc775aa8 100644
--- a/airflow/example_dags/example_branch_operator_decorator.py
+++ b/airflow/example_dags/example_branch_operator_decorator.py
@@ -15,11 +15,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
-"""Example DAG demonstrating the usage of the BranchPythonOperator."""
+"""Example DAG demonstrating the usage of the ``@task.branch`` TaskFlow API decorator."""
+from __future__ import annotations
import random
-from datetime import datetime
+
+import pendulum
from airflow import DAG
from airflow.decorators import task
@@ -28,39 +29,30 @@
from airflow.utils.trigger_rule import TriggerRule
with DAG(
- dag_id='example_branch_python_operator_decorator',
- start_date=datetime(2021, 1, 1),
+ dag_id="example_branch_python_operator_decorator",
+ start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
catchup=False,
- schedule_interval="@daily",
- tags=['example', 'example2'],
+ schedule="@daily",
+ tags=["example", "example2"],
) as dag:
- run_this_first = EmptyOperator(
- task_id='run_this_first',
- )
+ run_this_first = EmptyOperator(task_id="run_this_first")
- options = ['branch_a', 'branch_b', 'branch_c', 'branch_d']
+ options = ["branch_a", "branch_b", "branch_c", "branch_d"]
@task.branch(task_id="branching")
- def random_choice():
- return random.choice(options)
+ def random_choice(choices: list[str]) -> str:
+ return random.choice(choices)
- random_choice_instance = random_choice()
+ random_choice_instance = random_choice(choices=options)
run_this_first >> random_choice_instance
- join = EmptyOperator(
- task_id='join',
- trigger_rule=TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS,
- )
+ join = EmptyOperator(task_id="join", trigger_rule=TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS)
for option in options:
- t = EmptyOperator(
- task_id=option,
- )
+ t = EmptyOperator(task_id=option)
- empty_follow = EmptyOperator(
- task_id='follow_' + option,
- )
+ empty_follow = EmptyOperator(task_id="follow_" + option)
# Label is optional here, but it can help identify more complex branches
random_choice_instance >> Label(option) >> t >> empty_follow >> join
diff --git a/airflow/example_dags/example_branch_python_dop_operator_3.py b/airflow/example_dags/example_branch_python_dop_operator_3.py
index a8e0ce2c1c8af..f42abb46042be 100644
--- a/airflow/example_dags/example_branch_python_dop_operator_3.py
+++ b/airflow/example_dags/example_branch_python_dop_operator_3.py
@@ -15,48 +15,46 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""
-Example DAG demonstrating the usage of BranchPythonOperator with depends_on_past=True, where tasks may be run
-or skipped on alternating runs.
+Example DAG demonstrating the usage of ``@task.branch`` TaskFlow API decorator with depends_on_past=True,
+where tasks may be run or skipped on alternating runs.
"""
+from __future__ import annotations
+
import pendulum
from airflow import DAG
+from airflow.decorators import task
from airflow.operators.empty import EmptyOperator
-from airflow.operators.python import BranchPythonOperator
-def should_run(**kwargs):
+@task.branch()
+def should_run(**kwargs) -> str:
"""
Determine which empty_task should be run based on if the execution date minute is even or odd.
:param dict kwargs: Context
:return: Id of the task to run
- :rtype: str
"""
print(
f"------------- exec dttm = {kwargs['execution_date']} and minute = {kwargs['execution_date'].minute}"
)
- if kwargs['execution_date'].minute % 2 == 0:
+ if kwargs["execution_date"].minute % 2 == 0:
return "empty_task_1"
else:
return "empty_task_2"
with DAG(
- dag_id='example_branch_dop_operator_v3',
- schedule_interval='*/1 * * * *',
+ dag_id="example_branch_dop_operator_v3",
+ schedule="*/1 * * * *",
start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
catchup=False,
- default_args={'depends_on_past': True},
- tags=['example'],
+ default_args={"depends_on_past": True},
+ tags=["example"],
) as dag:
- cond = BranchPythonOperator(
- task_id='condition',
- python_callable=should_run,
- )
+ cond = should_run()
- empty_task_1 = EmptyOperator(task_id='empty_task_1')
- empty_task_2 = EmptyOperator(task_id='empty_task_2')
+ empty_task_1 = EmptyOperator(task_id="empty_task_1")
+ empty_task_2 = EmptyOperator(task_id="empty_task_2")
cond >> [empty_task_1, empty_task_2]
diff --git a/airflow/example_dags/example_complex.py b/airflow/example_dags/example_complex.py
index 22e1906c042dd..f7c7e25aff62a 100644
--- a/airflow/example_dags/example_complex.py
+++ b/airflow/example_dags/example_complex.py
@@ -15,10 +15,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""
Example Airflow DAG that shows the complex DAG structure.
"""
+from __future__ import annotations
+
import pendulum
from airflow import models
@@ -27,10 +28,10 @@
with models.DAG(
dag_id="example_complex",
- schedule_interval=None,
+ schedule=None,
start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
catchup=False,
- tags=['example', 'example2', 'example3'],
+ tags=["example", "example2", "example3"],
) as dag:
# Create
create_entry_group = BashOperator(task_id="create_entry_group", bash_command="echo create_entry_group")
diff --git a/airflow/example_dags/example_dag_decorator.py b/airflow/example_dags/example_dag_decorator.py
index 88e0282016dd2..e8ee8a72997a2 100644
--- a/airflow/example_dags/example_dag_decorator.py
+++ b/airflow/example_dags/example_dag_decorator.py
@@ -15,7 +15,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict
+from __future__ import annotations
+
+from typing import Any
import httpx
import pendulum
@@ -39,33 +41,33 @@ def execute(self, context: Context):
# [START dag_decorator_usage]
@dag(
- schedule_interval=None,
+ schedule=None,
start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
catchup=False,
- tags=['example'],
+ tags=["example"],
)
-def example_dag_decorator(email: str = 'example@example.com'):
+def example_dag_decorator(email: str = "example@example.com"):
"""
DAG to send server IP to email.
:param email: Email to send IP to. Defaults to example@example.com.
"""
- get_ip = GetRequestOperator(task_id='get_ip', url="http://httpbin.org/get")
+ get_ip = GetRequestOperator(task_id="get_ip", url="http://httpbin.org/get")
@task(multiple_outputs=True)
- def prepare_email(raw_json: Dict[str, Any]) -> Dict[str, str]:
- external_ip = raw_json['origin']
+ def prepare_email(raw_json: dict[str, Any]) -> dict[str, str]:
+ external_ip = raw_json["origin"]
return {
- 'subject': f'Server connected from {external_ip}',
- 'body': f'Seems like today your server executing Airflow is connected from IP {external_ip}
',
+ "subject": f"Server connected from {external_ip}",
+ "body": f"Seems like today your server executing Airflow is connected from IP {external_ip}
",
}
email_info = prepare_email(get_ip.output)
EmailOperator(
- task_id='send_email', to=email, subject=email_info['subject'], html_content=email_info['body']
+ task_id="send_email", to=email, subject=email_info["subject"], html_content=email_info["body"]
)
-dag = example_dag_decorator()
+example_dag = example_dag_decorator()
# [END dag_decorator_usage]
diff --git a/airflow/example_dags/example_datasets.py b/airflow/example_dags/example_datasets.py
new file mode 100644
index 0000000000000..c2db9b1968614
--- /dev/null
+++ b/airflow/example_dags/example_datasets.py
@@ -0,0 +1,129 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Example DAG for demonstrating behavior of Datasets feature.
+
+Notes on usage:
+
+Turn on all the dags.
+
+DAG dataset_produces_1 should run because it's on a schedule.
+
+After dataset_produces_1 runs, dataset_consumes_1 should be triggered immediately
+because its only dataset dependency is managed by dataset_produces_1.
+
+No other dags should be triggered. Note that even though dataset_consumes_1_and_2 depends on
+the dataset in dataset_produces_1, it will not be triggered until dataset_produces_2 runs
+(and dataset_produces_2 is left with no schedule so that we can trigger it manually).
+
+Next, trigger dataset_produces_2. After dataset_produces_2 finishes,
+dataset_consumes_1_and_2 should run.
+
+Dags dataset_consumes_1_never_scheduled and dataset_consumes_unknown_never_scheduled should not run because
+they depend on datasets that never get updated.
+"""
+from __future__ import annotations
+
+import pendulum
+
+from airflow import DAG, Dataset
+from airflow.operators.bash import BashOperator
+
+# [START dataset_def]
+dag1_dataset = Dataset("s3://dag1/output_1.txt", extra={"hi": "bye"})
+# [END dataset_def]
+dag2_dataset = Dataset("s3://dag2/output_1.txt", extra={"hi": "bye"})
+
+with DAG(
+ dag_id="dataset_produces_1",
+ catchup=False,
+ start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
+ schedule="@daily",
+ tags=["produces", "dataset-scheduled"],
+) as dag1:
+ # [START task_outlet]
+ BashOperator(outlets=[dag1_dataset], task_id="producing_task_1", bash_command="sleep 5")
+ # [END task_outlet]
+
+with DAG(
+ dag_id="dataset_produces_2",
+ catchup=False,
+ start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
+ schedule=None,
+ tags=["produces", "dataset-scheduled"],
+) as dag2:
+ BashOperator(outlets=[dag2_dataset], task_id="producing_task_2", bash_command="sleep 5")
+
+# [START dag_dep]
+with DAG(
+ dag_id="dataset_consumes_1",
+ catchup=False,
+ start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
+ schedule=[dag1_dataset],
+ tags=["consumes", "dataset-scheduled"],
+) as dag3:
+ # [END dag_dep]
+ BashOperator(
+ outlets=[Dataset("s3://consuming_1_task/dataset_other.txt")],
+ task_id="consuming_1",
+ bash_command="sleep 5",
+ )
+
+with DAG(
+ dag_id="dataset_consumes_1_and_2",
+ catchup=False,
+ start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
+ schedule=[dag1_dataset, dag2_dataset],
+ tags=["consumes", "dataset-scheduled"],
+) as dag4:
+ BashOperator(
+ outlets=[Dataset("s3://consuming_2_task/dataset_other_unknown.txt")],
+ task_id="consuming_2",
+ bash_command="sleep 5",
+ )
+
+with DAG(
+ dag_id="dataset_consumes_1_never_scheduled",
+ catchup=False,
+ start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
+ schedule=[
+ dag1_dataset,
+ Dataset("s3://this-dataset-doesnt-get-triggered"),
+ ],
+ tags=["consumes", "dataset-scheduled"],
+) as dag5:
+ BashOperator(
+ outlets=[Dataset("s3://consuming_2_task/dataset_other_unknown.txt")],
+ task_id="consuming_3",
+ bash_command="sleep 5",
+ )
+
+with DAG(
+ dag_id="dataset_consumes_unknown_never_scheduled",
+ catchup=False,
+ start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
+ schedule=[
+ Dataset("s3://unrelated/dataset3.txt"),
+ Dataset("s3://unrelated/dataset_other_unknown.txt"),
+ ],
+ tags=["dataset-scheduled"],
+) as dag6:
+ BashOperator(
+ task_id="unrelated_task",
+ outlets=[Dataset("s3://unrelated_task/dataset_other_unknown.txt")],
+ bash_command="sleep 5",
+ )
diff --git a/airflow/example_dags/example_dynamic_task_mapping.py b/airflow/example_dags/example_dynamic_task_mapping.py
new file mode 100644
index 0000000000000..dce6cda20972c
--- /dev/null
+++ b/airflow/example_dags/example_dynamic_task_mapping.py
@@ -0,0 +1,38 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Example DAG demonstrating the usage of dynamic task mapping."""
+from __future__ import annotations
+
+from datetime import datetime
+
+from airflow import DAG
+from airflow.decorators import task
+
+with DAG(dag_id="example_dynamic_task_mapping", start_date=datetime(2022, 3, 4)) as dag:
+
+ @task
+ def add_one(x: int):
+ return x + 1
+
+ @task
+ def sum_it(values):
+ total = sum(values)
+ print(f"Total was {total}")
+
+ added_values = add_one.expand(x=[1, 2, 3])
+ sum_it(added_values)
diff --git a/airflow/example_dags/example_external_task_marker_dag.py b/airflow/example_dags/example_external_task_marker_dag.py
index 0c4479a0d66f0..1b4f5d3ffd8f2 100644
--- a/airflow/example_dags/example_external_task_marker_dag.py
+++ b/airflow/example_dags/example_external_task_marker_dag.py
@@ -15,27 +15,29 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""
Example DAG demonstrating setting up inter-DAG dependencies using ExternalTaskSensor and
-ExternalTaskMarker
+ExternalTaskMarker.
In this example, child_task1 in example_external_task_marker_child depends on parent_task in
-example_external_task_marker_parent. When parent_task is cleared with "Recursive" selected,
-the presence of ExternalTaskMarker tells Airflow to clear child_task1 and its
-downstream tasks.
+example_external_task_marker_parent. When parent_task is cleared with 'Recursive' selected,
+the presence of ExternalTaskMarker tells Airflow to clear child_task1 and its downstream tasks.
ExternalTaskSensor will keep poking for the status of remote ExternalTaskMarker task at a regular
interval till one of the following will happen:
-1. ExternalTaskMarker reaches the states mentioned in the allowed_states list
- In this case, ExternalTaskSensor will exit with a success status code
-2. ExternalTaskMarker reaches the states mentioned in the failed_states list
- In this case, ExternalTaskSensor will raise an AirflowException and user need to handle this
- with multiple downstream tasks
-3. ExternalTaskSensor times out
- In this case, ExternalTaskSensor will raise AirflowSkipException or AirflowSensorTimeout
- exception
+
+ExternalTaskMarker reaches the states mentioned in the allowed_states list.
+In this case, ExternalTaskSensor will exit with a success status code
+
+ExternalTaskMarker reaches the states mentioned in the failed_states list
+In this case, ExternalTaskSensor will raise an AirflowException and user need to handle this
+with multiple downstream tasks
+
+ExternalTaskSensor times out. In this case, ExternalTaskSensor will raise AirflowSkipException
+or AirflowSensorTimeout exception
+
"""
+from __future__ import annotations
import pendulum
@@ -49,8 +51,8 @@
dag_id="example_external_task_marker_parent",
start_date=start_date,
catchup=False,
- schedule_interval=None,
- tags=['example2'],
+ schedule=None,
+ tags=["example2"],
) as parent_dag:
# [START howto_operator_external_task_marker]
parent_task = ExternalTaskMarker(
@@ -63,9 +65,9 @@
with DAG(
dag_id="example_external_task_marker_child",
start_date=start_date,
- schedule_interval=None,
+ schedule=None,
catchup=False,
- tags=['example2'],
+ tags=["example2"],
) as child_dag:
# [START howto_operator_external_task_sensor]
child_task1 = ExternalTaskSensor(
@@ -73,10 +75,23 @@
external_dag_id=parent_dag.dag_id,
external_task_id=parent_task.task_id,
timeout=600,
- allowed_states=['success'],
- failed_states=['failed', 'skipped'],
+ allowed_states=["success"],
+ failed_states=["failed", "skipped"],
mode="reschedule",
)
# [END howto_operator_external_task_sensor]
- child_task2 = EmptyOperator(task_id="child_task2")
- child_task1 >> child_task2
+
+ # [START howto_operator_external_task_sensor_with_task_group]
+ child_task2 = ExternalTaskSensor(
+ task_id="child_task2",
+ external_dag_id=parent_dag.dag_id,
+ external_task_group_id="parent_dag_task_group_id",
+ timeout=600,
+ allowed_states=["success"],
+ failed_states=["failed", "skipped"],
+ mode="reschedule",
+ )
+ # [END howto_operator_external_task_sensor_with_task_group]
+
+ child_task3 = EmptyOperator(task_id="child_task3")
+ child_task1 >> child_task2 >> child_task3
diff --git a/airflow/example_dags/example_kubernetes_executor.py b/airflow/example_dags/example_kubernetes_executor.py
index f278eb379b1e8..b9e6bdba35445 100644
--- a/airflow/example_dags/example_kubernetes_executor.py
+++ b/airflow/example_dags/example_kubernetes_executor.py
@@ -18,6 +18,8 @@
"""
This is an example dag for using a Kubernetes Executor Configuration.
"""
+from __future__ import annotations
+
import logging
import os
@@ -30,8 +32,8 @@
log = logging.getLogger(__name__)
-worker_container_repository = conf.get('kubernetes', 'worker_container_repository')
-worker_container_tag = conf.get('kubernetes', 'worker_container_tag')
+worker_container_repository = conf.get("kubernetes_executor", "worker_container_repository")
+worker_container_tag = conf.get("kubernetes_executor", "worker_container_tag")
try:
from kubernetes.client import models as k8s
@@ -45,11 +47,11 @@
if k8s:
with DAG(
- dag_id='example_kubernetes_executor',
- schedule_interval=None,
+ dag_id="example_kubernetes_executor",
+ schedule=None,
start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
catchup=False,
- tags=['example3'],
+ tags=["example3"],
) as dag:
# You can use annotations on your kubernetes pods!
start_task_executor_config = {
@@ -88,8 +90,8 @@ def test_volume_mount():
Tests whether the volume has been mounted.
"""
- with open('/foo/volume_mount_test.txt', 'w') as foo:
- foo.write('Hello')
+ with open("/foo/volume_mount_test.txt", "w") as foo:
+ foo.write("Hello")
return_code = os.system("cat /foo/volume_mount_test.txt")
if return_code != 0:
@@ -110,7 +112,7 @@ def test_volume_mount():
k8s.V1Container(
name="sidecar",
image="ubuntu",
- args=["echo \"retrieved from mount\" > /shared/test.txt"],
+ args=['echo "retrieved from mount" > /shared/test.txt'],
command=["bash", "-cx"],
volume_mounts=[k8s.V1VolumeMount(mount_path="/shared/", name="shared-empty-dir")],
),
@@ -152,7 +154,7 @@ def non_root_task():
executor_config_other_ns = {
"pod_override": k8s.V1Pod(
- metadata=k8s.V1ObjectMeta(namespace="test-namespace", labels={'release': 'stable'})
+ metadata=k8s.V1ObjectMeta(namespace="test-namespace", labels={"release": "stable"})
)
}
@@ -191,21 +193,21 @@ def base_image_override_task():
k8s.V1PodAffinityTerm(
label_selector=k8s.V1LabelSelector(
match_expressions=[
- k8s.V1LabelSelectorRequirement(key='app', operator='In', values=['airflow'])
+ k8s.V1LabelSelectorRequirement(key="app", operator="In", values=["airflow"])
]
),
- topology_key='kubernetes.io/hostname',
+ topology_key="kubernetes.io/hostname",
)
]
)
)
# Use k8s_client.V1Toleration to define node tolerations
- k8s_tolerations = [k8s.V1Toleration(key='dedicated', operator='Equal', value='airflow')]
+ k8s_tolerations = [k8s.V1Toleration(key="dedicated", operator="Equal", value="airflow")]
# Use k8s_client.V1ResourceRequirements to define resource limits
k8s_resource_requirements = k8s.V1ResourceRequirements(
- requests={'memory': '512Mi'}, limits={'memory': '512Mi'}
+ requests={"memory": "512Mi"}, limits={"memory": "512Mi"}
)
kube_exec_config_resource_limits = {
diff --git a/airflow/example_dags/example_latest_only.py b/airflow/example_dags/example_latest_only.py
index 92ec1436a6951..3dc4e9137e350 100644
--- a/airflow/example_dags/example_latest_only.py
+++ b/airflow/example_dags/example_latest_only.py
@@ -15,8 +15,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Example of the LatestOnlyOperator"""
+from __future__ import annotations
import datetime as dt
@@ -25,13 +25,13 @@
from airflow.operators.latest_only import LatestOnlyOperator
with DAG(
- dag_id='latest_only',
- schedule_interval=dt.timedelta(hours=4),
+ dag_id="latest_only",
+ schedule=dt.timedelta(hours=4),
start_date=dt.datetime(2021, 1, 1),
catchup=False,
- tags=['example2', 'example3'],
+ tags=["example2", "example3"],
) as dag:
- latest_only = LatestOnlyOperator(task_id='latest_only')
- task1 = EmptyOperator(task_id='task1')
+ latest_only = LatestOnlyOperator(task_id="latest_only")
+ task1 = EmptyOperator(task_id="task1")
latest_only >> task1
diff --git a/airflow/example_dags/example_latest_only_with_trigger.py b/airflow/example_dags/example_latest_only_with_trigger.py
index 56d6a24462217..e71d7e05c263c 100644
--- a/airflow/example_dags/example_latest_only_with_trigger.py
+++ b/airflow/example_dags/example_latest_only_with_trigger.py
@@ -18,6 +18,7 @@
"""
Example LatestOnlyOperator and TriggerRule interactions
"""
+from __future__ import annotations
# [START example]
import datetime
@@ -30,17 +31,17 @@
from airflow.utils.trigger_rule import TriggerRule
with DAG(
- dag_id='latest_only_with_trigger',
- schedule_interval=datetime.timedelta(hours=4),
+ dag_id="latest_only_with_trigger",
+ schedule=datetime.timedelta(hours=4),
start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
catchup=False,
- tags=['example3'],
+ tags=["example3"],
) as dag:
- latest_only = LatestOnlyOperator(task_id='latest_only')
- task1 = EmptyOperator(task_id='task1')
- task2 = EmptyOperator(task_id='task2')
- task3 = EmptyOperator(task_id='task3')
- task4 = EmptyOperator(task_id='task4', trigger_rule=TriggerRule.ALL_DONE)
+ latest_only = LatestOnlyOperator(task_id="latest_only")
+ task1 = EmptyOperator(task_id="task1")
+ task2 = EmptyOperator(task_id="task2")
+ task3 = EmptyOperator(task_id="task3")
+ task4 = EmptyOperator(task_id="task4", trigger_rule=TriggerRule.ALL_DONE)
latest_only >> task1 >> [task3, task4]
task2 >> [task3, task4]
diff --git a/airflow/example_dags/example_local_kubernetes_executor.py b/airflow/example_dags/example_local_kubernetes_executor.py
index e586cafda9416..db2d7acf0dbe5 100644
--- a/airflow/example_dags/example_local_kubernetes_executor.py
+++ b/airflow/example_dags/example_local_kubernetes_executor.py
@@ -18,6 +18,8 @@
"""
This is an example dag for using a Local Kubernetes Executor Configuration.
"""
+from __future__ import annotations
+
import logging
from datetime import datetime
@@ -28,8 +30,8 @@
log = logging.getLogger(__name__)
-worker_container_repository = conf.get('kubernetes', 'worker_container_repository')
-worker_container_tag = conf.get('kubernetes', 'worker_container_tag')
+worker_container_repository = conf.get("kubernetes_executor", "worker_container_repository")
+worker_container_tag = conf.get("kubernetes_executor", "worker_container_tag")
try:
from kubernetes.client import models as k8s
@@ -40,11 +42,11 @@
if k8s:
with DAG(
- dag_id='example_local_kubernetes_executor',
- schedule_interval=None,
+ dag_id="example_local_kubernetes_executor",
+ schedule=None,
start_date=datetime(2021, 1, 1),
catchup=False,
- tags=['example3'],
+ tags=["example3"],
) as dag:
# You can use annotations on your kubernetes pods!
start_task_executor_config = {
@@ -53,17 +55,17 @@
@task(
executor_config=start_task_executor_config,
- queue='kubernetes',
- task_id='task_with_kubernetes_executor',
+ queue="kubernetes",
+ task_id="task_with_kubernetes_executor",
)
def task_with_template():
print_stuff()
- @task(task_id='task_with_local_executor')
+ @task(task_id="task_with_local_executor")
def task_with_local(ds=None, **kwargs):
"""Print the Airflow context and ds variable from the context."""
print(kwargs)
print(ds)
- return 'Whatever you return gets printed in the logs'
+ return "Whatever you return gets printed in the logs"
task_with_local() >> task_with_template()
diff --git a/airflow/example_dags/example_nested_branch_dag.py b/airflow/example_dags/example_nested_branch_dag.py
index 14ac0a43ce245..7c46592455983 100644
--- a/airflow/example_dags/example_nested_branch_dag.py
+++ b/airflow/example_dags/example_nested_branch_dag.py
@@ -15,31 +15,38 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""
Example DAG demonstrating a workflow with nested branching. The join tasks are created with
``none_failed_min_one_success`` trigger rule such that they are skipped whenever their corresponding
-``BranchPythonOperator`` are skipped.
+branching tasks are skipped.
"""
+from __future__ import annotations
+
import pendulum
+from airflow.decorators import task
from airflow.models import DAG
from airflow.operators.empty import EmptyOperator
-from airflow.operators.python import BranchPythonOperator
from airflow.utils.trigger_rule import TriggerRule
with DAG(
dag_id="example_nested_branch_dag",
start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
catchup=False,
- schedule_interval="@daily",
+ schedule="@daily",
tags=["example"],
) as dag:
- branch_1 = BranchPythonOperator(task_id="branch_1", python_callable=lambda: "true_1")
+
+ @task.branch()
+ def branch(task_id_to_return: str) -> str:
+ return task_id_to_return
+
+ branch_1 = branch.override(task_id="branch_1")(task_id_to_return="true_1")
join_1 = EmptyOperator(task_id="join_1", trigger_rule=TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS)
true_1 = EmptyOperator(task_id="true_1")
false_1 = EmptyOperator(task_id="false_1")
- branch_2 = BranchPythonOperator(task_id="branch_2", python_callable=lambda: "true_2")
+
+ branch_2 = branch.override(task_id="branch_2")(task_id_to_return="true_2")
join_2 = EmptyOperator(task_id="join_2", trigger_rule=TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS)
true_2 = EmptyOperator(task_id="true_2")
false_2 = EmptyOperator(task_id="false_2")
diff --git a/airflow/example_dags/example_passing_params_via_test_command.py b/airflow/example_dags/example_passing_params_via_test_command.py
index 8057d5fd54a13..055c8639f90a4 100644
--- a/airflow/example_dags/example_passing_params_via_test_command.py
+++ b/airflow/example_dags/example_passing_params_via_test_command.py
@@ -15,8 +15,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Example DAG demonstrating the usage of the params arguments in templated arguments."""
+from __future__ import annotations
import datetime
import os
@@ -59,11 +59,11 @@ def print_env_vars(test_mode=None):
with DAG(
"example_passing_params_via_test_command",
- schedule_interval='*/1 * * * *',
+ schedule="*/1 * * * *",
start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
catchup=False,
dagrun_timeout=datetime.timedelta(minutes=4),
- tags=['example'],
+ tags=["example"],
) as dag:
run_this = my_py_command(params={"miff": "agg"})
@@ -75,7 +75,7 @@ def print_env_vars(test_mode=None):
)
also_run_this = BashOperator(
- task_id='also_run_this',
+ task_id="also_run_this",
bash_command=my_command,
params={"miff": "agg"},
env={"FOO": "{{ params.foo }}", "MIFF": "{{ params.miff }}"},
diff --git a/airflow/example_dags/example_python_operator.py b/airflow/example_dags/example_python_operator.py
index 0f9a7fc476acb..4f891abe60d02 100644
--- a/airflow/example_dags/example_python_operator.py
+++ b/airflow/example_dags/example_python_operator.py
@@ -15,13 +15,16 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""
Example DAG demonstrating the usage of the TaskFlow API to execute Python functions natively and within a
virtual environment.
"""
+from __future__ import annotations
+
import logging
import shutil
+import sys
+import tempfile
import time
from pprint import pprint
@@ -29,39 +32,58 @@
from airflow import DAG
from airflow.decorators import task
+from airflow.operators.python import ExternalPythonOperator, PythonVirtualenvOperator
log = logging.getLogger(__name__)
+PATH_TO_PYTHON_BINARY = sys.executable
+
+BASE_DIR = tempfile.gettempdir()
+
+
+def x():
+ pass
+
+
with DAG(
- dag_id='example_python_operator',
- schedule_interval=None,
+ dag_id="example_python_operator",
+ schedule=None,
start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
catchup=False,
- tags=['example'],
+ tags=["example"],
) as dag:
+
# [START howto_operator_python]
@task(task_id="print_the_context")
def print_context(ds=None, **kwargs):
"""Print the Airflow context and ds variable from the context."""
pprint(kwargs)
print(ds)
- return 'Whatever you return gets printed in the logs'
+ return "Whatever you return gets printed in the logs"
run_this = print_context()
# [END howto_operator_python]
+ # [START howto_operator_python_render_sql]
+ @task(task_id="log_sql_query", templates_dict={"query": "sql/sample.sql"}, templates_exts=[".sql"])
+ def log_sql(**kwargs):
+ logging.info("Python task decorator query: %s", str(kwargs["templates_dict"]["query"]))
+
+ log_the_sql = log_sql()
+ # [END howto_operator_python_render_sql]
+
# [START howto_operator_python_kwargs]
# Generate 5 sleeping tasks, sleeping from 0.0 to 0.4 seconds respectively
for i in range(5):
- @task(task_id=f'sleep_for_{i}')
+ @task(task_id=f"sleep_for_{i}")
def my_sleeping_function(random_base):
"""This is a function that will run within the DAG execution"""
time.sleep(random_base)
sleeping_task = my_sleeping_function(random_base=float(i) / 10)
- run_this >> sleeping_task
+ run_this >> log_the_sql >> sleeping_task
# [END howto_operator_python_kwargs]
if not shutil.which("virtualenv"):
@@ -82,14 +104,56 @@ def callable_virtualenv():
from colorama import Back, Fore, Style
- print(Fore.RED + 'some red text')
- print(Back.GREEN + 'and with a green background')
- print(Style.DIM + 'and in dim text')
+ print(Fore.RED + "some red text")
+ print(Back.GREEN + "and with a green background")
+ print(Style.DIM + "and in dim text")
print(Style.RESET_ALL)
- for _ in range(10):
- print(Style.DIM + 'Please wait...', flush=True)
- sleep(10)
- print('Finished')
+ for _ in range(4):
+ print(Style.DIM + "Please wait...", flush=True)
+ sleep(1)
+ print("Finished")
virtualenv_task = callable_virtualenv()
# [END howto_operator_python_venv]
+
+ sleeping_task >> virtualenv_task
+
+ # [START howto_operator_external_python]
+ @task.external_python(task_id="external_python", python=PATH_TO_PYTHON_BINARY)
+ def callable_external_python():
+ """
+ Example function that will be performed in a virtual environment.
+
+ Importing at the module level ensures that it will not attempt to import the
+ library before it is installed.
+ """
+ import sys
+ from time import sleep
+
+ print(f"Running task via {sys.executable}")
+ print("Sleeping")
+ for _ in range(4):
+ print("Please wait...", flush=True)
+ sleep(1)
+ print("Finished")
+
+ external_python_task = callable_external_python()
+ # [END howto_operator_external_python]
+
+ # [START howto_operator_external_python_classic]
+ external_classic = ExternalPythonOperator(
+ task_id="external_python_classic",
+ python=PATH_TO_PYTHON_BINARY,
+ python_callable=x,
+ )
+ # [END howto_operator_external_python_classic]
+
+ # [START howto_operator_python_venv_classic]
+ virtual_classic = PythonVirtualenvOperator(
+ task_id="virtualenv_classic",
+ requirements="colorama==0.4.0",
+ python_callable=x,
+ )
+ # [END howto_operator_python_venv_classic]
+
+ run_this >> external_classic >> external_python_task >> virtual_classic
diff --git a/airflow/example_dags/example_sensor_decorator.py b/airflow/example_dags/example_sensor_decorator.py
new file mode 100644
index 0000000000000..2197a6c53af74
--- /dev/null
+++ b/airflow/example_dags/example_sensor_decorator.py
@@ -0,0 +1,67 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Example DAG demonstrating the usage of the sensor decorator."""
+
+from __future__ import annotations
+
+# [START tutorial]
+# [START import_module]
+import pendulum
+
+from airflow.decorators import dag, task
+from airflow.sensors.base import PokeReturnValue
+
+# [END import_module]
+
+
+# [START instantiate_dag]
+@dag(
+ schedule_interval=None,
+ start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
+ catchup=False,
+ tags=["example"],
+)
+def example_sensor_decorator():
+ # [END instantiate_dag]
+
+ # [START wait_function]
+ # Using a sensor operator to wait for the upstream data to be ready.
+ @task.sensor(poke_interval=60, timeout=3600, mode="reschedule")
+ def wait_for_upstream() -> PokeReturnValue:
+ return PokeReturnValue(is_done=True, xcom_value="xcom_value")
+
+ # [END wait_function]
+
+ # [START dummy_function]
+ @task
+ def dummy_operator() -> None:
+ pass
+
+ # [END dummy_function]
+
+ # [START main_flow]
+ wait_for_upstream() >> dummy_operator()
+ # [END main_flow]
+
+
+# [START dag_invocation]
+tutorial_etl_dag = example_sensor_decorator()
+# [END dag_invocation]
+
+# [END tutorial]
diff --git a/airflow/example_dags/example_sensors.py b/airflow/example_dags/example_sensors.py
new file mode 100644
index 0000000000000..d9e3158f544bc
--- /dev/null
+++ b/airflow/example_dags/example_sensors.py
@@ -0,0 +1,123 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from datetime import datetime, timedelta
+
+import pendulum
+from pytz import UTC
+
+from airflow.models import DAG
+from airflow.operators.bash import BashOperator
+from airflow.sensors.bash import BashSensor
+from airflow.sensors.filesystem import FileSensor
+from airflow.sensors.python import PythonSensor
+from airflow.sensors.time_delta import TimeDeltaSensor, TimeDeltaSensorAsync
+from airflow.sensors.time_sensor import TimeSensor, TimeSensorAsync
+from airflow.sensors.weekday import DayOfWeekSensor
+from airflow.utils.trigger_rule import TriggerRule
+from airflow.utils.weekday import WeekDay
+
+
+# [START example_callables]
+def success_callable():
+ return True
+
+
+def failure_callable():
+ return False
+
+
+# [END example_callables]
+
+
+with DAG(
+ dag_id="example_sensors",
+ schedule=None,
+ start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
+ catchup=False,
+ tags=["example"],
+) as dag:
+ # [START example_time_delta_sensor]
+ t0 = TimeDeltaSensor(task_id="wait_some_seconds", delta=timedelta(seconds=2))
+ # [END example_time_delta_sensor]
+
+ # [START example_time_delta_sensor_async]
+ t0a = TimeDeltaSensorAsync(task_id="wait_some_seconds_async", delta=timedelta(seconds=2))
+ # [END example_time_delta_sensor_async]
+
+ # [START example_time_sensors]
+ t1 = TimeSensor(task_id="fire_immediately", target_time=datetime.now(tz=UTC).time())
+
+ t2 = TimeSensor(
+ task_id="timeout_after_second_date_in_the_future",
+ timeout=1,
+ soft_fail=True,
+ target_time=(datetime.now(tz=UTC) + timedelta(hours=1)).time(),
+ )
+ # [END example_time_sensors]
+
+ # [START example_time_sensors_async]
+ t1a = TimeSensorAsync(task_id="fire_immediately_async", target_time=datetime.now(tz=UTC).time())
+
+ t2a = TimeSensorAsync(
+ task_id="timeout_after_second_date_in_the_future_async",
+ timeout=1,
+ soft_fail=True,
+ target_time=(datetime.now(tz=UTC) + timedelta(hours=1)).time(),
+ )
+ # [END example_time_sensors_async]
+
+ # [START example_bash_sensors]
+ t3 = BashSensor(task_id="Sensor_succeeds", bash_command="exit 0")
+
+ t4 = BashSensor(task_id="Sensor_fails_after_3_seconds", timeout=3, soft_fail=True, bash_command="exit 1")
+ # [END example_bash_sensors]
+
+ t5 = BashOperator(task_id="remove_file", bash_command="rm -rf /tmp/temporary_file_for_testing")
+
+ # [START example_file_sensor]
+ t6 = FileSensor(task_id="wait_for_file", filepath="/tmp/temporary_file_for_testing")
+ # [END example_file_sensor]
+
+ t7 = BashOperator(
+ task_id="create_file_after_3_seconds", bash_command="sleep 3; touch /tmp/temporary_file_for_testing"
+ )
+
+ # [START example_python_sensors]
+ t8 = PythonSensor(task_id="success_sensor_python", python_callable=success_callable)
+
+ t9 = PythonSensor(
+ task_id="failure_timeout_sensor_python", timeout=3, soft_fail=True, python_callable=failure_callable
+ )
+ # [END example_python_sensors]
+
+ # [START example_day_of_week_sensor]
+ t10 = DayOfWeekSensor(
+ task_id="week_day_sensor_failing_on_timeout", timeout=3, soft_fail=True, week_day=WeekDay.MONDAY
+ )
+ # [END example_day_of_week_sensor]
+
+ tx = BashOperator(task_id="print_date_in_bash", bash_command="date")
+
+ tx.trigger_rule = TriggerRule.NONE_FAILED
+ [t0, t0a, t1, t1a, t2, t2a, t3, t4] >> tx
+ t5 >> t6 >> tx
+ t7 >> tx
+ [t8, t9] >> tx
+ t10 >> tx
diff --git a/airflow/example_dags/example_short_circuit_decorator.py b/airflow/example_dags/example_short_circuit_decorator.py
new file mode 100644
index 0000000000000..30f6cd0e012bb
--- /dev/null
+++ b/airflow/example_dags/example_short_circuit_decorator.py
@@ -0,0 +1,60 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Example DAG demonstrating the usage of the `@task.short_circuit()` TaskFlow decorator."""
+from __future__ import annotations
+
+import pendulum
+
+from airflow.decorators import dag, task
+from airflow.models.baseoperator import chain
+from airflow.operators.empty import EmptyOperator
+from airflow.utils.trigger_rule import TriggerRule
+
+
+@dag(start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), catchup=False, tags=["example"])
+def example_short_circuit_decorator():
+ # [START howto_operator_short_circuit]
+ @task.short_circuit()
+ def check_condition(condition):
+ return condition
+
+ ds_true = [EmptyOperator(task_id="true_" + str(i)) for i in [1, 2]]
+ ds_false = [EmptyOperator(task_id="false_" + str(i)) for i in [1, 2]]
+
+ condition_is_true = check_condition.override(task_id="condition_is_true")(condition=True)
+ condition_is_false = check_condition.override(task_id="condition_is_false")(condition=False)
+
+ chain(condition_is_true, *ds_true)
+ chain(condition_is_false, *ds_false)
+ # [END howto_operator_short_circuit]
+
+ # [START howto_operator_short_circuit_trigger_rules]
+ [task_1, task_2, task_3, task_4, task_5, task_6] = [
+ EmptyOperator(task_id=f"task_{i}") for i in range(1, 7)
+ ]
+
+ task_7 = EmptyOperator(task_id="task_7", trigger_rule=TriggerRule.ALL_DONE)
+
+ short_circuit = check_condition.override(task_id="short_circuit", ignore_downstream_trigger_rules=False)(
+ condition=False
+ )
+
+ chain(task_1, [task_2, short_circuit], [task_3, task_4], [task_5, task_6], task_7)
+ # [END howto_operator_short_circuit_trigger_rules]
+
+
+example_dag = example_short_circuit_decorator()
diff --git a/airflow/example_dags/example_short_circuit_operator.py b/airflow/example_dags/example_short_circuit_operator.py
index 2278de30e6294..77f976c502a16 100644
--- a/airflow/example_dags/example_short_circuit_operator.py
+++ b/airflow/example_dags/example_short_circuit_operator.py
@@ -15,8 +15,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Example DAG demonstrating the usage of the ShortCircuitOperator."""
+from __future__ import annotations
+
import pendulum
from airflow import DAG
@@ -26,30 +27,27 @@
from airflow.utils.trigger_rule import TriggerRule
with DAG(
- dag_id='example_short_circuit_operator',
+ dag_id="example_short_circuit_operator",
start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
catchup=False,
- tags=['example'],
+ tags=["example"],
) as dag:
- # [START howto_operator_short_circuit]
cond_true = ShortCircuitOperator(
- task_id='condition_is_True',
+ task_id="condition_is_True",
python_callable=lambda: True,
)
cond_false = ShortCircuitOperator(
- task_id='condition_is_False',
+ task_id="condition_is_False",
python_callable=lambda: False,
)
- ds_true = [EmptyOperator(task_id='true_' + str(i)) for i in [1, 2]]
- ds_false = [EmptyOperator(task_id='false_' + str(i)) for i in [1, 2]]
+ ds_true = [EmptyOperator(task_id="true_" + str(i)) for i in [1, 2]]
+ ds_false = [EmptyOperator(task_id="false_" + str(i)) for i in [1, 2]]
chain(cond_true, *ds_true)
chain(cond_false, *ds_false)
- # [END howto_operator_short_circuit]
- # [START howto_operator_short_circuit_trigger_rules]
[task_1, task_2, task_3, task_4, task_5, task_6] = [
EmptyOperator(task_id=f"task_{i}") for i in range(1, 7)
]
@@ -61,4 +59,3 @@
)
chain(task_1, [task_2, short_circuit], [task_3, task_4], [task_5, task_6], task_7)
- # [END howto_operator_short_circuit_trigger_rules]
diff --git a/airflow/example_dags/example_skip_dag.py b/airflow/example_dags/example_skip_dag.py
index 00d3a3d91e11b..c8b8958446686 100644
--- a/airflow/example_dags/example_skip_dag.py
+++ b/airflow/example_dags/example_skip_dag.py
@@ -15,8 +15,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Example DAG demonstrating the EmptyOperator and a custom EmptySkipOperator which skips by default."""
+from __future__ import annotations
import pendulum
@@ -31,7 +31,7 @@
class EmptySkipOperator(EmptyOperator):
"""Empty operator which always skips the task."""
- ui_color = '#e8b7e4'
+ ui_color = "#e8b7e4"
def execute(self, context: Context):
raise AirflowSkipException
@@ -45,10 +45,10 @@ def create_test_pipeline(suffix, trigger_rule):
:param str trigger_rule: TriggerRule for the join task
:param DAG dag_: The DAG to run the operators on
"""
- skip_operator = EmptySkipOperator(task_id=f'skip_operator_{suffix}')
- always_true = EmptyOperator(task_id=f'always_true_{suffix}')
+ skip_operator = EmptySkipOperator(task_id=f"skip_operator_{suffix}")
+ always_true = EmptyOperator(task_id=f"always_true_{suffix}")
join = EmptyOperator(task_id=trigger_rule, trigger_rule=trigger_rule)
- final = EmptyOperator(task_id=f'final_{suffix}')
+ final = EmptyOperator(task_id=f"final_{suffix}")
skip_operator >> join
always_true >> join
@@ -56,10 +56,10 @@ def create_test_pipeline(suffix, trigger_rule):
with DAG(
- dag_id='example_skip_dag',
+ dag_id="example_skip_dag",
start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
catchup=False,
- tags=['example'],
+ tags=["example"],
) as dag:
- create_test_pipeline('1', TriggerRule.ALL_SUCCESS)
- create_test_pipeline('2', TriggerRule.ONE_SUCCESS)
+ create_test_pipeline("1", TriggerRule.ALL_SUCCESS)
+ create_test_pipeline("2", TriggerRule.ONE_SUCCESS)
diff --git a/airflow/example_dags/example_sla_dag.py b/airflow/example_dags/example_sla_dag.py
index 0db6bc1ba7fcc..e76f40f10ddad 100644
--- a/airflow/example_dags/example_sla_dag.py
+++ b/airflow/example_dags/example_sla_dag.py
@@ -14,6 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+"""Example DAG demonstrating SLA use in Tasks"""
+from __future__ import annotations
import datetime
import time
@@ -22,8 +24,6 @@
from airflow.decorators import dag, task
-"""Example DAG demonstrating SLA use in Tasks"""
-
# [START howto_task_sla]
def sla_callback(dag, task_list, blocking_task_list, slas, blocking_tis):
@@ -40,11 +40,11 @@ def sla_callback(dag, task_list, blocking_task_list, slas, blocking_tis):
@dag(
- schedule_interval="*/2 * * * *",
+ schedule="*/2 * * * *",
start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
catchup=False,
sla_miss_callback=sla_callback,
- default_args={'email': "email@example.com"},
+ default_args={"email": "email@example.com"},
)
def example_sla_dag():
@task(sla=datetime.timedelta(seconds=10))
@@ -60,6 +60,6 @@ def sleep_30():
sleep_20() >> sleep_30()
-dag = example_sla_dag()
+example_dag = example_sla_dag()
# [END howto_task_sla]
diff --git a/airflow/example_dags/example_subdag_operator.py b/airflow/example_dags/example_subdag_operator.py
index 79d369d638d6a..f7c3b098135bb 100644
--- a/airflow/example_dags/example_subdag_operator.py
+++ b/airflow/example_dags/example_subdag_operator.py
@@ -15,8 +15,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Example DAG demonstrating the usage of the SubDagOperator."""
+from __future__ import annotations
# [START example_subdag_operator]
import datetime
@@ -26,36 +26,36 @@
from airflow.operators.empty import EmptyOperator
from airflow.operators.subdag import SubDagOperator
-DAG_NAME = 'example_subdag_operator'
+DAG_NAME = "example_subdag_operator"
with DAG(
dag_id=DAG_NAME,
default_args={"retries": 2},
start_date=datetime.datetime(2022, 1, 1),
- schedule_interval="@once",
- tags=['example'],
+ schedule="@once",
+ tags=["example"],
) as dag:
start = EmptyOperator(
- task_id='start',
+ task_id="start",
)
section_1 = SubDagOperator(
- task_id='section-1',
- subdag=subdag(DAG_NAME, 'section-1', dag.default_args),
+ task_id="section-1",
+ subdag=subdag(DAG_NAME, "section-1", dag.default_args),
)
some_other_task = EmptyOperator(
- task_id='some-other-task',
+ task_id="some-other-task",
)
section_2 = SubDagOperator(
- task_id='section-2',
- subdag=subdag(DAG_NAME, 'section-2', dag.default_args),
+ task_id="section-2",
+ subdag=subdag(DAG_NAME, "section-2", dag.default_args),
)
end = EmptyOperator(
- task_id='end',
+ task_id="end",
)
start >> section_1 >> some_other_task >> section_2 >> end
diff --git a/airflow/example_dags/example_task_group.py b/airflow/example_dags/example_task_group.py
index 7bc319f98b569..9d7a9f2e74d59 100644
--- a/airflow/example_dags/example_task_group.py
+++ b/airflow/example_dags/example_task_group.py
@@ -15,8 +15,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Example DAG demonstrating the usage of the TaskGroup."""
+from __future__ import annotations
+
import pendulum
from airflow.models.dag import DAG
@@ -36,7 +37,7 @@
# [START howto_task_group_section_1]
with TaskGroup("section_1", tooltip="Tasks for section_1") as section_1:
task_1 = EmptyOperator(task_id="task_1")
- task_2 = BashOperator(task_id="task_2", bash_command='echo 1')
+ task_2 = BashOperator(task_id="task_2", bash_command="echo 1")
task_3 = EmptyOperator(task_id="task_3")
task_1 >> [task_2, task_3]
@@ -48,7 +49,7 @@
# [START howto_task_group_inner_section_2]
with TaskGroup("inner_section_2", tooltip="Tasks for inner_section2") as inner_section_2:
- task_2 = BashOperator(task_id="task_2", bash_command='echo 1')
+ task_2 = BashOperator(task_id="task_2", bash_command="echo 1")
task_3 = EmptyOperator(task_id="task_3")
task_4 = EmptyOperator(task_id="task_4")
@@ -57,7 +58,7 @@
# [END howto_task_group_section_2]
- end = EmptyOperator(task_id='end')
+ end = EmptyOperator(task_id="end")
start >> section_1 >> section_2 >> end
# [END howto_task_group]
diff --git a/airflow/example_dags/example_task_group_decorator.py b/airflow/example_dags/example_task_group_decorator.py
index 637c721886bea..db8d7b302529c 100644
--- a/airflow/example_dags/example_task_group_decorator.py
+++ b/airflow/example_dags/example_task_group_decorator.py
@@ -15,8 +15,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Example DAG demonstrating the usage of the @taskgroup decorator."""
+from __future__ import annotations
import pendulum
@@ -29,31 +29,31 @@
@task
def task_start():
"""Empty Task which is First Task of Dag"""
- return '[Task_start]'
+ return "[Task_start]"
@task
def task_1(value: int) -> str:
"""Empty Task1"""
- return f'[ Task1 {value} ]'
+ return f"[ Task1 {value} ]"
@task
def task_2(value: str) -> str:
"""Empty Task2"""
- return f'[ Task2 {value} ]'
+ return f"[ Task2 {value} ]"
@task
def task_3(value: str) -> None:
"""Empty Task3"""
- print(f'[ Task3 {value} ]')
+ print(f"[ Task3 {value} ]")
@task
def task_end() -> None:
"""Empty Task which is Last Task of Dag"""
- print('[ Task_End ]')
+ print("[ Task_End ]")
# Creating TaskGroups
diff --git a/airflow/example_dags/example_time_delta_sensor_async.py b/airflow/example_dags/example_time_delta_sensor_async.py
index a2aa3cb66edee..d1562c5751d7d 100644
--- a/airflow/example_dags/example_time_delta_sensor_async.py
+++ b/airflow/example_dags/example_time_delta_sensor_async.py
@@ -15,11 +15,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""
Example DAG demonstrating ``TimeDeltaSensorAsync``, a drop in replacement for ``TimeDeltaSensor`` that
defers and doesn't occupy a worker slot while it waits
"""
+from __future__ import annotations
import datetime
@@ -31,7 +31,7 @@
with DAG(
dag_id="example_time_delta_sensor_async",
- schedule_interval=None,
+ schedule=None,
start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
catchup=False,
tags=["example"],
diff --git a/airflow/example_dags/example_trigger_controller_dag.py b/airflow/example_dags/example_trigger_controller_dag.py
index a017c9a5b4176..c07fc3b190b0c 100644
--- a/airflow/example_dags/example_trigger_controller_dag.py
+++ b/airflow/example_dags/example_trigger_controller_dag.py
@@ -15,12 +15,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""
Example usage of the TriggerDagRunOperator. This example holds 2 DAGs:
1. 1st DAG (example_trigger_controller_dag) holds a TriggerDagRunOperator, which will trigger the 2nd DAG
2. 2nd DAG (example_trigger_target_dag) which will be triggered by the TriggerDagRunOperator in the 1st DAG
"""
+from __future__ import annotations
+
import pendulum
from airflow import DAG
@@ -30,8 +31,8 @@
dag_id="example_trigger_controller_dag",
start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
catchup=False,
- schedule_interval="@once",
- tags=['example'],
+ schedule="@once",
+ tags=["example"],
) as dag:
trigger = TriggerDagRunOperator(
task_id="test_trigger_dagrun",
diff --git a/airflow/example_dags/example_trigger_target_dag.py b/airflow/example_dags/example_trigger_target_dag.py
index 20932338c8dd8..ff619376da0dc 100644
--- a/airflow/example_dags/example_trigger_target_dag.py
+++ b/airflow/example_dags/example_trigger_target_dag.py
@@ -15,12 +15,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""
Example usage of the TriggerDagRunOperator. This example holds 2 DAGs:
1. 1st DAG (example_trigger_controller_dag) holds a TriggerDagRunOperator, which will trigger the 2nd DAG
2. 2nd DAG (example_trigger_target_dag) which will be triggered by the TriggerDagRunOperator in the 1st DAG
"""
+from __future__ import annotations
+
import pendulum
from airflow import DAG
@@ -42,13 +43,13 @@ def run_this_func(dag_run=None):
dag_id="example_trigger_target_dag",
start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
catchup=False,
- schedule_interval=None,
- tags=['example'],
+ schedule=None,
+ tags=["example"],
) as dag:
run_this = run_this_func()
bash_task = BashOperator(
task_id="bash_task",
bash_command='echo "Here is the message: $message"',
- env={'message': '{{ dag_run.conf.get("message") }}'},
+ env={"message": '{{ dag_run.conf.get("message") }}'},
)
diff --git a/airflow/example_dags/example_xcom.py b/airflow/example_dags/example_xcom.py
index b55d4e5d667cd..8455ad575cc82 100644
--- a/airflow/example_dags/example_xcom.py
+++ b/airflow/example_dags/example_xcom.py
@@ -15,22 +15,23 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Example DAG demonstrating the usage of XComs."""
+from __future__ import annotations
+
import pendulum
-from airflow import DAG
+from airflow import DAG, XComArg
from airflow.decorators import task
from airflow.operators.bash import BashOperator
value_1 = [1, 2, 3]
-value_2 = {'a': 'b'}
+value_2 = {"a": "b"}
@task
def push(ti=None):
"""Pushes an XCom without a specific target"""
- ti.xcom_push(key='value from pusher 1', value=value_1)
+ ti.xcom_push(key="value from pusher 1", value=value_1)
@task
@@ -41,7 +42,7 @@ def push_by_returning():
def _compare_values(pulled_value, check_value):
if pulled_value != check_value:
- raise ValueError(f'The two values differ {pulled_value} and {check_value}')
+ raise ValueError(f"The two values differ {pulled_value} and {check_value}")
@task
@@ -55,21 +56,21 @@ def puller(pulled_value_2, ti=None):
@task
def pull_value_from_bash_push(ti=None):
- bash_pushed_via_return_value = ti.xcom_pull(key="return_value", task_ids='bash_push')
- bash_manually_pushed_value = ti.xcom_pull(key="manually_pushed_value", task_ids='bash_push')
+ bash_pushed_via_return_value = ti.xcom_pull(key="return_value", task_ids="bash_push")
+ bash_manually_pushed_value = ti.xcom_pull(key="manually_pushed_value", task_ids="bash_push")
print(f"The xcom value pushed by task push via return value is {bash_pushed_via_return_value}")
print(f"The xcom value pushed by task push manually is {bash_manually_pushed_value}")
with DAG(
- 'example_xcom',
- schedule_interval="@once",
+ "example_xcom",
+ schedule="@once",
start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
catchup=False,
- tags=['example'],
+ tags=["example"],
) as dag:
bash_push = BashOperator(
- task_id='bash_push',
+ task_id="bash_push",
bash_command='echo "bash_push demo" && '
'echo "Manually set xcom value '
'{{ ti.xcom_push(key="manually_pushed_value", value="manually_pushed_value") }}" && '
@@ -77,10 +78,10 @@ def pull_value_from_bash_push(ti=None):
)
bash_pull = BashOperator(
- task_id='bash_pull',
+ task_id="bash_pull",
bash_command='echo "bash pull demo" && '
- f'echo "The xcom pushed manually is {bash_push.output["manually_pushed_value"]}" && '
- f'echo "The returned_value xcom is {bash_push.output}" && '
+ f'echo "The xcom pushed manually is {XComArg(bash_push, key="manually_pushed_value")}" && '
+ f'echo "The returned_value xcom is {XComArg(bash_push)}" && '
'echo "finished"',
do_xcom_push=False,
)
@@ -90,6 +91,3 @@ def pull_value_from_bash_push(ti=None):
[bash_pull, python_pull_from_bash] << bash_push
puller(push_by_returning()) << push()
-
- # Task dependencies created via `XComArgs`:
- # pull << push2
diff --git a/airflow/example_dags/example_xcomargs.py b/airflow/example_dags/example_xcomargs.py
index 8312aca8c9d25..9d36cb535a8c1 100644
--- a/airflow/example_dags/example_xcomargs.py
+++ b/airflow/example_dags/example_xcomargs.py
@@ -15,8 +15,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Example DAG demonstrating the usage of the XComArgs."""
+from __future__ import annotations
+
import logging
import pendulum
@@ -41,11 +42,11 @@ def print_value(value, ts=None):
with DAG(
- dag_id='example_xcom_args',
+ dag_id="example_xcom_args",
start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
catchup=False,
- schedule_interval=None,
- tags=['example'],
+ schedule=None,
+ tags=["example"],
) as dag:
print_value(generate_value())
@@ -53,8 +54,8 @@ def print_value(value, ts=None):
"example_xcom_args_with_operators",
start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
catchup=False,
- schedule_interval=None,
- tags=['example'],
+ schedule=None,
+ tags=["example"],
) as dag2:
bash_op1 = BashOperator(task_id="c", bash_command="echo c")
bash_op2 = BashOperator(task_id="d", bash_command="echo c")
diff --git a/airflow/example_dags/libs/helper.py b/airflow/example_dags/libs/helper.py
index a3d3a720a0255..e6c2e3c4582fc 100644
--- a/airflow/example_dags/libs/helper.py
+++ b/airflow/example_dags/libs/helper.py
@@ -15,6 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
def print_stuff():
diff --git a/airflow/example_dags/plugins/workday.py b/airflow/example_dags/plugins/workday.py
index 77111a79396de..db68c29541a8f 100644
--- a/airflow/example_dags/plugins/workday.py
+++ b/airflow/example_dags/plugins/workday.py
@@ -15,12 +15,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Plugin to demonstrate timetable registration and accommodate example DAGs."""
+from __future__ import annotations
# [START howto_timetable]
from datetime import timedelta
-from typing import Optional
from pendulum import UTC, Date, DateTime, Time
@@ -47,9 +46,9 @@ def infer_manual_data_interval(self, run_after: DateTime) -> DataInterval:
def next_dagrun_info(
self,
*,
- last_automated_data_interval: Optional[DataInterval],
+ last_automated_data_interval: DataInterval | None,
restriction: TimeRestriction,
- ) -> Optional[DagRunInfo]:
+ ) -> DagRunInfo | None:
if last_automated_data_interval is not None: # There was a previous run on the regular schedule.
last_start = last_automated_data_interval.start
last_start_weekday = last_start.weekday()
diff --git a/airflow/example_dags/sql/sample.sql b/airflow/example_dags/sql/sample.sql
new file mode 100644
index 0000000000000..23af6ab4b9bb3
--- /dev/null
+++ b/airflow/example_dags/sql/sample.sql
@@ -0,0 +1,24 @@
+/*
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+*/
+
+CREATE TABLE Orders (
+ order_id INT PRIMARY KEY,
+ name TEXT,
+ description TEXT
+)
diff --git a/airflow/example_dags/subdags/subdag.py b/airflow/example_dags/subdags/subdag.py
index 2fcab731092fb..0bb9e86948eb1 100644
--- a/airflow/example_dags/subdags/subdag.py
+++ b/airflow/example_dags/subdags/subdag.py
@@ -15,8 +15,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Helper function to generate a DAG and operators given some arguments."""
+from __future__ import annotations
# [START subdag]
import pendulum
@@ -25,7 +25,7 @@
from airflow.operators.empty import EmptyOperator
-def subdag(parent_dag_name, child_dag_name, args):
+def subdag(parent_dag_name, child_dag_name, args) -> DAG:
"""
Generate a DAG to be used as a subdag.
@@ -33,19 +33,18 @@ def subdag(parent_dag_name, child_dag_name, args):
:param str child_dag_name: Id of the child DAG
:param dict args: Default arguments to provide to the subdag
:return: DAG to use as a subdag
- :rtype: airflow.models.DAG
"""
dag_subdag = DAG(
- dag_id=f'{parent_dag_name}.{child_dag_name}',
+ dag_id=f"{parent_dag_name}.{child_dag_name}",
default_args=args,
start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
catchup=False,
- schedule_interval="@daily",
+ schedule="@daily",
)
for i in range(5):
EmptyOperator(
- task_id=f'{child_dag_name}-task-{i + 1}',
+ task_id=f"{child_dag_name}-task-{i + 1}",
default_args=args,
dag=dag_subdag,
)
diff --git a/airflow/example_dags/tutorial.py b/airflow/example_dags/tutorial.py
index ff2bd2fe95cf7..dc6399625c98a 100644
--- a/airflow/example_dags/tutorial.py
+++ b/airflow/example_dags/tutorial.py
@@ -15,12 +15,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""
### Tutorial Documentation
Documentation that goes along with the Airflow tutorial located
[here](https://airflow.apache.org/tutorial.html)
"""
+from __future__ import annotations
+
# [START tutorial]
# [START import_module]
from datetime import datetime, timedelta
@@ -37,17 +38,17 @@
# [START instantiate_dag]
with DAG(
- 'tutorial',
+ "tutorial",
# [START default_args]
# These args will get passed on to each operator
# You can override them on a per-task basis during operator initialization
default_args={
- 'depends_on_past': False,
- 'email': ['airflow@example.com'],
- 'email_on_failure': False,
- 'email_on_retry': False,
- 'retries': 1,
- 'retry_delay': timedelta(minutes=5),
+ "depends_on_past": False,
+ "email": ["airflow@example.com"],
+ "email_on_failure": False,
+ "email_on_retry": False,
+ "retries": 1,
+ "retry_delay": timedelta(minutes=5),
# 'queue': 'bash_queue',
# 'pool': 'backfill',
# 'priority_weight': 10,
@@ -62,25 +63,25 @@
# 'trigger_rule': 'all_success'
},
# [END default_args]
- description='A simple tutorial DAG',
- schedule_interval=timedelta(days=1),
+ description="A simple tutorial DAG",
+ schedule=timedelta(days=1),
start_date=datetime(2021, 1, 1),
catchup=False,
- tags=['example'],
+ tags=["example"],
) as dag:
# [END instantiate_dag]
# t1, t2 and t3 are examples of tasks created by instantiating operators
# [START basic_task]
t1 = BashOperator(
- task_id='print_date',
- bash_command='date',
+ task_id="print_date",
+ bash_command="date",
)
t2 = BashOperator(
- task_id='sleep',
+ task_id="sleep",
depends_on_past=False,
- bash_command='sleep 5',
+ bash_command="sleep 5",
retries=3,
)
# [END basic_task]
@@ -93,11 +94,11 @@
`doc` (plain text), `doc_rst`, `doc_json`, `doc_yaml` which gets
rendered in the UI's Task Instance Details page.
![img](http://montcs.bloomu.edu/~bobmon/Semesters/2012-01/491/import%20soul.png)
-
+ **Image Credit:** Randall Munroe, [XKCD](https://xkcd.com/license.html)
"""
)
- dag.doc_md = __doc__ # providing that you have a docstring at the beginning of the DAG
+ dag.doc_md = __doc__ # providing that you have a docstring at the beginning of the DAG; OR
dag.doc_md = """
This is a documentation placed anywhere
""" # otherwise, type it like this
@@ -114,7 +115,7 @@
)
t3 = BashOperator(
- task_id='templated',
+ task_id="templated",
depends_on_past=False,
bash_command=templated_command,
)
diff --git a/airflow/example_dags/tutorial_dag.py b/airflow/example_dags/tutorial_dag.py
new file mode 100644
index 0000000000000..07b193865de4b
--- /dev/null
+++ b/airflow/example_dags/tutorial_dag.py
@@ -0,0 +1,135 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+### DAG Tutorial Documentation
+This DAG is demonstrating an Extract -> Transform -> Load pipeline
+"""
+from __future__ import annotations
+
+# [START tutorial]
+# [START import_module]
+import json
+from textwrap import dedent
+
+import pendulum
+
+# The DAG object; we'll need this to instantiate a DAG
+from airflow import DAG
+
+# Operators; we need this to operate!
+from airflow.operators.python import PythonOperator
+
+# [END import_module]
+
+# [START instantiate_dag]
+with DAG(
+ "tutorial_dag",
+ # [START default_args]
+ # These args will get passed on to each operator
+ # You can override them on a per-task basis during operator initialization
+ default_args={"retries": 2},
+ # [END default_args]
+ description="DAG tutorial",
+ schedule=None,
+ start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
+ catchup=False,
+ tags=["example"],
+) as dag:
+ # [END instantiate_dag]
+ # [START documentation]
+ dag.doc_md = __doc__
+ # [END documentation]
+
+ # [START extract_function]
+ def extract(**kwargs):
+ ti = kwargs["ti"]
+ data_string = '{"1001": 301.27, "1002": 433.21, "1003": 502.22}'
+ ti.xcom_push("order_data", data_string)
+
+ # [END extract_function]
+
+ # [START transform_function]
+ def transform(**kwargs):
+ ti = kwargs["ti"]
+ extract_data_string = ti.xcom_pull(task_ids="extract", key="order_data")
+ order_data = json.loads(extract_data_string)
+
+ total_order_value = 0
+ for value in order_data.values():
+ total_order_value += value
+
+ total_value = {"total_order_value": total_order_value}
+ total_value_json_string = json.dumps(total_value)
+ ti.xcom_push("total_order_value", total_value_json_string)
+
+ # [END transform_function]
+
+ # [START load_function]
+ def load(**kwargs):
+ ti = kwargs["ti"]
+ total_value_string = ti.xcom_pull(task_ids="transform", key="total_order_value")
+ total_order_value = json.loads(total_value_string)
+
+ print(total_order_value)
+
+ # [END load_function]
+
+ # [START main_flow]
+ extract_task = PythonOperator(
+ task_id="extract",
+ python_callable=extract,
+ )
+ extract_task.doc_md = dedent(
+ """\
+ #### Extract task
+ A simple Extract task to get data ready for the rest of the data pipeline.
+ In this case, getting data is simulated by reading from a hardcoded JSON string.
+ This data is then put into xcom, so that it can be processed by the next task.
+ """
+ )
+
+ transform_task = PythonOperator(
+ task_id="transform",
+ python_callable=transform,
+ )
+ transform_task.doc_md = dedent(
+ """\
+ #### Transform task
+ A simple Transform task which takes in the collection of order data from xcom
+ and computes the total order value.
+ This computed value is then put into xcom, so that it can be processed by the next task.
+ """
+ )
+
+ load_task = PythonOperator(
+ task_id="load",
+ python_callable=load,
+ )
+ load_task.doc_md = dedent(
+ """\
+ #### Load task
+ A simple Load task which takes in the result of the Transform task, by reading it
+ from xcom and instead of saving it to end user review, just prints it out.
+ """
+ )
+
+ extract_task >> transform_task >> load_task
+
+# [END main_flow]
+
+# [END tutorial]
diff --git a/airflow/example_dags/tutorial_etl_dag.py b/airflow/example_dags/tutorial_etl_dag.py
deleted file mode 100644
index d039a73488c18..0000000000000
--- a/airflow/example_dags/tutorial_etl_dag.py
+++ /dev/null
@@ -1,135 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-
-"""
-### ETL DAG Tutorial Documentation
-This ETL DAG is demonstrating an Extract -> Transform -> Load pipeline
-"""
-# [START tutorial]
-# [START import_module]
-import json
-from textwrap import dedent
-
-import pendulum
-
-# The DAG object; we'll need this to instantiate a DAG
-from airflow import DAG
-
-# Operators; we need this to operate!
-from airflow.operators.python import PythonOperator
-
-# [END import_module]
-
-# [START instantiate_dag]
-with DAG(
- 'tutorial_etl_dag',
- # [START default_args]
- # These args will get passed on to each operator
- # You can override them on a per-task basis during operator initialization
- default_args={'retries': 2},
- # [END default_args]
- description='ETL DAG tutorial',
- schedule_interval=None,
- start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
- catchup=False,
- tags=['example'],
-) as dag:
- # [END instantiate_dag]
- # [START documentation]
- dag.doc_md = __doc__
- # [END documentation]
-
- # [START extract_function]
- def extract(**kwargs):
- ti = kwargs['ti']
- data_string = '{"1001": 301.27, "1002": 433.21, "1003": 502.22}'
- ti.xcom_push('order_data', data_string)
-
- # [END extract_function]
-
- # [START transform_function]
- def transform(**kwargs):
- ti = kwargs['ti']
- extract_data_string = ti.xcom_pull(task_ids='extract', key='order_data')
- order_data = json.loads(extract_data_string)
-
- total_order_value = 0
- for value in order_data.values():
- total_order_value += value
-
- total_value = {"total_order_value": total_order_value}
- total_value_json_string = json.dumps(total_value)
- ti.xcom_push('total_order_value', total_value_json_string)
-
- # [END transform_function]
-
- # [START load_function]
- def load(**kwargs):
- ti = kwargs['ti']
- total_value_string = ti.xcom_pull(task_ids='transform', key='total_order_value')
- total_order_value = json.loads(total_value_string)
-
- print(total_order_value)
-
- # [END load_function]
-
- # [START main_flow]
- extract_task = PythonOperator(
- task_id='extract',
- python_callable=extract,
- )
- extract_task.doc_md = dedent(
- """\
- #### Extract task
- A simple Extract task to get data ready for the rest of the data pipeline.
- In this case, getting data is simulated by reading from a hardcoded JSON string.
- This data is then put into xcom, so that it can be processed by the next task.
- """
- )
-
- transform_task = PythonOperator(
- task_id='transform',
- python_callable=transform,
- )
- transform_task.doc_md = dedent(
- """\
- #### Transform task
- A simple Transform task which takes in the collection of order data from xcom
- and computes the total order value.
- This computed value is then put into xcom, so that it can be processed by the next task.
- """
- )
-
- load_task = PythonOperator(
- task_id='load',
- python_callable=load,
- )
- load_task.doc_md = dedent(
- """\
- #### Load task
- A simple Load task which takes in the result of the Transform task, by reading it
- from xcom and instead of saving it to end user review, just prints it out.
- """
- )
-
- extract_task >> transform_task >> load_task
-
-# [END main_flow]
-
-# [END tutorial]
diff --git a/airflow/example_dags/tutorial_taskflow_api.py b/airflow/example_dags/tutorial_taskflow_api.py
new file mode 100644
index 0000000000000..f41f729af8870
--- /dev/null
+++ b/airflow/example_dags/tutorial_taskflow_api.py
@@ -0,0 +1,106 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+# [START tutorial]
+# [START import_module]
+import json
+
+import pendulum
+
+from airflow.decorators import dag, task
+
+# [END import_module]
+
+
+# [START instantiate_dag]
+@dag(
+ schedule=None,
+ start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
+ catchup=False,
+ tags=["example"],
+)
+def tutorial_taskflow_api():
+ """
+ ### TaskFlow API Tutorial Documentation
+ This is a simple data pipeline example which demonstrates the use of
+ the TaskFlow API using three simple tasks for Extract, Transform, and Load.
+ Documentation that goes along with the Airflow TaskFlow API tutorial is
+ located
+ [here](https://airflow.apache.org/docs/apache-airflow/stable/tutorial_taskflow_api.html)
+ """
+ # [END instantiate_dag]
+
+ # [START extract]
+ @task()
+ def extract():
+ """
+ #### Extract task
+ A simple Extract task to get data ready for the rest of the data
+ pipeline. In this case, getting data is simulated by reading from a
+ hardcoded JSON string.
+ """
+ data_string = '{"1001": 301.27, "1002": 433.21, "1003": 502.22}'
+
+ order_data_dict = json.loads(data_string)
+ return order_data_dict
+
+ # [END extract]
+
+ # [START transform]
+ @task(multiple_outputs=True)
+ def transform(order_data_dict: dict):
+ """
+ #### Transform task
+ A simple Transform task which takes in the collection of order data and
+ computes the total order value.
+ """
+ total_order_value = 0
+
+ for value in order_data_dict.values():
+ total_order_value += value
+
+ return {"total_order_value": total_order_value}
+
+ # [END transform]
+
+ # [START load]
+ @task()
+ def load(total_order_value: float):
+ """
+ #### Load task
+ A simple Load task which takes in the result of the Transform task and
+ instead of saving it to end user review, just prints it out.
+ """
+
+ print(f"Total order value is: {total_order_value:.2f}")
+
+ # [END load]
+
+ # [START main_flow]
+ order_data = extract()
+ order_summary = transform(order_data)
+ load(order_summary["total_order_value"])
+ # [END main_flow]
+
+
+# [START dag_invocation]
+tutorial_taskflow_api()
+# [END dag_invocation]
+
+# [END tutorial]
diff --git a/airflow/example_dags/tutorial_taskflow_api_etl.py b/airflow/example_dags/tutorial_taskflow_api_etl.py
deleted file mode 100644
index f6af78f0a5a2c..0000000000000
--- a/airflow/example_dags/tutorial_taskflow_api_etl.py
+++ /dev/null
@@ -1,106 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-
-# [START tutorial]
-# [START import_module]
-import json
-
-import pendulum
-
-from airflow.decorators import dag, task
-
-# [END import_module]
-
-
-# [START instantiate_dag]
-@dag(
- schedule_interval=None,
- start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
- catchup=False,
- tags=['example'],
-)
-def tutorial_taskflow_api_etl():
- """
- ### TaskFlow API Tutorial Documentation
- This is a simple ETL data pipeline example which demonstrates the use of
- the TaskFlow API using three simple tasks for Extract, Transform, and Load.
- Documentation that goes along with the Airflow TaskFlow API tutorial is
- located
- [here](https://airflow.apache.org/docs/apache-airflow/stable/tutorial_taskflow_api.html)
- """
- # [END instantiate_dag]
-
- # [START extract]
- @task()
- def extract():
- """
- #### Extract task
- A simple Extract task to get data ready for the rest of the data
- pipeline. In this case, getting data is simulated by reading from a
- hardcoded JSON string.
- """
- data_string = '{"1001": 301.27, "1002": 433.21, "1003": 502.22}'
-
- order_data_dict = json.loads(data_string)
- return order_data_dict
-
- # [END extract]
-
- # [START transform]
- @task(multiple_outputs=True)
- def transform(order_data_dict: dict):
- """
- #### Transform task
- A simple Transform task which takes in the collection of order data and
- computes the total order value.
- """
- total_order_value = 0
-
- for value in order_data_dict.values():
- total_order_value += value
-
- return {"total_order_value": total_order_value}
-
- # [END transform]
-
- # [START load]
- @task()
- def load(total_order_value: float):
- """
- #### Load task
- A simple Load task which takes in the result of the Transform task and
- instead of saving it to end user review, just prints it out.
- """
-
- print(f"Total order value is: {total_order_value:.2f}")
-
- # [END load]
-
- # [START main_flow]
- order_data = extract()
- order_summary = transform(order_data)
- load(order_summary["total_order_value"])
- # [END main_flow]
-
-
-# [START dag_invocation]
-tutorial_etl_dag = tutorial_taskflow_api_etl()
-# [END dag_invocation]
-
-# [END tutorial]
diff --git a/airflow/example_dags/tutorial_taskflow_api_etl_virtualenv.py b/airflow/example_dags/tutorial_taskflow_api_etl_virtualenv.py
deleted file mode 100644
index ac280956b7f45..0000000000000
--- a/airflow/example_dags/tutorial_taskflow_api_etl_virtualenv.py
+++ /dev/null
@@ -1,89 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-
-import logging
-import shutil
-from datetime import datetime
-
-from airflow.decorators import dag, task
-
-log = logging.getLogger(__name__)
-
-if not shutil.which("virtualenv"):
- log.warning(
- "The tutorial_taskflow_api_etl_virtualenv example DAG requires virtualenv, please install it."
- )
-else:
-
- @dag(schedule_interval=None, start_date=datetime(2021, 1, 1), catchup=False, tags=['example'])
- def tutorial_taskflow_api_etl_virtualenv():
- """
- ### TaskFlow API example using virtualenv
- This is a simple ETL data pipeline example which demonstrates the use of
- the TaskFlow API using three simple tasks for Extract, Transform, and Load.
- """
-
- @task.virtualenv(
- use_dill=True,
- system_site_packages=False,
- requirements=['funcsigs'],
- )
- def extract():
- """
- #### Extract task
- A simple Extract task to get data ready for the rest of the data
- pipeline. In this case, getting data is simulated by reading from a
- hardcoded JSON string.
- """
- import json
-
- data_string = '{"1001": 301.27, "1002": 433.21, "1003": 502.22}'
-
- order_data_dict = json.loads(data_string)
- return order_data_dict
-
- @task(multiple_outputs=True)
- def transform(order_data_dict: dict):
- """
- #### Transform task
- A simple Transform task which takes in the collection of order data and
- computes the total order value.
- """
- total_order_value = 0
-
- for value in order_data_dict.values():
- total_order_value += value
-
- return {"total_order_value": total_order_value}
-
- @task()
- def load(total_order_value: float):
- """
- #### Load task
- A simple Load task which takes in the result of the Transform task and
- instead of saving it to end user review, just prints it out.
- """
-
- print(f"Total order value is: {total_order_value:.2f}")
-
- order_data = extract()
- order_summary = transform(order_data)
- load(order_summary["total_order_value"])
-
- tutorial_etl_dag = tutorial_taskflow_api_etl_virtualenv()
diff --git a/airflow/example_dags/tutorial_taskflow_api_virtualenv.py b/airflow/example_dags/tutorial_taskflow_api_virtualenv.py
new file mode 100644
index 0000000000000..a78116d2f7493
--- /dev/null
+++ b/airflow/example_dags/tutorial_taskflow_api_virtualenv.py
@@ -0,0 +1,87 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import logging
+import shutil
+from datetime import datetime
+
+from airflow.decorators import dag, task
+
+log = logging.getLogger(__name__)
+
+if not shutil.which("virtualenv"):
+ log.warning("The tutorial_taskflow_api_virtualenv example DAG requires virtualenv, please install it.")
+else:
+
+ @dag(schedule=None, start_date=datetime(2021, 1, 1), catchup=False, tags=["example"])
+ def tutorial_taskflow_api_virtualenv():
+ """
+ ### TaskFlow API example using virtualenv
+ This is a simple data pipeline example which demonstrates the use of
+ the TaskFlow API using three simple tasks for Extract, Transform, and Load.
+ """
+
+ @task.virtualenv(
+ use_dill=True,
+ system_site_packages=False,
+ requirements=["funcsigs"],
+ )
+ def extract():
+ """
+ #### Extract task
+ A simple Extract task to get data ready for the rest of the data
+ pipeline. In this case, getting data is simulated by reading from a
+ hardcoded JSON string.
+ """
+ import json
+
+ data_string = '{"1001": 301.27, "1002": 433.21, "1003": 502.22}'
+
+ order_data_dict = json.loads(data_string)
+ return order_data_dict
+
+ @task(multiple_outputs=True)
+ def transform(order_data_dict: dict):
+ """
+ #### Transform task
+ A simple Transform task which takes in the collection of order data and
+ computes the total order value.
+ """
+ total_order_value = 0
+
+ for value in order_data_dict.values():
+ total_order_value += value
+
+ return {"total_order_value": total_order_value}
+
+ @task()
+ def load(total_order_value: float):
+ """
+ #### Load task
+ A simple Load task which takes in the result of the Transform task and
+ instead of saving it to end user review, just prints it out.
+ """
+
+ print(f"Total order value is: {total_order_value:.2f}")
+
+ order_data = extract()
+ order_summary = transform(order_data)
+ load(order_summary["total_order_value"])
+
+ tutorial_dag = tutorial_taskflow_api_virtualenv()
diff --git a/airflow/exceptions.py b/airflow/exceptions.py
index 2f1e53e182a10..a60e3f90ffc73 100644
--- a/airflow/exceptions.py
+++ b/airflow/exceptions.py
@@ -15,14 +15,18 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-#
# Note: Any AirflowException raised is expected to cause the TaskInstance
# to be marked in an ERROR state
-"""Exceptions used by Airflow"""
+"""Exceptions used by Airflow."""
+from __future__ import annotations
+
import datetime
import warnings
from http import HTTPStatus
-from typing import Any, Dict, List, NamedTuple, Optional, Sized
+from typing import TYPE_CHECKING, Any, NamedTuple, Sized
+
+if TYPE_CHECKING:
+ from airflow.models import DagRun
class AirflowException(Exception):
@@ -67,14 +71,6 @@ def __init__(self, reschedule_date):
self.reschedule_date = reschedule_date
-class AirflowSmartSensorException(AirflowException):
- """
- Raise after the task register itself in the smart sensor service.
-
- It should exit without failing a task.
- """
-
-
class InvalidStatsNameException(AirflowException):
"""Raise when name of the stats is invalid."""
@@ -88,7 +84,7 @@ class AirflowWebServerTimeout(AirflowException):
class AirflowSkipException(AirflowException):
- """Raise when the task should be skipped"""
+ """Raise when the task should be skipped."""
class AirflowFailException(AirflowException):
@@ -99,6 +95,19 @@ class AirflowOptionalProviderFeatureException(AirflowException):
"""Raise by providers when imports are missing for optional provider features."""
+class XComNotFound(AirflowException):
+ """Raise when an XCom reference is being resolved against a non-existent XCom."""
+
+ def __init__(self, dag_id: str, task_id: str, key: str) -> None:
+ super().__init__()
+ self.dag_id = dag_id
+ self.task_id = task_id
+ self.key = key
+
+ def __str__(self) -> str:
+ return f'XComArg result from {self.task_id} at {self.dag_id} with key="{self.key}" is not found!'
+
+
class UnmappableOperator(AirflowException):
"""Raise when an operator is not implemented to be mappable."""
@@ -113,12 +122,14 @@ def __str__(self) -> str:
class UnmappableXComTypePushed(AirflowException):
"""Raise when an unmappable type is pushed as a mapped downstream's dependency."""
- def __init__(self, value: Any) -> None:
- super().__init__(value)
- self.value = value
+ def __init__(self, value: Any, *values: Any) -> None:
+ super().__init__(value, *values)
def __str__(self) -> str:
- return f"unmappable return type {type(self.value).__qualname__!r}"
+ typename = type(self.args[0]).__qualname__
+ for arg in self.args[1:]:
+ typename = f"{typename}[{type(arg).__qualname__}]"
+ return f"unmappable return type {typename!r}"
class UnmappableXComLengthPushed(AirflowException):
@@ -150,6 +161,10 @@ def __str__(self) -> str:
return f"Ignoring DAG {self.dag_id} from {self.incoming} - also found in {self.existing}"
+class AirflowDagInconsistent(AirflowException):
+ """Raise when a DAG has inconsistent attributes."""
+
+
class AirflowClusterPolicyViolation(AirflowException):
"""Raise when there is a violation of a Cluster Policy in DAG definition."""
@@ -173,6 +188,12 @@ class DagRunNotFound(AirflowNotFoundException):
class DagRunAlreadyExists(AirflowBadRequest):
"""Raise when creating a DAG run for DAG which already has DAG run entry."""
+ def __init__(self, dag_run: DagRun, execution_date: datetime.datetime, run_id: str) -> None:
+ super().__init__(
+ f"A DAG Run already exists for DAG {dag_run.dag_id} at {execution_date} with run id {run_id}"
+ )
+ self.dag_run = dag_run
+
class DagFileExists(AirflowBadRequest):
"""Raise when a DAG ID is still in DagBag i.e., DAG file is in DAG folder."""
@@ -186,12 +207,29 @@ class DuplicateTaskIdFound(AirflowException):
"""Raise when a Task with duplicate task_id is defined in the same DAG."""
+class TaskAlreadyInTaskGroup(AirflowException):
+ """Raise when a Task cannot be added to a TaskGroup since it already belongs to another TaskGroup."""
+
+ def __init__(self, task_id: str, existing_group_id: str | None, new_group_id: str) -> None:
+ super().__init__(task_id, new_group_id)
+ self.task_id = task_id
+ self.existing_group_id = existing_group_id
+ self.new_group_id = new_group_id
+
+ def __str__(self) -> str:
+ if self.existing_group_id is None:
+ existing_group = "the DAG's root group"
+ else:
+ existing_group = f"group {self.existing_group_id!r}"
+ return f"cannot add {self.task_id!r} to {self.new_group_id!r} (already in {existing_group})"
+
+
class SerializationError(AirflowException):
- """A problem occurred when trying to serialize a DAG."""
+ """A problem occurred when trying to serialize something."""
class ParamValidationError(AirflowException):
- """Raise when DAG params is invalid"""
+ """Raise when DAG params is invalid."""
class TaskNotFound(AirflowNotFoundException):
@@ -234,7 +272,7 @@ def __init__(self, message, ti_status):
class FileSyntaxError(NamedTuple):
"""Information about a single error in a file."""
- line_no: Optional[int]
+ line_no: int | None
message: str
def __str__(self):
@@ -250,7 +288,7 @@ class AirflowFileParseException(AirflowException):
:param parse_errors: File syntax errors
"""
- def __init__(self, msg: str, file_path: str, parse_errors: List[FileSyntaxError]) -> None:
+ def __init__(self, msg: str, file_path: str, parse_errors: list[FileSyntaxError]) -> None:
super().__init__(msg)
self.msg = msg
self.file_path = file_path
@@ -279,6 +317,8 @@ class ConnectionNotUnique(AirflowException):
class TaskDeferred(BaseException):
"""
+ Signal an operator moving to deferred state.
+
Special exception raised to signal that the operator it was raised from
wishes to defer until a trigger fires.
"""
@@ -288,8 +328,8 @@ def __init__(
*,
trigger,
method_name: str,
- kwargs: Optional[Dict[str, Any]] = None,
- timeout: Optional[datetime.timedelta] = None,
+ kwargs: dict[str, Any] | None = None,
+ timeout: datetime.timedelta | None = None,
):
super().__init__()
self.trigger = trigger
@@ -306,3 +346,25 @@ def __repr__(self) -> str:
class TaskDeferralError(AirflowException):
"""Raised when a task failed during deferral for some reason."""
+
+
+class PodMutationHookException(AirflowException):
+ """Raised when exception happens during Pod Mutation Hook execution."""
+
+
+class PodReconciliationError(AirflowException):
+ """Raised when an error is encountered while trying to merge pod configs."""
+
+
+class RemovedInAirflow3Warning(DeprecationWarning):
+ """Issued for usage of deprecated features that will be removed in Airflow3."""
+
+ deprecated_since: str | None = None
+ "Indicates the airflow version that started raising this deprecation warning"
+
+
+class AirflowProviderDeprecationWarning(DeprecationWarning):
+ """Issued for usage of deprecated features of Airflow provider."""
+
+ deprecated_provider_since: str | None = None
+ "Indicates the provider version that started raising this deprecation warning"
diff --git a/airflow/executors/base_executor.py b/airflow/executors/base_executor.py
index 7fcbd0642e14d..0c9af11864f99 100644
--- a/airflow/executors/base_executor.py
+++ b/airflow/executors/base_executor.py
@@ -15,21 +15,23 @@
# specific language governing permissions and limitations
# under the License.
"""Base executor - this is the base class for all the implemented executors."""
+from __future__ import annotations
+
import sys
+import warnings
from collections import OrderedDict
-from typing import Any, Counter, Dict, List, Optional, Set, Tuple, Union
+from typing import Any, Counter, List, Optional, Sequence, Tuple
from airflow.callbacks.base_callback_sink import BaseCallbackSink
from airflow.callbacks.callback_requests import CallbackRequest
from airflow.configuration import conf
+from airflow.exceptions import RemovedInAirflow3Warning
from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
from airflow.stats import Stats
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.state import State
-PARALLELISM: int = conf.getint('core', 'PARALLELISM')
-
-NOT_STARTED_MESSAGE = "The executor should be started first!"
+PARALLELISM: int = conf.getint("core", "PARALLELISM")
QUEUEING_ATTEMPTS = 5
@@ -62,15 +64,15 @@ class BaseExecutor(LoggingMixin):
``0`` for infinity
"""
- job_id: Union[None, int, str] = None
- callback_sink: Optional[BaseCallbackSink] = None
+ job_id: None | int | str = None
+ callback_sink: BaseCallbackSink | None = None
def __init__(self, parallelism: int = PARALLELISM):
super().__init__()
self.parallelism: int = parallelism
self.queued_tasks: OrderedDict[TaskInstanceKey, QueuedTaskInstanceType] = OrderedDict()
- self.running: Set[TaskInstanceKey] = set()
- self.event_buffer: Dict[TaskInstanceKey, EventBufferValueType] = {}
+ self.running: set[TaskInstanceKey] = set()
+ self.event_buffer: dict[TaskInstanceKey, EventBufferValueType] = {}
self.attempts: Counter[TaskInstanceKey] = Counter()
def __repr__(self):
@@ -84,7 +86,7 @@ def queue_command(
task_instance: TaskInstance,
command: CommandType,
priority: int = 1,
- queue: Optional[str] = None,
+ queue: str | None = None,
):
"""Queues command to task"""
if task_instance.key not in self.queued_tasks:
@@ -97,13 +99,13 @@ def queue_task_instance(
self,
task_instance: TaskInstance,
mark_success: bool = False,
- pickle_id: Optional[str] = None,
+ pickle_id: str | None = None,
ignore_all_deps: bool = False,
ignore_depends_on_past: bool = False,
ignore_task_deps: bool = False,
ignore_ti_state: bool = False,
- pool: Optional[str] = None,
- cfg_path: Optional[str] = None,
+ pool: str | None = None,
+ cfg_path: str | None = None,
) -> None:
"""Queues task instance."""
pool = pool or task_instance.pool
@@ -160,9 +162,9 @@ def heartbeat(self) -> None:
self.log.debug("%s in queue", num_queued_tasks)
self.log.debug("%s open slots", open_slots)
- Stats.gauge('executor.open_slots', open_slots)
- Stats.gauge('executor.queued_tasks', num_queued_tasks)
- Stats.gauge('executor.running_tasks', num_running_tasks)
+ Stats.gauge("executor.open_slots", open_slots)
+ Stats.gauge("executor.queued_tasks", num_queued_tasks)
+ Stats.gauge("executor.running_tasks", num_running_tasks)
self.trigger_tasks(open_slots)
@@ -170,7 +172,7 @@ def heartbeat(self) -> None:
self.log.debug("Calling the %s sync method", self.__class__)
self.sync()
- def order_queued_tasks_by_priority(self) -> List[Tuple[TaskInstanceKey, QueuedTaskInstanceType]]:
+ def order_queued_tasks_by_priority(self) -> list[tuple[TaskInstanceKey, QueuedTaskInstanceType]]:
"""
Orders the queued tasks by priority.
@@ -221,7 +223,7 @@ def trigger_tasks(self, open_slots: int) -> None:
if task_tuples:
self._process_tasks(task_tuples)
- def _process_tasks(self, task_tuples: List[TaskTuple]) -> None:
+ def _process_tasks(self, task_tuples: list[TaskTuple]) -> None:
for key, command, queue, executor_config in task_tuples:
del self.queued_tasks[key]
self.execute_async(key=key, command=command, queue=queue, executor_config=executor_config)
@@ -239,7 +241,7 @@ def change_state(self, key: TaskInstanceKey, state: str, info=None) -> None:
try:
self.running.remove(key)
except KeyError:
- self.log.debug('Could not find key: %s', str(key))
+ self.log.debug("Could not find key: %s", str(key))
self.event_buffer[key] = state, info
def fail(self, key: TaskInstanceKey, info=None) -> None:
@@ -260,7 +262,7 @@ def success(self, key: TaskInstanceKey, info=None) -> None:
"""
self.change_state(key, State.SUCCESS, info)
- def get_event_buffer(self, dag_ids=None) -> Dict[TaskInstanceKey, EventBufferValueType]:
+ def get_event_buffer(self, dag_ids=None) -> dict[TaskInstanceKey, EventBufferValueType]:
"""
Returns and flush the event buffer. In case dag_ids is specified
it will only return and flush events for the given dag_ids. Otherwise
@@ -269,7 +271,7 @@ def get_event_buffer(self, dag_ids=None) -> Dict[TaskInstanceKey, EventBufferVal
:param dag_ids: the dag_ids to return events for; returns all if given ``None``.
:return: a dict of events
"""
- cleared_events: Dict[TaskInstanceKey, EventBufferValueType] = {}
+ cleared_events: dict[TaskInstanceKey, EventBufferValueType] = {}
if dag_ids is None:
cleared_events = self.event_buffer
self.event_buffer = {}
@@ -284,8 +286,8 @@ def execute_async(
self,
key: TaskInstanceKey,
command: CommandType,
- queue: Optional[str] = None,
- executor_config: Optional[Any] = None,
+ queue: str | None = None,
+ executor_config: Any | None = None,
) -> None: # pragma: no cover
"""
This method will execute the command asynchronously.
@@ -309,7 +311,7 @@ def terminate(self):
"""This method is called when the daemon receives a SIGTERM"""
raise NotImplementedError()
- def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance]:
+ def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[TaskInstance]:
"""
Try to adopt running task instances that have been abandoned by a SchedulerJob dying.
@@ -317,7 +319,6 @@ def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance
re-scheduling)
:return: any TaskInstances that were unable to be adopted
- :rtype: list[airflow.models.TaskInstance]
"""
# By default, assume Executors cannot adopt tasks, so just say we failed to adopt anything.
# Subclasses can do better!
@@ -332,10 +333,43 @@ def slots_available(self):
return sys.maxsize
@staticmethod
- def validate_command(command: List[str]) -> None:
- """Check if the command to execute is airflow command"""
+ def validate_command(command: list[str]) -> None:
+ """
+ Back-compat method to Check if the command to execute is airflow command
+
+ :param command: command to check
+ :return: None
+ """
+ warnings.warn(
+ """
+ The `validate_command` method is deprecated. Please use ``validate_airflow_tasks_run_command``
+ """,
+ RemovedInAirflow3Warning,
+ stacklevel=2,
+ )
+ BaseExecutor.validate_airflow_tasks_run_command(command)
+
+ @staticmethod
+ def validate_airflow_tasks_run_command(command: list[str]) -> tuple[str | None, str | None]:
+ """
+ Check if the command to execute is airflow command
+
+ Returns tuple (dag_id,task_id) retrieved from the command (replaced with None values if missing)
+ """
if command[0:3] != ["airflow", "tasks", "run"]:
raise ValueError('The command must start with ["airflow", "tasks", "run"].')
+ if len(command) > 3 and "--help" not in command:
+ dag_id: str | None = None
+ task_id: str | None = None
+ for arg in command[3:]:
+ if not arg.startswith("--"):
+ if dag_id is None:
+ dag_id = arg
+ else:
+ task_id = arg
+ break
+ return dag_id, task_id
+ return None, None
def debug_dump(self):
"""Called in response to SIGUSR2 by the scheduler"""
diff --git a/airflow/executors/celery_executor.py b/airflow/executors/celery_executor.py
index 7b4c04e225a75..896a46694cfd0 100644
--- a/airflow/executors/celery_executor.py
+++ b/airflow/executors/celery_executor.py
@@ -21,6 +21,8 @@
For more information on how the CeleryExecutor works, take a look at the guide:
:ref:`executor:CeleryExecutor`
"""
+from __future__ import annotations
+
import datetime
import logging
import math
@@ -33,7 +35,7 @@
from concurrent.futures import ProcessPoolExecutor
from enum import Enum
from multiprocessing import cpu_count
-from typing import Any, Dict, List, Mapping, MutableMapping, Optional, Set, Tuple, Union
+from typing import Any, Mapping, MutableMapping, Optional, Sequence, Tuple
from celery import Celery, Task, states as celery_states
from celery.backends.base import BaseKeyValueStoreBackend
@@ -50,6 +52,7 @@
from airflow.executors.base_executor import BaseExecutor, CommandType, EventBufferValueType, TaskTuple
from airflow.models.taskinstance import TaskInstance, TaskInstanceKey
from airflow.stats import Stats
+from airflow.utils.dag_parsing_context import _airflow_parsing_context_manager
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.net import get_hostname
from airflow.utils.session import NEW_SESSION, provide_session
@@ -60,43 +63,43 @@
log = logging.getLogger(__name__)
# Make it constant for unit test.
-CELERY_FETCH_ERR_MSG_HEADER = 'Error fetching Celery task state'
+CELERY_FETCH_ERR_MSG_HEADER = "Error fetching Celery task state"
-CELERY_SEND_ERR_MSG_HEADER = 'Error sending Celery task'
+CELERY_SEND_ERR_MSG_HEADER = "Error sending Celery task"
-OPERATION_TIMEOUT = conf.getfloat('celery', 'operation_timeout', fallback=1.0)
+OPERATION_TIMEOUT = conf.getfloat("celery", "operation_timeout", fallback=1.0)
-'''
+"""
To start the celery worker, run the command:
airflow celery worker
-'''
+"""
-if conf.has_option('celery', 'celery_config_options'):
- celery_configuration = conf.getimport('celery', 'celery_config_options')
+if conf.has_option("celery", "celery_config_options"):
+ celery_configuration = conf.getimport("celery", "celery_config_options")
else:
celery_configuration = DEFAULT_CELERY_CONFIG
-app = Celery(conf.get('celery', 'CELERY_APP_NAME'), config_source=celery_configuration)
+app = Celery(conf.get("celery", "CELERY_APP_NAME"), config_source=celery_configuration)
@app.task
def execute_command(command_to_exec: CommandType) -> None:
"""Executes command."""
- BaseExecutor.validate_command(command_to_exec)
+ dag_id, task_id = BaseExecutor.validate_airflow_tasks_run_command(command_to_exec)
celery_task_id = app.current_task.request.id
log.info("[%s] Executing command in Celery: %s", celery_task_id, command_to_exec)
-
- try:
- if settings.EXECUTE_TASKS_NEW_PYTHON_INTERPRETER:
- _execute_in_subprocess(command_to_exec, celery_task_id)
- else:
- _execute_in_fork(command_to_exec, celery_task_id)
- except Exception:
- Stats.incr("celery.execute_command.failure")
- raise
+ with _airflow_parsing_context_manager(dag_id=dag_id, task_id=task_id):
+ try:
+ if settings.EXECUTE_TASKS_NEW_PYTHON_INTERPRETER:
+ _execute_in_subprocess(command_to_exec, celery_task_id)
+ else:
+ _execute_in_fork(command_to_exec, celery_task_id)
+ except Exception:
+ Stats.incr("celery.execute_command.failure")
+ raise
-def _execute_in_fork(command_to_exec: CommandType, celery_task_id: Optional[str] = None) -> None:
+def _execute_in_fork(command_to_exec: CommandType, celery_task_id: str | None = None) -> None:
pid = os.fork()
if pid:
# In parent, wait for the child
@@ -104,7 +107,7 @@ def _execute_in_fork(command_to_exec: CommandType, celery_task_id: Optional[str]
if ret == 0:
return
- msg = f'Celery command failed on host: {get_hostname()} with celery_task_id {celery_task_id}'
+ msg = f"Celery command failed on host: {get_hostname()} with celery_task_id {celery_task_id}"
raise AirflowException(msg)
from airflow.sentry import Sentry
@@ -124,7 +127,6 @@ def _execute_in_fork(command_to_exec: CommandType, celery_task_id: Optional[str]
args.external_executor_id = celery_task_id
setproctitle(f"airflow task supervisor: {command_to_exec}")
-
args.func(args)
ret = 0
except Exception as e:
@@ -136,16 +138,16 @@ def _execute_in_fork(command_to_exec: CommandType, celery_task_id: Optional[str]
os._exit(ret)
-def _execute_in_subprocess(command_to_exec: CommandType, celery_task_id: Optional[str] = None) -> None:
+def _execute_in_subprocess(command_to_exec: CommandType, celery_task_id: str | None = None) -> None:
env = os.environ.copy()
if celery_task_id:
env["external_executor_id"] = celery_task_id
try:
subprocess.check_output(command_to_exec, stderr=subprocess.STDOUT, close_fds=True, env=env)
except subprocess.CalledProcessError as e:
- log.exception('[%s] execute_command encountered a CalledProcessError', celery_task_id)
+ log.exception("[%s] execute_command encountered a CalledProcessError", celery_task_id)
log.error(e.output)
- msg = f'Celery command failed on host: {get_hostname()} with celery_task_id {celery_task_id}'
+ msg = f"Celery command failed on host: {get_hostname()} with celery_task_id {celery_task_id}"
raise AirflowException(msg)
@@ -169,7 +171,7 @@ def __init__(self, exception: Exception, exception_traceback: str):
def send_task_to_executor(
task_tuple: TaskInstanceInCelery,
-) -> Tuple[TaskInstanceKey, CommandType, Union[AsyncResult, ExceptionWithTraceback]]:
+) -> tuple[TaskInstanceKey, CommandType, AsyncResult | ExceptionWithTraceback]:
"""Sends task to executor."""
key, command, queue, task_to_run = task_tuple
try:
@@ -233,36 +235,35 @@ def __init__(self):
# Celery doesn't support bulk sending the tasks (which can become a bottleneck on bigger clusters)
# so we use a multiprocessing pool to speed this up.
# How many worker processes are created for checking celery task state.
- self._sync_parallelism = conf.getint('celery', 'SYNC_PARALLELISM')
+ self._sync_parallelism = conf.getint("celery", "SYNC_PARALLELISM")
if self._sync_parallelism == 0:
self._sync_parallelism = max(1, cpu_count() - 1)
self.bulk_state_fetcher = BulkStateFetcher(self._sync_parallelism)
self.tasks = {}
- self.stalled_task_timeouts: Dict[TaskInstanceKey, datetime.datetime] = {}
+ self.stalled_task_timeouts: dict[TaskInstanceKey, datetime.datetime] = {}
self.stalled_task_timeout = datetime.timedelta(
- seconds=conf.getint('celery', 'stalled_task_timeout', fallback=0)
+ seconds=conf.getint("celery", "stalled_task_timeout", fallback=0)
)
- self.adopted_task_timeouts: Dict[TaskInstanceKey, datetime.datetime] = {}
+ self.adopted_task_timeouts: dict[TaskInstanceKey, datetime.datetime] = {}
self.task_adoption_timeout = (
- datetime.timedelta(seconds=conf.getint('celery', 'task_adoption_timeout', fallback=600))
+ datetime.timedelta(seconds=conf.getint("celery", "task_adoption_timeout", fallback=600))
or self.stalled_task_timeout
)
self.task_publish_retries: Counter[TaskInstanceKey] = Counter()
- self.task_publish_max_retries = conf.getint('celery', 'task_publish_max_retries', fallback=3)
+ self.task_publish_max_retries = conf.getint("celery", "task_publish_max_retries", fallback=3)
def start(self) -> None:
- self.log.debug('Starting Celery Executor using %s processes for syncing', self._sync_parallelism)
+ self.log.debug("Starting Celery Executor using %s processes for syncing", self._sync_parallelism)
def _num_tasks_per_send_process(self, to_send_count: int) -> int:
"""
How many Celery tasks should each worker process send.
:return: Number of tasks that should be sent per process
- :rtype: int
"""
return max(1, int(math.ceil(1.0 * to_send_count / self._sync_parallelism)))
- def _process_tasks(self, task_tuples: List[TaskTuple]) -> None:
+ def _process_tasks(self, task_tuples: list[TaskTuple]) -> None:
task_tuples_to_send = [task_tuple[:3] + (execute_command,) for task_tuple in task_tuples]
first_task = next(t[3] for t in task_tuples_to_send)
@@ -271,7 +272,7 @@ def _process_tasks(self, task_tuples: List[TaskTuple]) -> None:
cached_celery_backend = first_task.backend
key_and_async_results = self._send_tasks_to_celery(task_tuples_to_send)
- self.log.debug('Sent all tasks.')
+ self.log.debug("Sent all tasks.")
for key, _, result in key_and_async_results:
if isinstance(result, ExceptionWithTraceback) and isinstance(
@@ -305,9 +306,9 @@ def _process_tasks(self, task_tuples: List[TaskTuple]) -> None:
self.event_buffer[key] = (State.QUEUED, result.task_id)
# If the task runs _really quickly_ we may already have a result!
- self.update_task_state(key, result.state, getattr(result, 'info', None))
+ self.update_task_state(key, result.state, getattr(result, "info", None))
- def _send_tasks_to_celery(self, task_tuples_to_send: List[TaskInstanceInCelery]):
+ def _send_tasks_to_celery(self, task_tuples_to_send: list[TaskInstanceInCelery]):
if len(task_tuples_to_send) == 1 or self._sync_parallelism == 1:
# One tuple, or max one process -> send it in the main thread.
return list(map(send_task_to_executor, task_tuples_to_send))
@@ -354,8 +355,8 @@ def _check_for_stalled_tasks(self) -> None:
self._send_stalled_tis_back_to_scheduler(timedout_keys)
def _get_timedout_ti_keys(
- self, task_timeouts: Dict[TaskInstanceKey, datetime.datetime]
- ) -> List[TaskInstanceKey]:
+ self, task_timeouts: dict[TaskInstanceKey, datetime.datetime]
+ ) -> list[TaskInstanceKey]:
"""
These timeouts exist to check to see if any of our tasks have not progressed
in the expected time. This can happen for few different reasons, usually related
@@ -387,7 +388,7 @@ def _get_timedout_ti_keys(
@provide_session
def _send_stalled_tis_back_to_scheduler(
- self, keys: List[TaskInstanceKey], session: Session = NEW_SESSION
+ self, keys: list[TaskInstanceKey], session: Session = NEW_SESSION
) -> None:
try:
session.query(TaskInstance).filter(
@@ -478,7 +479,7 @@ def end(self, synchronous: bool = False) -> None:
def terminate(self):
pass
- def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance]:
+ def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[TaskInstance]:
# See which of the TIs are still alive (or have finished even!)
#
# Since Celery doesn't store "SENT" state for queued commands (if we create an AsyncResult with a made
@@ -528,7 +529,7 @@ def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance
adopted.append(f"{ti} in state {state}")
if adopted:
- task_instance_str = '\n\t'.join(adopted)
+ task_instance_str = "\n\t".join(adopted)
self.log.info(
"Adopted the following %d tasks from a dead executor\n\t%s", len(adopted), task_instance_str
)
@@ -536,7 +537,7 @@ def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance
return not_adopted_tis
def _set_celery_pending_task_timeout(
- self, key: TaskInstanceKey, timeout_type: Optional[_CeleryPendingTaskTimeoutType]
+ self, key: TaskInstanceKey, timeout_type: _CeleryPendingTaskTimeoutType | None
) -> None:
"""
We use the fact that dicts maintain insertion order, and the the timeout for a
@@ -551,7 +552,7 @@ def _set_celery_pending_task_timeout(
self.stalled_task_timeouts[key] = utcnow() + self.stalled_task_timeout
-def fetch_celery_task_state(async_result: AsyncResult) -> Tuple[str, Union[str, ExceptionWithTraceback], Any]:
+def fetch_celery_task_state(async_result: AsyncResult) -> tuple[str, str | ExceptionWithTraceback, Any]:
"""
Fetch and return the state of the given celery task. The scope of this function is
global so that it can be called by subprocesses in the pool.
@@ -560,13 +561,12 @@ def fetch_celery_task_state(async_result: AsyncResult) -> Tuple[str, Union[str,
to fetch the task's state
:return: a tuple of the Celery task key and the Celery state and the celery info
of the task
- :rtype: tuple[str, str, str]
"""
try:
with timeout(seconds=OPERATION_TIMEOUT):
# Accessing state property of celery task will make actual network request
# to get the current state of the task
- info = async_result.info if hasattr(async_result, 'info') else None
+ info = async_result.info if hasattr(async_result, "info") else None
return async_result.task_id, async_result.state, info
except Exception as e:
exception_traceback = f"Celery Task ID: {async_result}\n{traceback.format_exc()}"
@@ -586,7 +586,7 @@ def __init__(self, sync_parallelism=None):
super().__init__()
self._sync_parallelism = sync_parallelism
- def _tasks_list_to_task_ids(self, async_tasks) -> Set[str]:
+ def _tasks_list_to_task_ids(self, async_tasks) -> set[str]:
return {a.task_id for a in async_tasks}
def get_many(self, async_results) -> Mapping[str, EventBufferValueType]:
diff --git a/airflow/executors/celery_kubernetes_executor.py b/airflow/executors/celery_kubernetes_executor.py
index b1edc32235727..cf18158b09531 100644
--- a/airflow/executors/celery_kubernetes_executor.py
+++ b/airflow/executors/celery_kubernetes_executor.py
@@ -15,7 +15,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Dict, List, Optional, Set, Union
+from __future__ import annotations
+
+from typing import Sequence
from airflow.callbacks.base_callback_sink import BaseCallbackSink
from airflow.callbacks.callback_requests import CallbackRequest
@@ -37,19 +39,19 @@ class CeleryKubernetesExecutor(LoggingMixin):
"""
supports_ad_hoc_ti_run: bool = True
- callback_sink: Optional[BaseCallbackSink] = None
+ callback_sink: BaseCallbackSink | None = None
- KUBERNETES_QUEUE = conf.get('celery_kubernetes_executor', 'kubernetes_queue')
+ KUBERNETES_QUEUE = conf.get("celery_kubernetes_executor", "kubernetes_queue")
def __init__(self, celery_executor: CeleryExecutor, kubernetes_executor: KubernetesExecutor):
super().__init__()
- self._job_id: Optional[int] = None
+ self._job_id: int | None = None
self.celery_executor = celery_executor
self.kubernetes_executor = kubernetes_executor
self.kubernetes_executor.kubernetes_queue = self.KUBERNETES_QUEUE
@property
- def queued_tasks(self) -> Dict[TaskInstanceKey, QueuedTaskInstanceType]:
+ def queued_tasks(self) -> dict[TaskInstanceKey, QueuedTaskInstanceType]:
"""Return queued tasks from celery and kubernetes executor"""
queued_tasks = self.celery_executor.queued_tasks.copy()
queued_tasks.update(self.kubernetes_executor.queued_tasks)
@@ -57,12 +59,12 @@ def queued_tasks(self) -> Dict[TaskInstanceKey, QueuedTaskInstanceType]:
return queued_tasks
@property
- def running(self) -> Set[TaskInstanceKey]:
+ def running(self) -> set[TaskInstanceKey]:
"""Return running tasks from celery and kubernetes executor"""
return self.celery_executor.running.union(self.kubernetes_executor.running)
@property
- def job_id(self) -> Optional[int]:
+ def job_id(self) -> int | None:
"""
This is a class attribute in BaseExecutor but since this is not really an executor, but a wrapper
of executors we implement as property so we can have custom setter.
@@ -70,7 +72,7 @@ def job_id(self) -> Optional[int]:
return self._job_id
@job_id.setter
- def job_id(self, value: Optional[int]) -> None:
+ def job_id(self, value: int | None) -> None:
"""job_id is manipulated by SchedulerJob. We must propagate the job_id to wrapped executors."""
self._job_id = value
self.kubernetes_executor.job_id = value
@@ -91,7 +93,7 @@ def queue_command(
task_instance: TaskInstance,
command: CommandType,
priority: int = 1,
- queue: Optional[str] = None,
+ queue: str | None = None,
) -> None:
"""Queues command via celery or kubernetes executor"""
executor = self._router(task_instance)
@@ -102,13 +104,13 @@ def queue_task_instance(
self,
task_instance: TaskInstance,
mark_success: bool = False,
- pickle_id: Optional[str] = None,
+ pickle_id: str | None = None,
ignore_all_deps: bool = False,
ignore_depends_on_past: bool = False,
ignore_task_deps: bool = False,
ignore_ti_state: bool = False,
- pool: Optional[str] = None,
- cfg_path: Optional[str] = None,
+ pool: str | None = None,
+ cfg_path: str | None = None,
) -> None:
"""Queues task instance via celery or kubernetes executor"""
executor = self._router(SimpleTaskInstance.from_ti(task_instance))
@@ -144,8 +146,8 @@ def heartbeat(self) -> None:
self.kubernetes_executor.heartbeat()
def get_event_buffer(
- self, dag_ids: Optional[List[str]] = None
- ) -> Dict[TaskInstanceKey, EventBufferValueType]:
+ self, dag_ids: list[str] | None = None
+ ) -> dict[TaskInstanceKey, EventBufferValueType]:
"""
Returns and flush the event buffer from celery and kubernetes executor
@@ -157,7 +159,7 @@ def get_event_buffer(
return {**cleared_events_from_celery, **cleared_events_from_kubernetes}
- def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance]:
+ def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[TaskInstance]:
"""
Try to adopt running task instances that have been abandoned by a SchedulerJob dying.
@@ -165,19 +167,13 @@ def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance
re-scheduling)
:return: any TaskInstances that were unable to be adopted
- :rtype: list[airflow.models.TaskInstance]
"""
- celery_tis = []
- kubernetes_tis = []
- abandoned_tis = []
- for ti in tis:
- if ti.queue == self.KUBERNETES_QUEUE:
- kubernetes_tis.append(ti)
- else:
- celery_tis.append(ti)
- abandoned_tis.extend(self.celery_executor.try_adopt_task_instances(celery_tis))
- abandoned_tis.extend(self.kubernetes_executor.try_adopt_task_instances(kubernetes_tis))
- return abandoned_tis
+ celery_tis = [ti for ti in tis if ti.queue != self.KUBERNETES_QUEUE]
+ kubernetes_tis = [ti for ti in tis if ti.queue == self.KUBERNETES_QUEUE]
+ return [
+ *self.celery_executor.try_adopt_task_instances(celery_tis),
+ *self.kubernetes_executor.try_adopt_task_instances(kubernetes_tis),
+ ]
def end(self) -> None:
"""End celery and kubernetes executor"""
@@ -189,13 +185,12 @@ def terminate(self) -> None:
self.celery_executor.terminate()
self.kubernetes_executor.terminate()
- def _router(self, simple_task_instance: SimpleTaskInstance) -> Union[CeleryExecutor, KubernetesExecutor]:
+ def _router(self, simple_task_instance: SimpleTaskInstance) -> CeleryExecutor | KubernetesExecutor:
"""
Return either celery_executor or kubernetes_executor
:param simple_task_instance: SimpleTaskInstance
:return: celery_executor or kubernetes_executor
- :rtype: Union[CeleryExecutor, KubernetesExecutor]
"""
if simple_task_instance.queue == self.KUBERNETES_QUEUE:
return self.kubernetes_executor
diff --git a/airflow/executors/dask_executor.py b/airflow/executors/dask_executor.py
index 5d0896455e9fd..a2c2c571630a6 100644
--- a/airflow/executors/dask_executor.py
+++ b/airflow/executors/dask_executor.py
@@ -22,20 +22,22 @@
For more information on how the DaskExecutor works, take a look at the guide:
:ref:`executor:DaskExecutor`
"""
+from __future__ import annotations
+
import subprocess
-from typing import Any, Dict, Optional
+from typing import TYPE_CHECKING, Any
from distributed import Client, Future, as_completed
from distributed.security import Security
from airflow.configuration import conf
from airflow.exceptions import AirflowException
-from airflow.executors.base_executor import NOT_STARTED_MESSAGE, BaseExecutor, CommandType
+from airflow.executors.base_executor import BaseExecutor, CommandType
from airflow.models.taskinstance import TaskInstanceKey
# queue="default" is a special case since this is the base config default queue name,
# with respect to DaskExecutor, treat it as if no queue is provided
-_UNDEFINED_QUEUES = {None, 'default'}
+_UNDEFINED_QUEUES = {None, "default"}
class DaskExecutor(BaseExecutor):
@@ -44,16 +46,16 @@ class DaskExecutor(BaseExecutor):
def __init__(self, cluster_address=None):
super().__init__(parallelism=0)
if cluster_address is None:
- cluster_address = conf.get('dask', 'cluster_address')
+ cluster_address = conf.get("dask", "cluster_address")
if not cluster_address:
- raise ValueError('Please provide a Dask cluster address in airflow.cfg')
+ raise ValueError("Please provide a Dask cluster address in airflow.cfg")
self.cluster_address = cluster_address
# ssl / tls parameters
- self.tls_ca = conf.get('dask', 'tls_ca')
- self.tls_key = conf.get('dask', 'tls_key')
- self.tls_cert = conf.get('dask', 'tls_cert')
- self.client: Optional[Client] = None
- self.futures: Optional[Dict[Future, TaskInstanceKey]] = None
+ self.tls_ca = conf.get("dask", "tls_ca")
+ self.tls_key = conf.get("dask", "tls_key")
+ self.tls_cert = conf.get("dask", "tls_cert")
+ self.client: Client | None = None
+ self.futures: dict[Future, TaskInstanceKey] | None = None
def start(self) -> None:
if self.tls_ca or self.tls_key or self.tls_cert:
@@ -73,23 +75,22 @@ def execute_async(
self,
key: TaskInstanceKey,
command: CommandType,
- queue: Optional[str] = None,
- executor_config: Optional[Any] = None,
+ queue: str | None = None,
+ executor_config: Any | None = None,
) -> None:
+ if TYPE_CHECKING:
+ assert self.client
- self.validate_command(command)
+ self.validate_airflow_tasks_run_command(command)
def airflow_run():
return subprocess.check_call(command, close_fds=True)
- if not self.client:
- raise AirflowException(NOT_STARTED_MESSAGE)
-
resources = None
if queue not in _UNDEFINED_QUEUES:
scheduler_info = self.client.scheduler_info()
avail_queues = {
- resource for d in scheduler_info['workers'].values() for resource in d['resources']
+ resource for d in scheduler_info["workers"].values() for resource in d["resources"]
}
if queue not in avail_queues:
@@ -100,8 +101,9 @@ def airflow_run():
self.futures[future] = key # type: ignore
def _process_future(self, future: Future) -> None:
- if not self.futures:
- raise AirflowException(NOT_STARTED_MESSAGE)
+ if TYPE_CHECKING:
+ assert self.futures
+
if future.done():
key = self.futures[future]
if future.exception():
@@ -115,23 +117,25 @@ def _process_future(self, future: Future) -> None:
self.futures.pop(future)
def sync(self) -> None:
- if self.futures is None:
- raise AirflowException(NOT_STARTED_MESSAGE)
+ if TYPE_CHECKING:
+ assert self.futures
+
# make a copy so futures can be popped during iteration
for future in self.futures.copy():
self._process_future(future)
def end(self) -> None:
- if not self.client:
- raise AirflowException(NOT_STARTED_MESSAGE)
- if self.futures is None:
- raise AirflowException(NOT_STARTED_MESSAGE)
+ if TYPE_CHECKING:
+ assert self.client
+ assert self.futures
+
self.client.cancel(list(self.futures.keys()))
for future in as_completed(self.futures.copy()):
self._process_future(future)
def terminate(self):
- if self.futures is None:
- raise AirflowException(NOT_STARTED_MESSAGE)
+ if TYPE_CHECKING:
+ assert self.futures
+
self.client.cancel(self.futures.keys())
self.end()
diff --git a/airflow/executors/debug_executor.py b/airflow/executors/debug_executor.py
index 7a2dddec79358..d727ee8795112 100644
--- a/airflow/executors/debug_executor.py
+++ b/airflow/executors/debug_executor.py
@@ -22,9 +22,10 @@
For more information on how the DebugExecutor works, take a look at the guide:
:ref:`executor:DebugExecutor`
"""
+from __future__ import annotations
import threading
-from typing import Any, Dict, List, Optional
+from typing import Any
from airflow.configuration import conf
from airflow.executors.base_executor import BaseExecutor
@@ -44,9 +45,9 @@ class DebugExecutor(BaseExecutor):
def __init__(self):
super().__init__()
- self.tasks_to_run: List[TaskInstance] = []
+ self.tasks_to_run: list[TaskInstance] = []
# Place where we keep information for task instance raw run
- self.tasks_params: Dict[TaskInstanceKey, Dict[str, Any]] = {}
+ self.tasks_params: dict[TaskInstanceKey, dict[str, Any]] = {}
self.fail_fast = conf.getboolean("debug", "fail_fast")
def execute_async(self, *args, **kwargs) -> None:
@@ -75,7 +76,7 @@ def _run_task(self, ti: TaskInstance) -> bool:
key = ti.key
try:
params = self.tasks_params.pop(ti.key, {})
- ti._run_raw_task(job_id=ti.job_id, **params)
+ ti.run(job_id=ti.job_id, **params)
self.change_state(key, State.SUCCESS)
return True
except Exception as e:
@@ -88,13 +89,13 @@ def queue_task_instance(
self,
task_instance: TaskInstance,
mark_success: bool = False,
- pickle_id: Optional[str] = None,
+ pickle_id: str | None = None,
ignore_all_deps: bool = False,
ignore_depends_on_past: bool = False,
ignore_task_deps: bool = False,
ignore_ti_state: bool = False,
- pool: Optional[str] = None,
- cfg_path: Optional[str] = None,
+ pool: str | None = None,
+ cfg_path: str | None = None,
) -> None:
"""Queues task instance with empty command because we do not need it."""
self.queue_command(
diff --git a/airflow/executors/executor_constants.py b/airflow/executors/executor_constants.py
index 55a3a7f766301..c6219f3b4a965 100644
--- a/airflow/executors/executor_constants.py
+++ b/airflow/executors/executor_constants.py
@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
LOCAL_EXECUTOR = "LocalExecutor"
LOCAL_KUBERNETES_EXECUTOR = "LocalKubernetesExecutor"
diff --git a/airflow/executors/executor_loader.py b/airflow/executors/executor_loader.py
index 723060db0f179..56802017e4fe3 100644
--- a/airflow/executors/executor_loader.py
+++ b/airflow/executors/executor_loader.py
@@ -15,10 +15,12 @@
# specific language governing permissions and limitations
# under the License.
"""All executors."""
+from __future__ import annotations
+
import logging
from contextlib import suppress
from enum import Enum, unique
-from typing import TYPE_CHECKING, Optional, Tuple, Type
+from typing import TYPE_CHECKING
from airflow.exceptions import AirflowConfigException
from airflow.executors.executor_constants import (
@@ -51,33 +53,33 @@ class ConnectorSource(Enum):
class ExecutorLoader:
"""Keeps constants for all the currently available executors."""
- _default_executor: Optional["BaseExecutor"] = None
+ _default_executor: BaseExecutor | None = None
executors = {
- LOCAL_EXECUTOR: 'airflow.executors.local_executor.LocalExecutor',
- LOCAL_KUBERNETES_EXECUTOR: 'airflow.executors.local_kubernetes_executor.LocalKubernetesExecutor',
- SEQUENTIAL_EXECUTOR: 'airflow.executors.sequential_executor.SequentialExecutor',
- CELERY_EXECUTOR: 'airflow.executors.celery_executor.CeleryExecutor',
- CELERY_KUBERNETES_EXECUTOR: 'airflow.executors.celery_kubernetes_executor.CeleryKubernetesExecutor',
- DASK_EXECUTOR: 'airflow.executors.dask_executor.DaskExecutor',
- KUBERNETES_EXECUTOR: 'airflow.executors.kubernetes_executor.KubernetesExecutor',
- DEBUG_EXECUTOR: 'airflow.executors.debug_executor.DebugExecutor',
+ LOCAL_EXECUTOR: "airflow.executors.local_executor.LocalExecutor",
+ LOCAL_KUBERNETES_EXECUTOR: "airflow.executors.local_kubernetes_executor.LocalKubernetesExecutor",
+ SEQUENTIAL_EXECUTOR: "airflow.executors.sequential_executor.SequentialExecutor",
+ CELERY_EXECUTOR: "airflow.executors.celery_executor.CeleryExecutor",
+ CELERY_KUBERNETES_EXECUTOR: "airflow.executors.celery_kubernetes_executor.CeleryKubernetesExecutor",
+ DASK_EXECUTOR: "airflow.executors.dask_executor.DaskExecutor",
+ KUBERNETES_EXECUTOR: "airflow.executors.kubernetes_executor.KubernetesExecutor",
+ DEBUG_EXECUTOR: "airflow.executors.debug_executor.DebugExecutor",
}
@classmethod
- def get_default_executor(cls) -> "BaseExecutor":
- """Creates a new instance of the configured executor if none exists and returns it"""
+ def get_default_executor(cls) -> BaseExecutor:
+ """Creates a new instance of the configured executor if none exists and returns it."""
if cls._default_executor is not None:
return cls._default_executor
from airflow.configuration import conf
- executor_name = conf.get_mandatory_value('core', 'EXECUTOR')
+ executor_name = conf.get_mandatory_value("core", "EXECUTOR")
cls._default_executor = cls.load_executor(executor_name)
return cls._default_executor
@classmethod
- def load_executor(cls, executor_name: str) -> "BaseExecutor":
+ def load_executor(cls, executor_name: str) -> BaseExecutor:
"""
Loads the executor.
@@ -107,7 +109,7 @@ def load_executor(cls, executor_name: str) -> "BaseExecutor":
return executor_cls()
@classmethod
- def import_executor_cls(cls, executor_name: str) -> Tuple[Type["BaseExecutor"], ConnectorSource]:
+ def import_executor_cls(cls, executor_name: str) -> tuple[type[BaseExecutor], ConnectorSource]:
"""
Imports the executor class.
@@ -133,8 +135,7 @@ def import_executor_cls(cls, executor_name: str) -> Tuple[Type["BaseExecutor"],
return import_string(executor_name), ConnectorSource.CUSTOM_PATH
@classmethod
- def __load_celery_kubernetes_executor(cls) -> "BaseExecutor":
- """:return: an instance of CeleryKubernetesExecutor"""
+ def __load_celery_kubernetes_executor(cls) -> BaseExecutor:
celery_executor = import_string(cls.executors[CELERY_EXECUTOR])()
kubernetes_executor = import_string(cls.executors[KUBERNETES_EXECUTOR])()
@@ -142,8 +143,7 @@ def __load_celery_kubernetes_executor(cls) -> "BaseExecutor":
return celery_kubernetes_executor_cls(celery_executor, kubernetes_executor)
@classmethod
- def __load_local_kubernetes_executor(cls) -> "BaseExecutor":
- """:return: an instance of LocalKubernetesExecutor"""
+ def __load_local_kubernetes_executor(cls) -> BaseExecutor:
local_executor = import_string(cls.executors[LOCAL_EXECUTOR])()
kubernetes_executor = import_string(cls.executors[KUBERNETES_EXECUTOR])()
diff --git a/airflow/executors/kubernetes_executor.py b/airflow/executors/kubernetes_executor.py
index c76cf58f418d4..28f720f35e1c2 100644
--- a/airflow/executors/kubernetes_executor.py
+++ b/airflow/executors/kubernetes_executor.py
@@ -21,22 +21,24 @@
For more information on how the KubernetesExecutor works, take a look at the guide:
:ref:`executor:KubernetesExecutor`
"""
+from __future__ import annotations
import functools
import json
+import logging
import multiprocessing
import time
from datetime import timedelta
from queue import Empty, Queue
-from typing import Any, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple
from kubernetes import client, watch
from kubernetes.client import Configuration, models as k8s
from kubernetes.client.rest import ApiException
from urllib3.exceptions import ReadTimeoutError
-from airflow.exceptions import AirflowException
-from airflow.executors.base_executor import NOT_STARTED_MESSAGE, BaseExecutor, CommandType
+from airflow.exceptions import AirflowException, PodMutationHookException, PodReconciliationError
+from airflow.executors.base_executor import BaseExecutor, CommandType
from airflow.kubernetes import pod_generator
from airflow.kubernetes.kube_client import get_kube_client
from airflow.kubernetes.kube_config import KubeConfig
@@ -77,10 +79,10 @@ class KubernetesJobWatcher(multiprocessing.Process, LoggingMixin):
def __init__(
self,
- namespace: Optional[str],
+ namespace: str | None,
multi_namespace_mode: bool,
- watcher_queue: 'Queue[KubernetesWatchType]',
- resource_version: Optional[str],
+ watcher_queue: Queue[KubernetesWatchType],
+ resource_version: str | None,
scheduler_job_id: str,
kube_config: Configuration,
):
@@ -94,9 +96,10 @@ def __init__(
def run(self) -> None:
"""Performs watching"""
+ if TYPE_CHECKING:
+ assert self.scheduler_job_id
+
kube_client: client.CoreV1Api = get_kube_client()
- if not self.scheduler_job_id:
- raise AirflowException(NOT_STARTED_MESSAGE)
while True:
try:
self.resource_version = self._run(
@@ -108,34 +111,34 @@ def run(self) -> None:
)
time.sleep(1)
except Exception:
- self.log.exception('Unknown error in KubernetesJobWatcher. Failing')
+ self.log.exception("Unknown error in KubernetesJobWatcher. Failing")
self.resource_version = "0"
ResourceVersion().resource_version = "0"
raise
else:
self.log.warning(
- 'Watch died gracefully, starting back up with: last resource_version: %s',
+ "Watch died gracefully, starting back up with: last resource_version: %s",
self.resource_version,
)
def _run(
self,
kube_client: client.CoreV1Api,
- resource_version: Optional[str],
+ resource_version: str | None,
scheduler_job_id: str,
kube_config: Any,
- ) -> Optional[str]:
- self.log.info('Event: and now my watch begins starting at resource_version: %s', resource_version)
+ ) -> str | None:
+ self.log.info("Event: and now my watch begins starting at resource_version: %s", resource_version)
watcher = watch.Watch()
- kwargs = {'label_selector': f'airflow-worker={scheduler_job_id}'}
+ kwargs = {"label_selector": f"airflow-worker={scheduler_job_id}"}
if resource_version:
- kwargs['resource_version'] = resource_version
+ kwargs["resource_version"] = resource_version
if kube_config.kube_client_request_args:
for key, value in kube_config.kube_client_request_args.items():
kwargs[key] = value
- last_resource_version: Optional[str] = None
+ last_resource_version: str | None = None
if self.multi_namespace_mode:
list_worker_pods = functools.partial(
watcher.stream, kube_client.list_pod_for_all_namespaces, **kwargs
@@ -145,21 +148,21 @@ def _run(
watcher.stream, kube_client.list_namespaced_pod, self.namespace, **kwargs
)
for event in list_worker_pods():
- task = event['object']
- self.log.info('Event: %s had an event of type %s', task.metadata.name, event['type'])
- if event['type'] == 'ERROR':
+ task = event["object"]
+ self.log.debug("Event: %s had an event of type %s", task.metadata.name, event["type"])
+ if event["type"] == "ERROR":
return self.process_error(event)
annotations = task.metadata.annotations
task_instance_related_annotations = {
- 'dag_id': annotations['dag_id'],
- 'task_id': annotations['task_id'],
- 'execution_date': annotations.get('execution_date'),
- 'run_id': annotations.get('run_id'),
- 'try_number': annotations['try_number'],
+ "dag_id": annotations["dag_id"],
+ "task_id": annotations["task_id"],
+ "execution_date": annotations.get("execution_date"),
+ "run_id": annotations.get("run_id"),
+ "try_number": annotations["try_number"],
}
- map_index = annotations.get('map_index')
+ map_index = annotations.get("map_index")
if map_index is not None:
- task_instance_related_annotations['map_index'] = map_index
+ task_instance_related_annotations["map_index"] = map_index
self.process_status(
pod_id=task.metadata.name,
@@ -175,14 +178,14 @@ def _run(
def process_error(self, event: Any) -> str:
"""Process error response"""
- self.log.error('Encountered Error response from k8s list namespaced pod stream => %s', event)
- raw_object = event['raw_object']
- if raw_object['code'] == 410:
+ self.log.error("Encountered Error response from k8s list namespaced pod stream => %s", event)
+ raw_object = event["raw_object"]
+ if raw_object["code"] == 410:
self.log.info(
- 'Kubernetes resource version is too old, must reset to 0 => %s', (raw_object['message'],)
+ "Kubernetes resource version is too old, must reset to 0 => %s", (raw_object["message"],)
)
# Return resource version 0
- return '0'
+ return "0"
raise AirflowException(
f"Kubernetes failure for {raw_object['reason']} with code {raw_object['code']} and message: "
f"{raw_object['message']}"
@@ -193,33 +196,33 @@ def process_status(
pod_id: str,
namespace: str,
status: str,
- annotations: Dict[str, str],
+ annotations: dict[str, str],
resource_version: str,
event: Any,
) -> None:
"""Process status response"""
- if status == 'Pending':
- if event['type'] == 'DELETED':
- self.log.info('Event: Failed to start pod %s', pod_id)
+ if status == "Pending":
+ if event["type"] == "DELETED":
+ self.log.info("Event: Failed to start pod %s", pod_id)
self.watcher_queue.put((pod_id, namespace, State.FAILED, annotations, resource_version))
else:
- self.log.info('Event: %s Pending', pod_id)
- elif status == 'Failed':
- self.log.error('Event: %s Failed', pod_id)
+ self.log.debug("Event: %s Pending", pod_id)
+ elif status == "Failed":
+ self.log.error("Event: %s Failed", pod_id)
self.watcher_queue.put((pod_id, namespace, State.FAILED, annotations, resource_version))
- elif status == 'Succeeded':
- self.log.info('Event: %s Succeeded', pod_id)
+ elif status == "Succeeded":
+ self.log.info("Event: %s Succeeded", pod_id)
self.watcher_queue.put((pod_id, namespace, None, annotations, resource_version))
- elif status == 'Running':
- if event['type'] == 'DELETED':
- self.log.info('Event: Pod %s deleted before it could complete', pod_id)
+ elif status == "Running":
+ if event["type"] == "DELETED":
+ self.log.info("Event: Pod %s deleted before it could complete", pod_id)
self.watcher_queue.put((pod_id, namespace, State.FAILED, annotations, resource_version))
else:
- self.log.info('Event: %s is Running', pod_id)
+ self.log.info("Event: %s is Running", pod_id)
else:
self.log.warning(
- 'Event: Invalid state: %s on pod: %s in namespace %s with annotations: %s with '
- 'resource_version: %s',
+ "Event: Invalid state: %s on pod: %s in namespace %s with annotations: %s with "
+ "resource_version: %s",
status,
pod_id,
namespace,
@@ -234,8 +237,8 @@ class AirflowKubernetesScheduler(LoggingMixin):
def __init__(
self,
kube_config: Any,
- task_queue: 'Queue[KubernetesJobType]',
- result_queue: 'Queue[KubernetesResultsType]',
+ task_queue: Queue[KubernetesJobType],
+ result_queue: Queue[KubernetesResultsType],
kube_client: client.CoreV1Api,
scheduler_job_id: str,
):
@@ -254,19 +257,22 @@ def __init__(
def run_pod_async(self, pod: k8s.V1Pod, **kwargs):
"""Runs POD asynchronously"""
- pod_mutation_hook(pod)
+ try:
+ pod_mutation_hook(pod)
+ except Exception as e:
+ raise PodMutationHookException(e)
sanitized_pod = self.kube_client.api_client.sanitize_for_serialization(pod)
json_pod = json.dumps(sanitized_pod, indent=2)
- self.log.debug('Pod Creation Request: \n%s', json_pod)
+ self.log.debug("Pod Creation Request: \n%s", json_pod)
try:
resp = self.kube_client.create_namespaced_pod(
body=sanitized_pod, namespace=pod.metadata.namespace, **kwargs
)
- self.log.debug('Pod Creation Response: %s', resp)
+ self.log.debug("Pod Creation Response: %s", resp)
except Exception as e:
- self.log.exception('Exception when attempting to create Namespaced Pod: %s', json_pod)
+ self.log.exception("Exception when attempting to create Namespaced Pod: %s", json_pod)
raise e
return resp
@@ -288,7 +294,7 @@ def _health_check_kube_watcher(self):
self.log.debug("KubeJobWatcher alive, continuing")
else:
self.log.error(
- 'Error while health checking kube watcher process. Process died for unknown reasons'
+ "Error while health checking kube watcher process. Process died for unknown reasons"
)
ResourceVersion().resource_version = "0"
self.kube_watcher = self._make_kube_watcher()
@@ -300,8 +306,8 @@ def run_next(self, next_job: KubernetesJobType) -> None:
and store relevant info in the current_jobs map so we can track the job's
status
"""
- self.log.info('Kubernetes job is %s', str(next_job).replace("\n", " "))
key, command, kube_executor_config, pod_template_file = next_job
+
dag_id, task_id, run_id, try_number, map_index = key
if command[0:3] != ["airflow", "tasks", "run"]:
@@ -331,6 +337,7 @@ def run_next(self, next_job: KubernetesJobType) -> None:
)
# Reconcile the pod generated by the Operator and the Pod
# generated by the .cfg file
+ self.log.info("Creating kubernetes pod for job is %s, with pod name %s", key, pod.metadata.name)
self.log.debug("Kubernetes running for command %s", command)
self.log.debug("Kubernetes launching image %s", pod.spec.containers[0].image)
@@ -378,21 +385,21 @@ def sync(self) -> None:
def process_watcher_task(self, task: KubernetesWatchType) -> None:
"""Process the task by watcher."""
pod_id, namespace, state, annotations, resource_version = task
- self.log.info(
- 'Attempting to finish pod; pod_id: %s; state: %s; annotations: %s', pod_id, state, annotations
+ self.log.debug(
+ "Attempting to finish pod; pod_id: %s; state: %s; annotations: %s", pod_id, state, annotations
)
key = annotations_to_key(annotations=annotations)
if key:
- self.log.debug('finishing job %s - %s (%s)', key, state, pod_id)
+ self.log.debug("finishing job %s - %s (%s)", key, state, pod_id)
self.result_queue.put((key, state, pod_id, namespace, resource_version))
def _flush_watcher_queue(self) -> None:
- self.log.debug('Executor shutting down, watcher_queue approx. size=%d', self.watcher_queue.qsize())
+ self.log.debug("Executor shutting down, watcher_queue approx. size=%d", self.watcher_queue.qsize())
while True:
try:
task = self.watcher_queue.get_nowait()
# Ignoring it since it can only have either FAILED or SUCCEEDED pods
- self.log.warning('Executor shutting down, IGNORING watcher task=%s', task)
+ self.log.warning("Executor shutting down, IGNORING watcher task=%s", task)
self.watcher_queue.task_done()
except Empty:
break
@@ -411,7 +418,7 @@ def terminate(self) -> None:
self._manager.shutdown()
-def get_base_pod_from_template(pod_template_file: Optional[str], kube_config: Any) -> k8s.V1Pod:
+def get_base_pod_from_template(pod_template_file: str | None, kube_config: Any) -> k8s.V1Pod:
"""
Reads either the pod_template_file set in the executor_config or the base pod_template_file
set in the airflow.cfg to craft a "base pod" that will be used by the KubernetesExecutor
@@ -434,14 +441,14 @@ class KubernetesExecutor(BaseExecutor):
def __init__(self):
self.kube_config = KubeConfig()
self._manager = multiprocessing.Manager()
- self.task_queue: 'Queue[KubernetesJobType]' = self._manager.Queue()
- self.result_queue: 'Queue[KubernetesResultsType]' = self._manager.Queue()
- self.kube_scheduler: Optional[AirflowKubernetesScheduler] = None
- self.kube_client: Optional[client.CoreV1Api] = None
- self.scheduler_job_id: Optional[str] = None
- self.event_scheduler: Optional[EventScheduler] = None
- self.last_handled: Dict[TaskInstanceKey, float] = {}
- self.kubernetes_queue: Optional[str] = None
+ self.task_queue: Queue[KubernetesJobType] = self._manager.Queue()
+ self.result_queue: Queue[KubernetesResultsType] = self._manager.Queue()
+ self.kube_scheduler: AirflowKubernetesScheduler | None = None
+ self.kube_client: client.CoreV1Api | None = None
+ self.scheduler_job_id: str | None = None
+ self.event_scheduler: EventScheduler | None = None
+ self.last_handled: dict[TaskInstanceKey, float] = {}
+ self.kubernetes_queue: str | None = None
super().__init__(parallelism=self.kube_config.parallelism)
@provide_session
@@ -457,15 +464,17 @@ def clear_not_launched_queued_tasks(self, session=None) -> None:
is around, and if not, and there's no matching entry in our own
task_queue, marks it for re-execution.
"""
- self.log.debug("Clearing tasks that have not been launched")
- if not self.kube_client:
- raise AirflowException(NOT_STARTED_MESSAGE)
+ if TYPE_CHECKING:
+ assert self.kube_client
- query = session.query(TaskInstance).filter(TaskInstance.state == State.QUEUED)
+ self.log.debug("Clearing tasks that have not been launched")
+ query = session.query(TaskInstance).filter(
+ TaskInstance.state == State.QUEUED, TaskInstance.queued_by_job_id == self.job_id
+ )
if self.kubernetes_queue:
query = query.filter(TaskInstance.queue == self.kubernetes_queue)
- queued_tis: List[TaskInstance] = query.all()
- self.log.info('Found %s queued task instances', len(queued_tis))
+ queued_tis: list[TaskInstance] = query.all()
+ self.log.info("Found %s queued task instances", len(queued_tis))
# Go through the "last seen" dictionary and clean out old entries
allowed_age = self.kube_config.worker_pods_queued_check_interval * 3
@@ -488,25 +497,25 @@ def clear_not_launched_queued_tasks(self, session=None) -> None:
)
if ti.map_index >= 0:
# Old tasks _couldn't_ be mapped, so we don't have to worry about compat
- base_label_selector += f',map_index={ti.map_index}'
+ base_label_selector += f",map_index={ti.map_index}"
kwargs = dict(label_selector=base_label_selector)
if self.kube_config.kube_client_request_args:
kwargs.update(**self.kube_config.kube_client_request_args)
# Try run_id first
- kwargs['label_selector'] += ',run_id=' + pod_generator.make_safe_label_value(ti.run_id)
+ kwargs["label_selector"] += ",run_id=" + pod_generator.make_safe_label_value(ti.run_id)
pod_list = self.kube_client.list_namespaced_pod(self.kube_config.kube_namespace, **kwargs)
if pod_list.items:
continue
# Fallback to old style of using execution_date
- kwargs['label_selector'] = (
- f'{base_label_selector},'
- f'execution_date={pod_generator.datetime_to_label_safe_datestring(ti.execution_date)}'
+ kwargs["label_selector"] = (
+ f"{base_label_selector},"
+ f"execution_date={pod_generator.datetime_to_label_safe_datestring(ti.execution_date)}"
)
pod_list = self.kube_client.list_namespaced_pod(self.kube_config.kube_namespace, **kwargs)
if pod_list.items:
continue
- self.log.info('TaskInstance: %s found in queued state but was not launched, rescheduling', ti)
+ self.log.info("TaskInstance: %s found in queued state but was not launched, rescheduling", ti)
session.query(TaskInstance).filter(
TaskInstance.dag_id == ti.dag_id,
TaskInstance.task_id == ti.task_id,
@@ -516,11 +525,11 @@ def clear_not_launched_queued_tasks(self, session=None) -> None:
def start(self) -> None:
"""Starts the executor"""
- self.log.info('Start Kubernetes executor')
+ self.log.info("Start Kubernetes executor")
if not self.job_id:
raise AirflowException("Could not get scheduler_job_id")
self.scheduler_job_id = str(self.job_id)
- self.log.debug('Start with scheduler_job_id: %s', self.scheduler_job_id)
+ self.log.debug("Start with scheduler_job_id: %s", self.scheduler_job_id)
self.kube_client = get_kube_client()
self.kube_scheduler = AirflowKubernetesScheduler(
self.kube_config, self.task_queue, self.result_queue, self.kube_client, self.scheduler_job_id
@@ -530,6 +539,7 @@ def start(self) -> None:
self.kube_config.worker_pods_pending_timeout_check_interval,
self._check_worker_pods_pending_timeout,
)
+
self.event_scheduler.call_regular_interval(
self.kube_config.worker_pods_queued_check_interval,
self.clear_not_launched_queued_tasks,
@@ -542,15 +552,22 @@ def execute_async(
self,
key: TaskInstanceKey,
command: CommandType,
- queue: Optional[str] = None,
- executor_config: Optional[Any] = None,
+ queue: str | None = None,
+ executor_config: Any | None = None,
) -> None:
"""Executes task asynchronously"""
- self.log.info('Add task %s with command %s with executor_config %s', key, command, executor_config)
+ if TYPE_CHECKING:
+ assert self.task_queue
+
+ if self.log.isEnabledFor(logging.DEBUG):
+ self.log.debug("Add task %s with command %s, executor_config %s", key, command, executor_config)
+ else:
+ self.log.info("Add task %s with command %s", key, command)
+
try:
kube_executor_config = PodGenerator.from_obj(executor_config)
except Exception:
- self.log.error("Invalid executor_config for %s", key)
+ self.log.error("Invalid executor_config for %s. Executor_config: %s", key, executor_config)
self.fail(key=key, info="Invalid executor_config passed")
return
@@ -558,8 +575,6 @@ def execute_async(
pod_template_file = executor_config.get("pod_template_file", None)
else:
pod_template_file = None
- if not self.task_queue:
- raise AirflowException(NOT_STARTED_MESSAGE)
self.event_buffer[key] = (State.QUEUED, self.scheduler_job_id)
self.task_queue.put((key, command, kube_executor_config, pod_template_file))
# We keep a temporary local record that we've handled this so we don't
@@ -568,22 +583,18 @@ def execute_async(
def sync(self) -> None:
"""Synchronize task state."""
+ if TYPE_CHECKING:
+ assert self.scheduler_job_id
+ assert self.kube_scheduler
+ assert self.kube_config
+ assert self.result_queue
+ assert self.task_queue
+ assert self.event_scheduler
+
if self.running:
- self.log.debug('self.running: %s', self.running)
+ self.log.debug("self.running: %s", self.running)
if self.queued_tasks:
- self.log.debug('self.queued: %s', self.queued_tasks)
- if not self.scheduler_job_id:
- raise AirflowException(NOT_STARTED_MESSAGE)
- if not self.kube_scheduler:
- raise AirflowException(NOT_STARTED_MESSAGE)
- if not self.kube_config:
- raise AirflowException(NOT_STARTED_MESSAGE)
- if not self.result_queue:
- raise AirflowException(NOT_STARTED_MESSAGE)
- if not self.task_queue:
- raise AirflowException(NOT_STARTED_MESSAGE)
- if not self.event_scheduler:
- raise AirflowException(NOT_STARTED_MESSAGE)
+ self.log.debug("self.queued: %s", self.queued_tasks)
self.kube_scheduler.sync()
last_resource_version = None
@@ -593,7 +604,7 @@ def sync(self) -> None:
try:
key, state, pod_id, namespace, resource_version = results
last_resource_version = resource_version
- self.log.info('Changing state of %s to %s', results, state)
+ self.log.info("Changing state of %s to %s", results, state)
try:
self._change_state(key, state, pod_id, namespace)
except Exception as e:
@@ -617,8 +628,14 @@ def sync(self) -> None:
task = self.task_queue.get_nowait()
try:
self.kube_scheduler.run_next(task)
+ except PodReconciliationError as e:
+ self.log.error(
+ "Pod reconciliation failed, likely due to kubernetes library upgrade. "
+ "Try clearing the task to re-run.",
+ exc_info=True,
+ )
+ self.fail(task[0], e)
except ApiException as e:
-
# These codes indicate something is wrong with pod definition; otherwise we assume pod
# definition is ok, and that retrying may work
if e.status in (400, 422):
@@ -627,11 +644,19 @@ def sync(self) -> None:
self.change_state(key, State.FAILED, e)
else:
self.log.warning(
- 'ApiException when attempting to run task, re-queueing. Reason: %r. Message: %s',
+ "ApiException when attempting to run task, re-queueing. Reason: %r. Message: %s",
e.reason,
- json.loads(e.body)['message'],
+ json.loads(e.body)["message"],
)
self.task_queue.put(task)
+ except PodMutationHookException as e:
+ key, _, _, _ = task
+ self.log.error(
+ "Pod Mutation Hook failed for the task %s. Failing task. Details: %s",
+ key,
+ e,
+ )
+ self.fail(key, e)
finally:
self.task_queue.task_done()
except Empty:
@@ -643,15 +668,16 @@ def sync(self) -> None:
def _check_worker_pods_pending_timeout(self):
"""Check if any pending worker pods have timed out"""
- if not self.scheduler_job_id:
- raise AirflowException(NOT_STARTED_MESSAGE)
+ if TYPE_CHECKING:
+ assert self.scheduler_job_id
+
timeout = self.kube_config.worker_pods_pending_timeout
- self.log.debug('Looking for pending worker pods older than %d seconds', timeout)
+ self.log.debug("Looking for pending worker pods older than %d seconds", timeout)
kwargs = {
- 'limit': self.kube_config.worker_pods_pending_timeout_batch_size,
- 'field_selector': 'status.phase=Pending',
- 'label_selector': f'airflow-worker={self.scheduler_job_id}',
+ "limit": self.kube_config.worker_pods_pending_timeout_batch_size,
+ "field_selector": "status.phase=Pending",
+ "label_selector": f"airflow-worker={self.scheduler_job_id}",
**self.kube_config.kube_client_request_args,
}
if self.kube_config.multi_namespace_mode:
@@ -670,35 +696,36 @@ def _check_worker_pods_pending_timeout(self):
self.log.error(
(
'Pod "%s" has been pending for longer than %d seconds.'
- 'It will be deleted and set to failed.'
+ "It will be deleted and set to failed."
),
pod.metadata.name,
timeout,
)
self.kube_scheduler.delete_pod(pod.metadata.name, pod.metadata.namespace)
- def _change_state(self, key: TaskInstanceKey, state: Optional[str], pod_id: str, namespace: str) -> None:
+ def _change_state(self, key: TaskInstanceKey, state: str | None, pod_id: str, namespace: str) -> None:
+ if TYPE_CHECKING:
+ assert self.kube_scheduler
+
if state != State.RUNNING:
if self.kube_config.delete_worker_pods:
- if not self.kube_scheduler:
- raise AirflowException(NOT_STARTED_MESSAGE)
if state != State.FAILED or self.kube_config.delete_worker_pods_on_failure:
self.kube_scheduler.delete_pod(pod_id, namespace)
- self.log.info('Deleted pod: %s in namespace %s', str(key), str(namespace))
+ self.log.info("Deleted pod: %s in namespace %s", str(key), str(namespace))
try:
self.running.remove(key)
except KeyError:
- self.log.debug('Could not find key: %s', str(key))
+ self.log.debug("Could not find key: %s", str(key))
self.event_buffer[key] = state, None
- def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance]:
+ def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[TaskInstance]:
tis_to_flush = [ti for ti in tis if not ti.queued_by_job_id]
scheduler_job_ids = {ti.queued_by_job_id for ti in tis}
pod_ids = {ti.key: ti for ti in tis if ti.queued_by_job_id}
kube_client: client.CoreV1Api = self.kube_client
for scheduler_job_id in scheduler_job_ids:
scheduler_job_id = pod_generator.make_safe_label_value(str(scheduler_job_id))
- kwargs = {'label_selector': f'airflow-worker={scheduler_job_id}'}
+ kwargs = {"label_selector": f"airflow-worker={scheduler_job_id}"}
pod_list = kube_client.list_namespaced_pod(namespace=self.kube_config.kube_namespace, **kwargs)
for pod in pod_list.items:
self.adopt_launched_task(kube_client, pod, pod_ids)
@@ -707,7 +734,7 @@ def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance
return tis_to_flush
def adopt_launched_task(
- self, kube_client: client.CoreV1Api, pod: k8s.V1Pod, pod_ids: Dict[TaskInstanceKey, k8s.V1Pod]
+ self, kube_client: client.CoreV1Api, pod: k8s.V1Pod, pod_ids: dict[TaskInstanceKey, k8s.V1Pod]
) -> None:
"""
Patch existing pod so that the current KubernetesJobWatcher can monitor it via label selectors
@@ -716,83 +743,88 @@ def adopt_launched_task(
:param pod: V1Pod spec that we will patch with new label
:param pod_ids: pod_ids we expect to patch.
"""
- if not self.scheduler_job_id:
- raise AirflowException(NOT_STARTED_MESSAGE)
+ if TYPE_CHECKING:
+ assert self.scheduler_job_id
+
self.log.info("attempting to adopt pod %s", pod.metadata.name)
- pod.metadata.labels['airflow-worker'] = pod_generator.make_safe_label_value(self.scheduler_job_id)
pod_id = annotations_to_key(pod.metadata.annotations)
if pod_id not in pod_ids:
self.log.error("attempting to adopt taskinstance which was not specified by database: %s", pod_id)
return
+ new_worker_id_label = pod_generator.make_safe_label_value(self.scheduler_job_id)
try:
kube_client.patch_namespaced_pod(
name=pod.metadata.name,
namespace=pod.metadata.namespace,
- body=PodGenerator.serialize_pod(pod),
+ body={"metadata": {"labels": {"airflow-worker": new_worker_id_label}}},
)
- pod_ids.pop(pod_id)
- self.running.add(pod_id)
except ApiException as e:
self.log.info("Failed to adopt pod %s. Reason: %s", pod.metadata.name, e)
+ return
+
+ del pod_ids[pod_id]
+ self.running.add(pod_id)
def _adopt_completed_pods(self, kube_client: client.CoreV1Api) -> None:
"""
-
- Patch completed pod so that the KubernetesJobWatcher can delete it.
+ Patch completed pods so that the KubernetesJobWatcher can delete them.
:param kube_client: kubernetes client for speaking to kube API
"""
- if not self.scheduler_job_id:
- raise AirflowException(NOT_STARTED_MESSAGE)
+ if TYPE_CHECKING:
+ assert self.scheduler_job_id
+
+ new_worker_id_label = pod_generator.make_safe_label_value(self.scheduler_job_id)
kwargs = {
- 'field_selector': "status.phase=Succeeded",
- 'label_selector': 'kubernetes_executor=True',
+ "field_selector": "status.phase=Succeeded",
+ "label_selector": f"kubernetes_executor=True,airflow-worker!={new_worker_id_label}",
}
pod_list = kube_client.list_namespaced_pod(namespace=self.kube_config.kube_namespace, **kwargs)
for pod in pod_list.items:
self.log.info("Attempting to adopt pod %s", pod.metadata.name)
- pod.metadata.labels['airflow-worker'] = pod_generator.make_safe_label_value(self.scheduler_job_id)
try:
kube_client.patch_namespaced_pod(
name=pod.metadata.name,
namespace=pod.metadata.namespace,
- body=PodGenerator.serialize_pod(pod),
+ body={"metadata": {"labels": {"airflow-worker": new_worker_id_label}}},
)
except ApiException as e:
self.log.info("Failed to adopt pod %s. Reason: %s", pod.metadata.name, e)
def _flush_task_queue(self) -> None:
- if not self.task_queue:
- raise AirflowException(NOT_STARTED_MESSAGE)
- self.log.debug('Executor shutting down, task_queue approximate size=%d', self.task_queue.qsize())
+ if TYPE_CHECKING:
+ assert self.task_queue
+
+ self.log.debug("Executor shutting down, task_queue approximate size=%d", self.task_queue.qsize())
while True:
try:
task = self.task_queue.get_nowait()
# This is a new task to run thus ok to ignore.
- self.log.warning('Executor shutting down, will NOT run task=%s', task)
+ self.log.warning("Executor shutting down, will NOT run task=%s", task)
self.task_queue.task_done()
except Empty:
break
def _flush_result_queue(self) -> None:
- if not self.result_queue:
- raise AirflowException(NOT_STARTED_MESSAGE)
- self.log.debug('Executor shutting down, result_queue approximate size=%d', self.result_queue.qsize())
+ if TYPE_CHECKING:
+ assert self.result_queue
+
+ self.log.debug("Executor shutting down, result_queue approximate size=%d", self.result_queue.qsize())
while True:
try:
results = self.result_queue.get_nowait()
- self.log.warning('Executor shutting down, flushing results=%s', results)
+ self.log.warning("Executor shutting down, flushing results=%s", results)
try:
key, state, pod_id, namespace, resource_version = results
self.log.info(
- 'Changing state of %s to %s : resource_version=%d', results, state, resource_version
+ "Changing state of %s to %s : resource_version=%d", results, state, resource_version
)
try:
self._change_state(key, state, pod_id, namespace)
except Exception as e:
self.log.exception(
- 'Ignoring exception: %s when attempting to change state of %s to %s.',
+ "Ignoring exception: %s when attempting to change state of %s to %s.",
e,
results,
state,
@@ -804,20 +836,22 @@ def _flush_result_queue(self) -> None:
def end(self) -> None:
"""Called when the executor shuts down"""
- if not self.task_queue:
- raise AirflowException(NOT_STARTED_MESSAGE)
- if not self.result_queue:
- raise AirflowException(NOT_STARTED_MESSAGE)
- if not self.kube_scheduler:
- raise AirflowException(NOT_STARTED_MESSAGE)
- self.log.info('Shutting down Kubernetes executor')
- self.log.debug('Flushing task_queue...')
- self._flush_task_queue()
- self.log.debug('Flushing result_queue...')
- self._flush_result_queue()
- # Both queues should be empty...
- self.task_queue.join()
- self.result_queue.join()
+ if TYPE_CHECKING:
+ assert self.task_queue
+ assert self.result_queue
+ assert self.kube_scheduler
+
+ self.log.info("Shutting down Kubernetes executor")
+ try:
+ self.log.debug("Flushing task_queue...")
+ self._flush_task_queue()
+ self.log.debug("Flushing result_queue...")
+ self._flush_result_queue()
+ # Both queues should be empty...
+ self.task_queue.join()
+ self.result_queue.join()
+ except ConnectionResetError:
+ self.log.exception("Connection Reset error while flushing task_queue and result_queue.")
if self.kube_scheduler:
self.kube_scheduler.terminate()
self._manager.shutdown()
diff --git a/airflow/executors/local_executor.py b/airflow/executors/local_executor.py
index 431add69bcfbb..c2c82d863944a 100644
--- a/airflow/executors/local_executor.py
+++ b/airflow/executors/local_executor.py
@@ -22,6 +22,8 @@
For more information on how the LocalExecutor works, take a look at the guide:
:ref:`executor:LocalExecutor`
"""
+from __future__ import annotations
+
import logging
import os
import subprocess
@@ -29,13 +31,13 @@
from multiprocessing import Manager, Process
from multiprocessing.managers import SyncManager
from queue import Empty, Queue
-from typing import Any, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Any, Optional, Tuple
from setproctitle import getproctitle, setproctitle
from airflow import settings
from airflow.exceptions import AirflowException
-from airflow.executors.base_executor import NOT_STARTED_MESSAGE, PARALLELISM, BaseExecutor, CommandType
+from airflow.executors.base_executor import PARALLELISM, BaseExecutor, CommandType
from airflow.models.taskinstance import TaskInstanceKey, TaskInstanceStateType
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.state import State
@@ -54,10 +56,10 @@ class LocalWorkerBase(Process, LoggingMixin):
:param result_queue: the queue to store result state
"""
- def __init__(self, result_queue: 'Queue[TaskInstanceStateType]'):
+ def __init__(self, result_queue: Queue[TaskInstanceStateType]):
super().__init__(target=self.do_work)
self.daemon: bool = True
- self.result_queue: 'Queue[TaskInstanceStateType]' = result_queue
+ self.result_queue: Queue[TaskInstanceStateType] = result_queue
def run(self):
# We know we've just started a new process, so lets disconnect from the metadata db now
@@ -148,7 +150,7 @@ class LocalWorker(LocalWorkerBase):
"""
def __init__(
- self, result_queue: 'Queue[TaskInstanceStateType]', key: TaskInstanceKey, command: CommandType
+ self, result_queue: Queue[TaskInstanceStateType], key: TaskInstanceKey, command: CommandType
):
super().__init__(result_queue)
self.key: TaskInstanceKey = key
@@ -168,7 +170,7 @@ class QueuedLocalWorker(LocalWorkerBase):
:param result_queue: queue where worker puts results after finishing tasks
"""
- def __init__(self, task_queue: 'Queue[ExecutorWorkType]', result_queue: 'Queue[TaskInstanceStateType]'):
+ def __init__(self, task_queue: Queue[ExecutorWorkType], result_queue: Queue[TaskInstanceStateType]):
super().__init__(result_queue=result_queue)
self.task_queue = task_queue
@@ -205,14 +207,12 @@ def __init__(self, parallelism: int = PARALLELISM):
super().__init__(parallelism=parallelism)
if self.parallelism < 0:
raise AirflowException("parallelism must be bigger than or equal to 0")
- self.manager: Optional[SyncManager] = None
- self.result_queue: Optional['Queue[TaskInstanceStateType]'] = None
- self.workers: List[QueuedLocalWorker] = []
+ self.manager: SyncManager | None = None
+ self.result_queue: Queue[TaskInstanceStateType] | None = None
+ self.workers: list[QueuedLocalWorker] = []
self.workers_used: int = 0
self.workers_active: int = 0
- self.impl: Optional[
- Union['LocalExecutor.UnlimitedParallelism', 'LocalExecutor.LimitedParallelism']
- ] = None
+ self.impl: None | (LocalExecutor.UnlimitedParallelism | LocalExecutor.LimitedParallelism) = None
class UnlimitedParallelism:
"""
@@ -222,8 +222,8 @@ class UnlimitedParallelism:
:param executor: the executor instance to implement.
"""
- def __init__(self, executor: 'LocalExecutor'):
- self.executor: 'LocalExecutor' = executor
+ def __init__(self, executor: LocalExecutor):
+ self.executor: LocalExecutor = executor
def start(self) -> None:
"""Starts the executor."""
@@ -234,8 +234,8 @@ def execute_async(
self,
key: TaskInstanceKey,
command: CommandType,
- queue: Optional[str] = None,
- executor_config: Optional[Any] = None,
+ queue: str | None = None,
+ executor_config: Any | None = None,
) -> None:
"""
Executes task asynchronously.
@@ -245,8 +245,9 @@ def execute_async(
:param queue: Name of the queue
:param executor_config: configuration for the executor
"""
- if not self.executor.result_queue:
- raise AirflowException(NOT_STARTED_MESSAGE)
+ if TYPE_CHECKING:
+ assert self.executor.result_queue
+
local_worker = LocalWorker(self.executor.result_queue, key=key, command=command)
self.executor.workers_used += 1
self.executor.workers_active += 1
@@ -278,17 +279,17 @@ class LimitedParallelism:
:param executor: the executor instance to implement.
"""
- def __init__(self, executor: 'LocalExecutor'):
- self.executor: 'LocalExecutor' = executor
- self.queue: Optional['Queue[ExecutorWorkType]'] = None
+ def __init__(self, executor: LocalExecutor):
+ self.executor: LocalExecutor = executor
+ self.queue: Queue[ExecutorWorkType] | None = None
def start(self) -> None:
"""Starts limited parallelism implementation."""
- if not self.executor.manager:
- raise AirflowException(NOT_STARTED_MESSAGE)
+ if TYPE_CHECKING:
+ assert self.executor.manager
+ assert self.executor.result_queue
+
self.queue = self.executor.manager.Queue()
- if not self.executor.result_queue:
- raise AirflowException(NOT_STARTED_MESSAGE)
self.executor.workers = [
QueuedLocalWorker(self.queue, self.executor.result_queue)
for _ in range(self.executor.parallelism)
@@ -303,8 +304,8 @@ def execute_async(
self,
key: TaskInstanceKey,
command: CommandType,
- queue: Optional[str] = None,
- executor_config: Optional[Any] = None,
+ queue: str | None = None,
+ executor_config: Any | None = None,
) -> None:
"""
Executes task asynchronously.
@@ -314,8 +315,9 @@ def execute_async(
:param queue: name of the queue
:param executor_config: configuration for the executor
"""
- if not self.queue:
- raise AirflowException(NOT_STARTED_MESSAGE)
+ if TYPE_CHECKING:
+ assert self.queue
+
self.queue.put((key, command))
def sync(self):
@@ -361,21 +363,22 @@ def execute_async(
self,
key: TaskInstanceKey,
command: CommandType,
- queue: Optional[str] = None,
- executor_config: Optional[Any] = None,
+ queue: str | None = None,
+ executor_config: Any | None = None,
) -> None:
"""Execute asynchronously."""
- if not self.impl:
- raise AirflowException(NOT_STARTED_MESSAGE)
+ if TYPE_CHECKING:
+ assert self.impl
- self.validate_command(command)
+ self.validate_airflow_tasks_run_command(command)
self.impl.execute_async(key=key, command=command, queue=queue, executor_config=executor_config)
def sync(self) -> None:
"""Sync will get called periodically by the heartbeat method."""
- if not self.impl:
- raise AirflowException(NOT_STARTED_MESSAGE)
+ if TYPE_CHECKING:
+ assert self.impl
+
self.impl.sync()
def end(self) -> None:
@@ -383,10 +386,10 @@ def end(self) -> None:
Ends the executor.
:return:
"""
- if not self.impl:
- raise AirflowException(NOT_STARTED_MESSAGE)
- if not self.manager:
- raise AirflowException(NOT_STARTED_MESSAGE)
+ if TYPE_CHECKING:
+ assert self.impl
+ assert self.manager
+
self.log.info(
"Shutting down LocalExecutor"
"; waiting for running tasks to finish. Signal again if you don't want to wait."
diff --git a/airflow/executors/local_kubernetes_executor.py b/airflow/executors/local_kubernetes_executor.py
index 9944cfe1ef1fa..b6151d5d1c1d5 100644
--- a/airflow/executors/local_kubernetes_executor.py
+++ b/airflow/executors/local_kubernetes_executor.py
@@ -15,7 +15,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Dict, List, Optional, Set, Union
+from __future__ import annotations
+
+from typing import Sequence
from airflow.callbacks.base_callback_sink import BaseCallbackSink
from airflow.callbacks.callback_requests import CallbackRequest
@@ -37,19 +39,19 @@ class LocalKubernetesExecutor(LoggingMixin):
"""
supports_ad_hoc_ti_run: bool = True
- callback_sink: Optional[BaseCallbackSink] = None
+ callback_sink: BaseCallbackSink | None = None
- KUBERNETES_QUEUE = conf.get('local_kubernetes_executor', 'kubernetes_queue')
+ KUBERNETES_QUEUE = conf.get("local_kubernetes_executor", "kubernetes_queue")
def __init__(self, local_executor: LocalExecutor, kubernetes_executor: KubernetesExecutor):
super().__init__()
- self._job_id: Optional[str] = None
+ self._job_id: str | None = None
self.local_executor = local_executor
self.kubernetes_executor = kubernetes_executor
self.kubernetes_executor.kubernetes_queue = self.KUBERNETES_QUEUE
@property
- def queued_tasks(self) -> Dict[TaskInstanceKey, QueuedTaskInstanceType]:
+ def queued_tasks(self) -> dict[TaskInstanceKey, QueuedTaskInstanceType]:
"""Return queued tasks from local and kubernetes executor"""
queued_tasks = self.local_executor.queued_tasks.copy()
queued_tasks.update(self.kubernetes_executor.queued_tasks)
@@ -57,12 +59,12 @@ def queued_tasks(self) -> Dict[TaskInstanceKey, QueuedTaskInstanceType]:
return queued_tasks
@property
- def running(self) -> Set[TaskInstanceKey]:
+ def running(self) -> set[TaskInstanceKey]:
"""Return running tasks from local and kubernetes executor"""
return self.local_executor.running.union(self.kubernetes_executor.running)
@property
- def job_id(self) -> Optional[str]:
+ def job_id(self) -> str | None:
"""
This is a class attribute in BaseExecutor but since this is not really an executor, but a wrapper
of executors we implement as property so we can have custom setter.
@@ -70,7 +72,7 @@ def job_id(self) -> Optional[str]:
return self._job_id
@job_id.setter
- def job_id(self, value: Optional[str]) -> None:
+ def job_id(self, value: str | None) -> None:
"""job_id is manipulated by SchedulerJob. We must propagate the job_id to wrapped executors."""
self._job_id = value
self.kubernetes_executor.job_id = value
@@ -92,7 +94,7 @@ def queue_command(
task_instance: TaskInstance,
command: CommandType,
priority: int = 1,
- queue: Optional[str] = None,
+ queue: str | None = None,
) -> None:
"""Queues command via local or kubernetes executor"""
executor = self._router(task_instance)
@@ -103,13 +105,13 @@ def queue_task_instance(
self,
task_instance: TaskInstance,
mark_success: bool = False,
- pickle_id: Optional[str] = None,
+ pickle_id: str | None = None,
ignore_all_deps: bool = False,
ignore_depends_on_past: bool = False,
ignore_task_deps: bool = False,
ignore_ti_state: bool = False,
- pool: Optional[str] = None,
- cfg_path: Optional[str] = None,
+ pool: str | None = None,
+ cfg_path: str | None = None,
) -> None:
"""Queues task instance via local or kubernetes executor"""
executor = self._router(SimpleTaskInstance.from_ti(task_instance))
@@ -143,8 +145,8 @@ def heartbeat(self) -> None:
self.kubernetes_executor.heartbeat()
def get_event_buffer(
- self, dag_ids: Optional[List[str]] = None
- ) -> Dict[TaskInstanceKey, EventBufferValueType]:
+ self, dag_ids: list[str] | None = None
+ ) -> dict[TaskInstanceKey, EventBufferValueType]:
"""
Returns and flush the event buffer from local and kubernetes executor
@@ -156,7 +158,7 @@ def get_event_buffer(
return {**cleared_events_from_local, **cleared_events_from_kubernetes}
- def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance]:
+ def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[TaskInstance]:
"""
Try to adopt running task instances that have been abandoned by a SchedulerJob dying.
@@ -164,19 +166,13 @@ def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance
re-scheduling)
:return: any TaskInstances that were unable to be adopted
- :rtype: list[airflow.models.TaskInstance]
"""
- local_tis = []
- kubernetes_tis = []
- abandoned_tis = []
- for ti in tis:
- if ti.queue == self.KUBERNETES_QUEUE:
- kubernetes_tis.append(ti)
- else:
- local_tis.append(ti)
- abandoned_tis.extend(self.local_executor.try_adopt_task_instances(local_tis))
- abandoned_tis.extend(self.kubernetes_executor.try_adopt_task_instances(kubernetes_tis))
- return abandoned_tis
+ local_tis = [ti for ti in tis if ti.queue != self.KUBERNETES_QUEUE]
+ kubernetes_tis = [ti for ti in tis if ti.queue == self.KUBERNETES_QUEUE]
+ return [
+ *self.local_executor.try_adopt_task_instances(local_tis),
+ *self.kubernetes_executor.try_adopt_task_instances(kubernetes_tis),
+ ]
def end(self) -> None:
"""End local and kubernetes executor"""
@@ -188,13 +184,12 @@ def terminate(self) -> None:
self.local_executor.terminate()
self.kubernetes_executor.terminate()
- def _router(self, simple_task_instance: SimpleTaskInstance) -> Union[LocalExecutor, KubernetesExecutor]:
+ def _router(self, simple_task_instance: SimpleTaskInstance) -> LocalExecutor | KubernetesExecutor:
"""
Return either local_executor or kubernetes_executor
:param simple_task_instance: SimpleTaskInstance
:return: local_executor or kubernetes_executor
- :rtype: Union[LocalExecutor, KubernetesExecutor]
"""
if simple_task_instance.queue == self.KUBERNETES_QUEUE:
return self.kubernetes_executor
diff --git a/airflow/executors/sequential_executor.py b/airflow/executors/sequential_executor.py
index 456e3e9893e8b..c7c2f00417a21 100644
--- a/airflow/executors/sequential_executor.py
+++ b/airflow/executors/sequential_executor.py
@@ -22,8 +22,10 @@
For more information on how the SequentialExecutor works, take a look at the guide:
:ref:`executor:SequentialExecutor`
"""
+from __future__ import annotations
+
import subprocess
-from typing import Any, Optional
+from typing import Any
from airflow.executors.base_executor import BaseExecutor, CommandType
from airflow.models.taskinstance import TaskInstanceKey
@@ -48,10 +50,10 @@ def execute_async(
self,
key: TaskInstanceKey,
command: CommandType,
- queue: Optional[str] = None,
- executor_config: Optional[Any] = None,
+ queue: str | None = None,
+ executor_config: Any | None = None,
) -> None:
- self.validate_command(command)
+ self.validate_airflow_tasks_run_command(command)
self.commands_to_run.append((key, command))
def sync(self) -> None:
diff --git a/airflow/hooks/S3_hook.py b/airflow/hooks/S3_hook.py
deleted file mode 100644
index b59311a1ba9e4..0000000000000
--- a/airflow/hooks/S3_hook.py
+++ /dev/null
@@ -1,30 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.hooks.s3`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.hooks.s3 import S3Hook, provide_bucket_name # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.s3`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/hooks/__init__.py b/airflow/hooks/__init__.py
index 5c933400c2f8a..db2b7f0a298ee 100644
--- a/airflow/hooks/__init__.py
+++ b/airflow/hooks/__init__.py
@@ -15,4 +15,80 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+# fmt:, off
"""Hooks."""
+from __future__ import annotations
+
+from airflow.utils.deprecation_tools import add_deprecated_classes
+
+__deprecated_classes = {
+ "S3_hook": {
+ "S3Hook": "airflow.providers.amazon.aws.hooks.s3.S3Hook",
+ "provide_bucket_name": "airflow.providers.amazon.aws.hooks.s3.provide_bucket_name",
+ },
+ "base_hook": {
+ "BaseHook": "airflow.hooks.base.BaseHook",
+ },
+ "dbapi_hook": {
+ "DbApiHook": "airflow.providers.common.sql.hooks.sql.DbApiHook",
+ },
+ "docker_hook": {
+ "DockerHook": "airflow.providers.docker.hooks.docker.DockerHook",
+ },
+ "druid_hook": {
+ "DruidDbApiHook": "airflow.providers.apache.druid.hooks.druid.DruidDbApiHook",
+ "DruidHook": "airflow.providers.apache.druid.hooks.druid.DruidHook",
+ },
+ "hdfs_hook": {
+ "HDFSHook": "airflow.providers.apache.hdfs.hooks.hdfs.HDFSHook",
+ "HDFSHookException": "airflow.providers.apache.hdfs.hooks.hdfs.HDFSHookException",
+ },
+ "hive_hooks": {
+ "HIVE_QUEUE_PRIORITIES": "airflow.providers.apache.hive.hooks.hive.HIVE_QUEUE_PRIORITIES",
+ "HiveCliHook": "airflow.providers.apache.hive.hooks.hive.HiveCliHook",
+ "HiveMetastoreHook": "airflow.providers.apache.hive.hooks.hive.HiveMetastoreHook",
+ "HiveServer2Hook": "airflow.providers.apache.hive.hooks.hive.HiveServer2Hook",
+ },
+ "http_hook": {
+ "HttpHook": "airflow.providers.http.hooks.http.HttpHook",
+ },
+ "jdbc_hook": {
+ "JdbcHook": "airflow.providers.jdbc.hooks.jdbc.JdbcHook",
+ "jaydebeapi": "airflow.providers.jdbc.hooks.jdbc.jaydebeapi",
+ },
+ "mssql_hook": {
+ "MsSqlHook": "airflow.providers.microsoft.mssql.hooks.mssql.MsSqlHook",
+ },
+ "mysql_hook": {
+ "MySqlHook": "airflow.providers.mysql.hooks.mysql.MySqlHook",
+ },
+ "oracle_hook": {
+ "OracleHook": "airflow.providers.oracle.hooks.oracle.OracleHook",
+ },
+ "pig_hook": {
+ "PigCliHook": "airflow.providers.apache.pig.hooks.pig.PigCliHook",
+ },
+ "postgres_hook": {
+ "PostgresHook": "airflow.providers.postgres.hooks.postgres.PostgresHook",
+ },
+ "presto_hook": {
+ "PrestoHook": "airflow.providers.presto.hooks.presto.PrestoHook",
+ },
+ "samba_hook": {
+ "SambaHook": "airflow.providers.samba.hooks.samba.SambaHook",
+ },
+ "slack_hook": {
+ "SlackHook": "airflow.providers.slack.hooks.slack.SlackHook",
+ },
+ "sqlite_hook": {
+ "SqliteHook": "airflow.providers.sqlite.hooks.sqlite.SqliteHook",
+ },
+ "webhdfs_hook": {
+ "WebHDFSHook": "airflow.providers.apache.hdfs.hooks.webhdfs.WebHDFSHook",
+ },
+ "zendesk_hook": {
+ "ZendeskHook": "airflow.providers.zendesk.hooks.zendesk.ZendeskHook",
+ },
+}
+
+add_deprecated_classes(__deprecated_classes, __name__)
diff --git a/airflow/hooks/base.py b/airflow/hooks/base.py
index aa506cf62de28..9298a686889f7 100644
--- a/airflow/hooks/base.py
+++ b/airflow/hooks/base.py
@@ -15,11 +15,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Base class for all hooks"""
+"""Base class for all hooks."""
+from __future__ import annotations
+
import logging
import warnings
-from typing import TYPE_CHECKING, Any, Dict, List
+from typing import TYPE_CHECKING, Any
+from airflow.exceptions import RemovedInAirflow3Warning
from airflow.typing_compat import Protocol
from airflow.utils.log.logging_mixin import LoggingMixin
@@ -31,7 +34,9 @@
class BaseHook(LoggingMixin):
"""
- Abstract base class for hooks, hooks are meant as an interface to
+ Abstract base class for hooks.
+
+ Hooks are meant as an interface to
interact with external systems. MySqlHook, HiveHook, PigHook return
object that can handle the connection and interaction to specific
instances of these systems, and expose consistent methods to interact
@@ -39,7 +44,7 @@ class BaseHook(LoggingMixin):
"""
@classmethod
- def get_connections(cls, conn_id: str) -> List["Connection"]:
+ def get_connections(cls, conn_id: str) -> list[Connection]:
"""
Get all connections as an iterable, given the connection id.
@@ -49,13 +54,13 @@ def get_connections(cls, conn_id: str) -> List["Connection"]:
warnings.warn(
"`BaseHook.get_connections` method will be deprecated in the future."
"Please use `BaseHook.get_connection` instead.",
- PendingDeprecationWarning,
+ RemovedInAirflow3Warning,
stacklevel=2,
)
return [cls.get_connection(conn_id)]
@classmethod
- def get_connection(cls, conn_id: str) -> "Connection":
+ def get_connection(cls, conn_id: str) -> Connection:
"""
Get connection, given connection id.
@@ -69,15 +74,13 @@ def get_connection(cls, conn_id: str) -> "Connection":
return conn
@classmethod
- def get_hook(cls, conn_id: str) -> "BaseHook":
+ def get_hook(cls, conn_id: str) -> BaseHook:
"""
Returns default hook for this connection id.
:param conn_id: connection id
:return: default hook for this connection
"""
- # TODO: set method return type to BaseHook class when on 3.7+.
- # See https://stackoverflow.com/a/33533514/3066428
connection = cls.get_connection(conn_id)
return connection.get_hook()
@@ -86,11 +89,11 @@ def get_conn(self) -> Any:
raise NotImplementedError()
@classmethod
- def get_connection_form_widgets(cls) -> Dict[str, Any]:
+ def get_connection_form_widgets(cls) -> dict[str, Any]:
...
@classmethod
- def get_ui_field_behaviour(cls) -> Dict[str, Any]:
+ def get_ui_field_behaviour(cls) -> dict[str, Any]:
...
@@ -139,7 +142,7 @@ def get_ui_field_behaviour(cls):
hook_name: str
@staticmethod
- def get_connection_form_widgets() -> Dict[str, Any]:
+ def get_connection_form_widgets() -> dict[str, Any]:
"""
Returns dictionary of widgets to be added for the hook to handle extra values.
@@ -155,8 +158,10 @@ def get_connection_form_widgets() -> Dict[str, Any]:
...
@staticmethod
- def get_ui_field_behaviour() -> Dict[str, Any]:
+ def get_ui_field_behaviour() -> dict[str, Any]:
"""
+ Attributes of the UI field.
+
Returns dictionary describing customizations to implement in javascript handling the
connection form. Should be compliant with airflow/customized_form_field_behaviours.schema.json'
diff --git a/airflow/hooks/base_hook.py b/airflow/hooks/base_hook.py
deleted file mode 100644
index cf1594d18d284..0000000000000
--- a/airflow/hooks/base_hook.py
+++ /dev/null
@@ -1,24 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.hooks.base`."""
-
-import warnings
-
-from airflow.hooks.base import BaseHook # noqa
-
-warnings.warn("This module is deprecated. Please use `airflow.hooks.base`.", DeprecationWarning, stacklevel=2)
diff --git a/airflow/hooks/dbapi.py b/airflow/hooks/dbapi.py
index 0b9ce4377be23..cd4a39af8d2a5 100644
--- a/airflow/hooks/dbapi.py
+++ b/airflow/hooks/dbapi.py
@@ -15,363 +15,16 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from contextlib import closing
-from datetime import datetime
-from typing import Any, Optional
+from __future__ import annotations
-from sqlalchemy import create_engine
+import warnings
-from airflow.exceptions import AirflowException
-from airflow.hooks.base import BaseHook
-from airflow.typing_compat import Protocol
+from airflow.exceptions import RemovedInAirflow3Warning
+from airflow.providers.common.sql.hooks.sql import ConnectorProtocol # noqa
+from airflow.providers.common.sql.hooks.sql import DbApiHook # noqa
-
-class ConnectorProtocol(Protocol):
- """A protocol where you can connect to a database."""
-
- def connect(self, host: str, port: int, username: str, schema: str) -> Any:
- """
- Connect to a database.
-
- :param host: The database host to connect to.
- :param port: The database port to connect to.
- :param username: The database username used for the authentication.
- :param schema: The database schema to connect to.
- :return: the authorized connection object.
- """
-
-
-#########################################################################################
-# #
-# Note! Be extra careful when changing this file. This hook is used as a base for #
-# a number of DBApi-related hooks and providers depend on the methods implemented #
-# here. Whatever you add here, has to backwards compatible unless #
-# `>=` is added to providers' requirements using the new feature #
-# #
-#########################################################################################
-class DbApiHook(BaseHook):
- """
- Abstract base class for sql hooks.
-
- :param schema: Optional DB schema that overrides the schema specified in the connection. Make sure that
- if you change the schema parameter value in the constructor of the derived Hook, such change
- should be done before calling the ``DBApiHook.__init__()``.
- """
-
- # Override to provide the connection name.
- conn_name_attr = None # type: str
- # Override to have a default connection id for a particular dbHook
- default_conn_name = 'default_conn_id'
- # Override if this db supports autocommit.
- supports_autocommit = False
- # Override with the object that exposes the connect method
- connector = None # type: Optional[ConnectorProtocol]
- # Override with db-specific query to check connection
- _test_connection_sql = "select 1"
-
- def __init__(self, *args, schema: Optional[str] = None, **kwargs):
- super().__init__()
- if not self.conn_name_attr:
- raise AirflowException("conn_name_attr is not defined")
- elif len(args) == 1:
- setattr(self, self.conn_name_attr, args[0])
- elif self.conn_name_attr not in kwargs:
- setattr(self, self.conn_name_attr, self.default_conn_name)
- else:
- setattr(self, self.conn_name_attr, kwargs[self.conn_name_attr])
- # We should not make schema available in deriving hooks for backwards compatibility
- # If a hook deriving from DBApiHook has a need to access schema, then it should retrieve it
- # from kwargs and store it on its own. We do not run "pop" here as we want to give the
- # Hook deriving from the DBApiHook to still have access to the field in it's constructor
- self.__schema = schema
-
- def get_conn(self):
- """Returns a connection object"""
- db = self.get_connection(getattr(self, self.conn_name_attr))
- return self.connector.connect(host=db.host, port=db.port, username=db.login, schema=db.schema)
-
- def get_uri(self) -> str:
- """
- Extract the URI from the connection.
-
- :return: the extracted uri.
- """
- conn = self.get_connection(getattr(self, self.conn_name_attr))
- conn.schema = self.__schema or conn.schema
- return conn.get_uri()
-
- def get_sqlalchemy_engine(self, engine_kwargs=None):
- """
- Get an sqlalchemy_engine object.
-
- :param engine_kwargs: Kwargs used in :func:`~sqlalchemy.create_engine`.
- :return: the created engine.
- """
- if engine_kwargs is None:
- engine_kwargs = {}
- return create_engine(self.get_uri(), **engine_kwargs)
-
- def get_pandas_df(self, sql, parameters=None, **kwargs):
- """
- Executes the sql and returns a pandas dataframe
-
- :param sql: the sql statement to be executed (str) or a list of
- sql statements to execute
- :param parameters: The parameters to render the SQL query with.
- :param kwargs: (optional) passed into pandas.io.sql.read_sql method
- """
- try:
- from pandas.io import sql as psql
- except ImportError:
- raise Exception("pandas library not installed, run: pip install 'apache-airflow[pandas]'.")
-
- with closing(self.get_conn()) as conn:
- return psql.read_sql(sql, con=conn, params=parameters, **kwargs)
-
- def get_pandas_df_by_chunks(self, sql, parameters=None, *, chunksize, **kwargs):
- """
- Executes the sql and returns a generator
-
- :param sql: the sql statement to be executed (str) or a list of
- sql statements to execute
- :param parameters: The parameters to render the SQL query with
- :param chunksize: number of rows to include in each chunk
- :param kwargs: (optional) passed into pandas.io.sql.read_sql method
- """
- try:
- from pandas.io import sql as psql
- except ImportError:
- raise Exception("pandas library not installed, run: pip install 'apache-airflow[pandas]'.")
-
- with closing(self.get_conn()) as conn:
- yield from psql.read_sql(sql, con=conn, params=parameters, chunksize=chunksize, **kwargs)
-
- def get_records(self, sql, parameters=None):
- """
- Executes the sql and returns a set of records.
-
- :param sql: the sql statement to be executed (str) or a list of
- sql statements to execute
- :param parameters: The parameters to render the SQL query with.
- """
- with closing(self.get_conn()) as conn:
- with closing(conn.cursor()) as cur:
- if parameters is not None:
- cur.execute(sql, parameters)
- else:
- cur.execute(sql)
- return cur.fetchall()
-
- def get_first(self, sql, parameters=None):
- """
- Executes the sql and returns the first resulting row.
-
- :param sql: the sql statement to be executed (str) or a list of
- sql statements to execute
- :param parameters: The parameters to render the SQL query with.
- """
- with closing(self.get_conn()) as conn:
- with closing(conn.cursor()) as cur:
- if parameters is not None:
- cur.execute(sql, parameters)
- else:
- cur.execute(sql)
- return cur.fetchone()
-
- def run(self, sql, autocommit=False, parameters=None, handler=None):
- """
- Runs a command or a list of commands. Pass a list of sql
- statements to the sql parameter to get them to execute
- sequentially
-
- :param sql: the sql statement to be executed (str) or a list of
- sql statements to execute
- :param autocommit: What to set the connection's autocommit setting to
- before executing the query.
- :param parameters: The parameters to render the SQL query with.
- :param handler: The result handler which is called with the result of each statement.
- :return: query results if handler was provided.
- """
- scalar = isinstance(sql, str)
- if scalar:
- sql = [sql]
-
- if sql:
- self.log.debug("Executing %d statements", len(sql))
- else:
- raise ValueError("List of SQL statements is empty")
-
- with closing(self.get_conn()) as conn:
- if self.supports_autocommit:
- self.set_autocommit(conn, autocommit)
-
- with closing(conn.cursor()) as cur:
- results = []
- for sql_statement in sql:
- self._run_command(cur, sql_statement, parameters)
- if handler is not None:
- result = handler(cur)
- results.append(result)
-
- # If autocommit was set to False for db that supports autocommit,
- # or if db does not supports autocommit, we do a manual commit.
- if not self.get_autocommit(conn):
- conn.commit()
-
- if handler is None:
- return None
-
- if scalar:
- return results[0]
-
- return results
-
- def _run_command(self, cur, sql_statement, parameters):
- """Runs a statement using an already open cursor."""
- self.log.info("Running statement: %s, parameters: %s", sql_statement, parameters)
- if parameters:
- cur.execute(sql_statement, parameters)
- else:
- cur.execute(sql_statement)
-
- # According to PEP 249, this is -1 when query result is not applicable.
- if cur.rowcount >= 0:
- self.log.info("Rows affected: %s", cur.rowcount)
-
- def set_autocommit(self, conn, autocommit):
- """Sets the autocommit flag on the connection"""
- if not self.supports_autocommit and autocommit:
- self.log.warning(
- "%s connection doesn't support autocommit but autocommit activated.",
- getattr(self, self.conn_name_attr),
- )
- conn.autocommit = autocommit
-
- def get_autocommit(self, conn):
- """
- Get autocommit setting for the provided connection.
- Return True if conn.autocommit is set to True.
- Return False if conn.autocommit is not set or set to False or conn
- does not support autocommit.
-
- :param conn: Connection to get autocommit setting from.
- :return: connection autocommit setting.
- :rtype: bool
- """
- return getattr(conn, 'autocommit', False) and self.supports_autocommit
-
- def get_cursor(self):
- """Returns a cursor"""
- return self.get_conn().cursor()
-
- @staticmethod
- def _generate_insert_sql(table, values, target_fields, replace, **kwargs):
- """
- Static helper method that generates the INSERT SQL statement.
- The REPLACE variant is specific to MySQL syntax.
-
- :param table: Name of the target table
- :param values: The row to insert into the table
- :param target_fields: The names of the columns to fill in the table
- :param replace: Whether to replace instead of insert
- :return: The generated INSERT or REPLACE SQL statement
- :rtype: str
- """
- placeholders = [
- "%s",
- ] * len(values)
-
- if target_fields:
- target_fields = ", ".join(target_fields)
- target_fields = f"({target_fields})"
- else:
- target_fields = ''
-
- if not replace:
- sql = "INSERT INTO "
- else:
- sql = "REPLACE INTO "
- sql += f"{table} {target_fields} VALUES ({','.join(placeholders)})"
- return sql
-
- def insert_rows(self, table, rows, target_fields=None, commit_every=1000, replace=False, **kwargs):
- """
- A generic way to insert a set of tuples into a table,
- a new transaction is created every commit_every rows
-
- :param table: Name of the target table
- :param rows: The rows to insert into the table
- :param target_fields: The names of the columns to fill in the table
- :param commit_every: The maximum number of rows to insert in one
- transaction. Set to 0 to insert all rows in one transaction.
- :param replace: Whether to replace instead of insert
- """
- i = 0
- with closing(self.get_conn()) as conn:
- if self.supports_autocommit:
- self.set_autocommit(conn, False)
-
- conn.commit()
-
- with closing(conn.cursor()) as cur:
- for i, row in enumerate(rows, 1):
- lst = []
- for cell in row:
- lst.append(self._serialize_cell(cell, conn))
- values = tuple(lst)
- sql = self._generate_insert_sql(table, values, target_fields, replace, **kwargs)
- self.log.debug("Generated sql: %s", sql)
- cur.execute(sql, values)
- if commit_every and i % commit_every == 0:
- conn.commit()
- self.log.info("Loaded %s rows into %s so far", i, table)
-
- conn.commit()
- self.log.info("Done loading. Loaded a total of %s rows", i)
-
- @staticmethod
- def _serialize_cell(cell, conn=None):
- """
- Returns the SQL literal of the cell as a string.
-
- :param cell: The cell to insert into the table
- :param conn: The database connection
- :return: The serialized cell
- :rtype: str
- """
- if cell is None:
- return None
- if isinstance(cell, datetime):
- return cell.isoformat()
- return str(cell)
-
- def bulk_dump(self, table, tmp_file):
- """
- Dumps a database table into a tab-delimited file
-
- :param table: The name of the source table
- :param tmp_file: The path of the target file
- """
- raise NotImplementedError()
-
- def bulk_load(self, table, tmp_file):
- """
- Loads a tab-delimited file into a database table
-
- :param table: The name of the target table
- :param tmp_file: The path of the file to load into the table
- """
- raise NotImplementedError()
-
- def test_connection(self):
- """Tests the connection using db-specific query"""
- status, message = False, ''
- try:
- if self.get_first(self._test_connection_sql):
- status = True
- message = 'Connection successfully tested'
- except Exception as e:
- status = False
- message = str(e)
-
- return status, message
+warnings.warn(
+ "This module is deprecated. Please use `airflow.providers.common.sql.hooks.sql`.",
+ RemovedInAirflow3Warning,
+ stacklevel=2,
+)
diff --git a/airflow/hooks/dbapi_hook.py b/airflow/hooks/dbapi_hook.py
deleted file mode 100644
index 4a441b0f50d59..0000000000000
--- a/airflow/hooks/dbapi_hook.py
+++ /dev/null
@@ -1,26 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.hooks.dbapi`."""
-
-import warnings
-
-from airflow.hooks.dbapi import DbApiHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.hooks.dbapi`.", DeprecationWarning, stacklevel=2
-)
diff --git a/airflow/hooks/docker_hook.py b/airflow/hooks/docker_hook.py
deleted file mode 100644
index aaedd7e637d93..0000000000000
--- a/airflow/hooks/docker_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.docker.hooks.docker`."""
-
-import warnings
-
-from airflow.providers.docker.hooks.docker import DockerHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.docker.hooks.docker`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/hooks/druid_hook.py b/airflow/hooks/druid_hook.py
deleted file mode 100644
index 0a43debbabddc..0000000000000
--- a/airflow/hooks/druid_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.apache.druid.hooks.druid`."""
-
-import warnings
-
-from airflow.providers.apache.druid.hooks.druid import DruidDbApiHook, DruidHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.apache.druid.hooks.druid`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/hooks/filesystem.py b/airflow/hooks/filesystem.py
index d9ad6ded1cc12..39517e8cdc1ad 100644
--- a/airflow/hooks/filesystem.py
+++ b/airflow/hooks/filesystem.py
@@ -15,7 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-#
+from __future__ import annotations
from airflow.hooks.base import BaseHook
@@ -33,13 +33,13 @@ class FSHook(BaseHook):
Extra: {"path": "/tmp"}
"""
- def __init__(self, conn_id='fs_default'):
+ def __init__(self, conn_id: str = "fs_default"):
super().__init__()
conn = self.get_connection(conn_id)
- self.basepath = conn.extra_dejson.get('path', '')
+ self.basepath = conn.extra_dejson.get("path", "")
self.conn = conn
- def get_conn(self):
+ def get_conn(self) -> None:
pass
def get_path(self) -> str:
diff --git a/airflow/hooks/hdfs_hook.py b/airflow/hooks/hdfs_hook.py
deleted file mode 100644
index fd13e7337e262..0000000000000
--- a/airflow/hooks/hdfs_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.apache.hdfs.hooks.hdfs`."""
-
-import warnings
-
-from airflow.providers.apache.hdfs.hooks.hdfs import HDFSHook, HDFSHookException # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.apache.hdfs.hooks.hdfs`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/hooks/hive_hooks.py b/airflow/hooks/hive_hooks.py
deleted file mode 100644
index 74d7863c8d947..0000000000000
--- a/airflow/hooks/hive_hooks.py
+++ /dev/null
@@ -1,33 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.apache.hive.hooks.hive`."""
-
-import warnings
-
-from airflow.providers.apache.hive.hooks.hive import ( # noqa
- HIVE_QUEUE_PRIORITIES,
- HiveCliHook,
- HiveMetastoreHook,
- HiveServer2Hook,
-)
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.apache.hive.hooks.hive`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/hooks/http_hook.py b/airflow/hooks/http_hook.py
deleted file mode 100644
index 5b8c1fdf9b776..0000000000000
--- a/airflow/hooks/http_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.http.hooks.http`."""
-
-import warnings
-
-from airflow.providers.http.hooks.http import HttpHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.http.hooks.http`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/hooks/jdbc_hook.py b/airflow/hooks/jdbc_hook.py
deleted file mode 100644
index a032ab0e2598b..0000000000000
--- a/airflow/hooks/jdbc_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.jdbc.hooks.jdbc`."""
-
-import warnings
-
-from airflow.providers.jdbc.hooks.jdbc import JdbcHook, jaydebeapi # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.jdbc.hooks.jdbc`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/hooks/mssql_hook.py b/airflow/hooks/mssql_hook.py
deleted file mode 100644
index 64943eeea7905..0000000000000
--- a/airflow/hooks/mssql_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.mssql.hooks.mssql`."""
-
-import warnings
-
-from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.microsoft.mssql.hooks.mssql`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/hooks/mysql_hook.py b/airflow/hooks/mysql_hook.py
deleted file mode 100644
index 437313680b09c..0000000000000
--- a/airflow/hooks/mysql_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.mysql.hooks.mysql`."""
-
-import warnings
-
-from airflow.providers.mysql.hooks.mysql import MySqlHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.mysql.hooks.mysql`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/hooks/oracle_hook.py b/airflow/hooks/oracle_hook.py
deleted file mode 100644
index 0dfe33a78ae2a..0000000000000
--- a/airflow/hooks/oracle_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.oracle.hooks.oracle`."""
-
-import warnings
-
-from airflow.providers.oracle.hooks.oracle import OracleHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.oracle.hooks.oracle`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/hooks/pig_hook.py b/airflow/hooks/pig_hook.py
deleted file mode 100644
index 3ead3df6c826e..0000000000000
--- a/airflow/hooks/pig_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.apache.pig.hooks.pig`."""
-
-import warnings
-
-from airflow.providers.apache.pig.hooks.pig import PigCliHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.apache.pig.hooks.pig`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/hooks/postgres_hook.py b/airflow/hooks/postgres_hook.py
deleted file mode 100644
index 16f79dc329593..0000000000000
--- a/airflow/hooks/postgres_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.postgres.hooks.postgres`."""
-
-import warnings
-
-from airflow.providers.postgres.hooks.postgres import PostgresHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.postgres.hooks.postgres`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/hooks/presto_hook.py b/airflow/hooks/presto_hook.py
deleted file mode 100644
index 0c33e1423d35d..0000000000000
--- a/airflow/hooks/presto_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.presto.hooks.presto`."""
-
-import warnings
-
-from airflow.providers.presto.hooks.presto import PrestoHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.presto.hooks.presto`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/hooks/samba_hook.py b/airflow/hooks/samba_hook.py
deleted file mode 100644
index b4c7cf83b05a6..0000000000000
--- a/airflow/hooks/samba_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.samba.hooks.samba`."""
-
-import warnings
-
-from airflow.providers.samba.hooks.samba import SambaHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.samba.hooks.samba`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/hooks/slack_hook.py b/airflow/hooks/slack_hook.py
deleted file mode 100644
index 43636b2c6eeef..0000000000000
--- a/airflow/hooks/slack_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.slack.hooks.slack`."""
-
-import warnings
-
-from airflow.providers.slack.hooks.slack import SlackHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.slack.hooks.slack`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/hooks/sqlite_hook.py b/airflow/hooks/sqlite_hook.py
deleted file mode 100644
index 773900400ccbc..0000000000000
--- a/airflow/hooks/sqlite_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.sqlite.hooks.sqlite`."""
-
-import warnings
-
-from airflow.providers.sqlite.hooks.sqlite import SqliteHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.sqlite.hooks.sqlite`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/hooks/subprocess.py b/airflow/hooks/subprocess.py
index fa8c706c6963d..84479a3a40809 100644
--- a/airflow/hooks/subprocess.py
+++ b/airflow/hooks/subprocess.py
@@ -14,32 +14,33 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
import contextlib
import os
import signal
from collections import namedtuple
from subprocess import PIPE, STDOUT, Popen
from tempfile import TemporaryDirectory, gettempdir
-from typing import Dict, List, Optional
from airflow.hooks.base import BaseHook
-SubprocessResult = namedtuple('SubprocessResult', ['exit_code', 'output'])
+SubprocessResult = namedtuple("SubprocessResult", ["exit_code", "output"])
class SubprocessHook(BaseHook):
- """Hook for running processes with the ``subprocess`` module"""
+ """Hook for running processes with the ``subprocess`` module."""
def __init__(self) -> None:
- self.sub_process: Optional[Popen[bytes]] = None
+ self.sub_process: Popen[bytes] | None = None
super().__init__()
def run_command(
self,
- command: List[str],
- env: Optional[Dict[str, str]] = None,
- output_encoding: str = 'utf-8',
- cwd: Optional[str] = None,
+ command: list[str],
+ env: dict[str, str] | None = None,
+ output_encoding: str = "utf-8",
+ cwd: str | None = None,
) -> SubprocessResult:
"""
Execute the command.
@@ -52,26 +53,26 @@ def run_command(
environment in which ``command`` will be executed. If omitted, ``os.environ`` will be used.
Note, that in case you have Sentry configured, original variables from the environment
will also be passed to the subprocess with ``SUBPROCESS_`` prefix. See
- :doc:`/logging-monitoring/errors` for details.
+ :doc:`/administration-and-deployment/logging-monitoring/errors` for details.
:param output_encoding: encoding to use for decoding stdout
:param cwd: Working directory to run the command in.
If None (default), the command is run in a temporary directory.
:return: :class:`namedtuple` containing ``exit_code`` and ``output``, the last line from stderr
or stdout
"""
- self.log.info('Tmp dir root location: \n %s', gettempdir())
+ self.log.info("Tmp dir root location: \n %s", gettempdir())
with contextlib.ExitStack() as stack:
if cwd is None:
- cwd = stack.enter_context(TemporaryDirectory(prefix='airflowtmp'))
+ cwd = stack.enter_context(TemporaryDirectory(prefix="airflowtmp"))
def pre_exec():
# Restore default signal disposition and invoke setsid
- for sig in ('SIGPIPE', 'SIGXFZ', 'SIGXFSZ'):
+ for sig in ("SIGPIPE", "SIGXFZ", "SIGXFSZ"):
if hasattr(signal, sig):
signal.signal(getattr(signal, sig), signal.SIG_DFL)
os.setsid()
- self.log.info('Running command: %s', command)
+ self.log.info("Running command: %s", command)
self.sub_process = Popen(
command,
@@ -82,24 +83,24 @@ def pre_exec():
preexec_fn=pre_exec,
)
- self.log.info('Output:')
- line = ''
+ self.log.info("Output:")
+ line = ""
if self.sub_process is None:
raise RuntimeError("The subprocess should be created here and is None!")
if self.sub_process.stdout is not None:
- for raw_line in iter(self.sub_process.stdout.readline, b''):
- line = raw_line.decode(output_encoding, errors='backslashreplace').rstrip()
+ for raw_line in iter(self.sub_process.stdout.readline, b""):
+ line = raw_line.decode(output_encoding, errors="backslashreplace").rstrip()
self.log.info("%s", line)
self.sub_process.wait()
- self.log.info('Command exited with return code %s', self.sub_process.returncode)
+ self.log.info("Command exited with return code %s", self.sub_process.returncode)
return_code: int = self.sub_process.returncode
return SubprocessResult(exit_code=return_code, output=line)
def send_sigterm(self):
"""Sends SIGTERM signal to ``self.sub_process`` if one exists."""
- self.log.info('Sending SIGTERM signal to process group')
- if self.sub_process and hasattr(self.sub_process, 'pid'):
+ self.log.info("Sending SIGTERM signal to process group")
+ if self.sub_process and hasattr(self.sub_process, "pid"):
os.killpg(os.getpgid(self.sub_process.pid), signal.SIGTERM)
diff --git a/airflow/hooks/webhdfs_hook.py b/airflow/hooks/webhdfs_hook.py
deleted file mode 100644
index 1c4353835cf00..0000000000000
--- a/airflow/hooks/webhdfs_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.apache.hdfs.hooks.webhdfs`."""
-
-import warnings
-
-from airflow.providers.apache.hdfs.hooks.webhdfs import WebHDFSHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.apache.hdfs.hooks.webhdfs`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/hooks/zendesk_hook.py b/airflow/hooks/zendesk_hook.py
deleted file mode 100644
index 38323c5880d74..0000000000000
--- a/airflow/hooks/zendesk_hook.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.zendesk.hooks.zendesk`."""
-
-import warnings
-
-from airflow.providers.zendesk.hooks.zendesk import ZendeskHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.zendesk.hooks.zendesk`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/jobs/__init__.py b/airflow/jobs/__init__.py
index 2f061162719d5..217e5db960782 100644
--- a/airflow/jobs/__init__.py
+++ b/airflow/jobs/__init__.py
@@ -15,9 +15,3 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-#
-import airflow.jobs.backfill_job
-import airflow.jobs.base_job
-import airflow.jobs.local_task_job
-import airflow.jobs.scheduler_job
-import airflow.jobs.triggerer_job # noqa
diff --git a/airflow/jobs/backfill_job.py b/airflow/jobs/backfill_job.py
index c5b98c2a8df2a..1115edc78c5ea 100644
--- a/airflow/jobs/backfill_job.py
+++ b/airflow/jobs/backfill_job.py
@@ -15,10 +15,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-#
+from __future__ import annotations
import time
-from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Tuple
+from typing import TYPE_CHECKING, Any, Iterable, Iterator, Sequence
import attr
import pendulum
@@ -50,7 +50,7 @@
from airflow.utils.types import DagRunType
if TYPE_CHECKING:
- from airflow.models.mappedoperator import MappedOperator
+ from airflow.models.abstractoperator import AbstractOperator
class BackfillJob(BaseJob):
@@ -62,7 +62,7 @@ class BackfillJob(BaseJob):
STATES_COUNT_AS_RUNNING = (State.RUNNING, State.QUEUED)
- __mapper_args__ = {'polymorphic_identity': 'BackfillJob'}
+ __mapper_args__ = {"polymorphic_identity": "BackfillJob"}
@attr.define
class _DagRunTaskStatus:
@@ -88,15 +88,15 @@ class _DagRunTaskStatus:
:param total_runs: Number of total dag runs able to run
"""
- to_run: Dict[TaskInstanceKey, TaskInstance] = attr.ib(factory=dict)
- running: Dict[TaskInstanceKey, TaskInstance] = attr.ib(factory=dict)
- skipped: Set[TaskInstanceKey] = attr.ib(factory=set)
- succeeded: Set[TaskInstanceKey] = attr.ib(factory=set)
- failed: Set[TaskInstanceKey] = attr.ib(factory=set)
- not_ready: Set[TaskInstanceKey] = attr.ib(factory=set)
- deadlocked: Set[TaskInstance] = attr.ib(factory=set)
- active_runs: List[DagRun] = attr.ib(factory=list)
- executed_dag_run_dates: Set[pendulum.DateTime] = attr.ib(factory=set)
+ to_run: dict[TaskInstanceKey, TaskInstance] = attr.ib(factory=dict)
+ running: dict[TaskInstanceKey, TaskInstance] = attr.ib(factory=dict)
+ skipped: set[TaskInstanceKey] = attr.ib(factory=set)
+ succeeded: set[TaskInstanceKey] = attr.ib(factory=set)
+ failed: set[TaskInstanceKey] = attr.ib(factory=set)
+ not_ready: set[TaskInstanceKey] = attr.ib(factory=set)
+ deadlocked: set[TaskInstance] = attr.ib(factory=set)
+ active_runs: list[DagRun] = attr.ib(factory=list)
+ executed_dag_run_dates: set[pendulum.DateTime] = attr.ib(factory=set)
finished_runs: int = 0
total_runs: int = 0
@@ -117,6 +117,7 @@ def __init__(
run_backwards=False,
run_at_least_once=False,
continue_on_failures=False,
+ disable_retry=False,
*args,
**kwargs,
):
@@ -156,6 +157,7 @@ def __init__(
self.run_backwards = run_backwards
self.run_at_least_once = run_at_least_once
self.continue_on_failures = continue_on_failures
+ self.disable_retry = disable_retry
super().__init__(*args, **kwargs)
def _update_counters(self, ti_status, session=None):
@@ -217,6 +219,12 @@ def _update_counters(self, ti_status, session=None):
tis_to_be_scheduled.append(ti)
ti_status.running.pop(reduced_key)
ti_status.to_run[ti.key] = ti
+ # special case: Deferrable task can go from DEFERRED to SCHEDULED;
+ # when that happens, we need to put it back as in UP_FOR_RESCHEDULE
+ elif ti.state == TaskInstanceState.SCHEDULED:
+ self.log.debug("Task instance %s is resumed from deferred state", ti)
+ ti_status.running.pop(ti.key)
+ ti_status.to_run[ti.key] = ti
# Batch schedule of task instances
if tis_to_be_scheduled:
@@ -228,7 +236,7 @@ def _update_counters(self, ti_status, session=None):
def _manage_executor_state(
self, running, session
- ) -> Iterator[Tuple["MappedOperator", str, Sequence[TaskInstance], int]]:
+ ) -> Iterator[tuple[AbstractOperator, str, Sequence[TaskInstance], int]]:
"""
Checks if the executor agrees with the state of task instances
that are running.
@@ -263,10 +271,20 @@ def _manage_executor_state(
self.log.error(msg)
ti.handle_failure(error=msg)
continue
+
+ def _iter_task_needing_expansion() -> Iterator[AbstractOperator]:
+ from airflow.models.mappedoperator import AbstractOperator
+
+ for node in self.dag.get_task(ti.task_id, include_subdags=True).iter_mapped_dependants():
+ if isinstance(node, AbstractOperator):
+ yield node
+ else: # A (mapped) task group. All its children need expansion.
+ yield from node.iter_tasks()
+
if ti.state not in self.STATES_COUNT_AS_RUNNING:
# Don't use ti.task; if this task is mapped, that attribute
# would hold the unmapped task. We need to original task here.
- for node in self.dag.get_task(ti.task_id, include_subdags=True).iter_mapped_dependants():
+ for node in _iter_task_needing_expansion():
new_tis, num_mapped_tis = node.expand_mapped_task(ti.run_id, session=session)
yield node, ti.run_id, new_tis, num_mapped_tis
@@ -292,13 +310,15 @@ def _get_dag_run(self, dagrun_info: DagRunInfo, dag: DAG, session: Session = Non
# check if we are scheduling on top of a already existing dag_run
# we could find a "scheduled" run instead of a "backfill"
runs = DagRun.find(dag_id=dag.dag_id, execution_date=run_date, session=session)
- run: Optional[DagRun]
+ run: DagRun | None
if runs:
run = runs[0]
if run.state == DagRunState.RUNNING:
respect_dag_max_active_limit = False
# Fixes --conf overwrite for backfills with already existing DagRuns
run.conf = self.conf or {}
+ # start_date is cleared for existing DagRuns
+ run.start_date = timezone.utcnow()
else:
run = None
@@ -326,10 +346,12 @@ def _get_dag_run(self, dagrun_info: DagRunInfo, dag: DAG, session: Session = Non
run.state = DagRunState.RUNNING
run.run_type = DagRunType.BACKFILL_JOB
run.verify_integrity(session=session)
+
+ run.notify_dagrun_state_changed(msg="started")
return run
@provide_session
- def _task_instances_for_dag_run(self, dag_run, session=None):
+ def _task_instances_for_dag_run(self, dag, dag_run, session=None):
"""
Returns a map of task instance key to task instance object for the tasks to
run in the given dag run.
@@ -349,24 +371,25 @@ def _task_instances_for_dag_run(self, dag_run, session=None):
dag_run.refresh_from_db()
make_transient(dag_run)
+ dag_run.dag = dag
+ info = dag_run.task_instance_scheduling_decisions(session=session)
+ schedulable_tis = info.schedulable_tis
try:
- for ti in dag_run.get_task_instances():
- # all tasks part of the backfill are scheduled to run
- if ti.state == State.NONE:
- ti.set_state(TaskInstanceState.SCHEDULED, session=session)
+ for ti in dag_run.get_task_instances(session=session):
+ if ti in schedulable_tis:
+ ti.set_state(TaskInstanceState.SCHEDULED)
if ti.state != TaskInstanceState.REMOVED:
tasks_to_run[ti.key] = ti
session.commit()
except Exception:
session.rollback()
raise
-
return tasks_to_run
def _log_progress(self, ti_status):
self.log.info(
- '[backfill progress] | finished run %s of %s | tasks waiting: %s | succeeded: %s | '
- 'running: %s | failed: %s | skipped: %s | deadlocked: %s | not ready: %s',
+ "[backfill progress] | finished run %s of %s | tasks waiting: %s | succeeded: %s | "
+ "running: %s | failed: %s | skipped: %s | deadlocked: %s | not ready: %s",
ti_status.finished_runs,
ti_status.total_runs,
len(ti_status.to_run),
@@ -388,7 +411,7 @@ def _process_backfill_task_instances(
pickle_id,
start_date=None,
session=None,
- ):
+ ) -> list:
"""
Process a set of task instances from a set of dag runs. Special handling is done
to account for different task instance states that could be present when running
@@ -400,11 +423,10 @@ def _process_backfill_task_instances(
:param start_date: the start date of the backfill job
:param session: the current session object
:return: the list of execution_dates for the finished dag runs
- :rtype: list
"""
executed_run_dates = []
- is_unit_test = airflow_conf.getboolean('core', 'unit_test_mode')
+ is_unit_test = airflow_conf.getboolean("core", "unit_test_mode")
while (len(ti_status.to_run) > 0 or len(ti_status.running) > 0) and len(ti_status.deadlocked) == 0:
self.log.debug("*** Clearing out not_ready list ***")
@@ -439,13 +461,6 @@ def _per_task_process(key, ti: TaskInstance, session=None):
ti_status.running.pop(key)
return
- # guard against externally modified tasks instances or
- # in case max concurrency has been reached at task runtime
- elif ti.state == State.NONE:
- self.log.warning(
- "FIXME: Task instance %s state was set to None externally. This should not happen", ti
- )
- ti.set_state(TaskInstanceState.SCHEDULED, session=session)
if self.rerun_failed_tasks:
# Rerun failed tasks or upstreamed failed tasks
if ti.state in (TaskInstanceState.FAILED, TaskInstanceState.UPSTREAM_FAILED):
@@ -485,7 +500,7 @@ def _per_task_process(key, ti: TaskInstance, session=None):
if executor.has_task(ti):
self.log.debug("Task Instance %s already in executor waiting for queue to clear", ti)
else:
- self.log.debug('Sending %s to executor', ti)
+ self.log.debug("Sending %s to executor", ti)
# Skip scheduled state, we are executing immediately
ti.state = TaskInstanceState.QUEUED
ti.queued_by_job_id = self.id
@@ -544,7 +559,7 @@ def _per_task_process(key, ti: TaskInstance, session=None):
return
# all remaining tasks
- self.log.debug('Adding %s to not_ready', ti)
+ self.log.debug("Adding %s to not_ready", ti)
ti_status.not_ready.add(key)
try:
@@ -555,7 +570,7 @@ def _per_task_process(key, ti: TaskInstance, session=None):
pool = session.query(models.Pool).filter(models.Pool.pool == task.pool).first()
if not pool:
- raise PoolNotFound(f'Unknown pool: {task.pool}')
+ raise PoolNotFound(f"Unknown pool: {task.pool}")
open_slots = pool.open_slots(session=session)
if open_slots <= 0:
@@ -627,6 +642,11 @@ def to_keep(key: TaskInstanceKey) -> bool:
for new_ti in new_mapped_tis:
new_ti.set_state(TaskInstanceState.SCHEDULED, session=session)
+ # Set state to failed for running TIs that are set up for retry if disable-retry flag is set
+ for ti in ti_status.running.values():
+ if self.disable_retry and ti.state == TaskInstanceState.UP_FOR_RETRY:
+ ti.set_state(TaskInstanceState.FAILED, session=session)
+
# update the task counters
self._update_counters(ti_status=ti_status, session=session)
session.commit()
@@ -669,12 +689,12 @@ def tabulate_ti_keys_set(ti_keys: Iterable[TaskInstanceKey]) -> str:
return tabulate(sorted_ti_keys, headers=headers)
- err = ''
+ err = ""
if ti_status.failed:
err += "Some task instances failed:\n"
err += tabulate_ti_keys_set(ti_status.failed)
if ti_status.deadlocked:
- err += 'BackfillJob is deadlocked.'
+ err += "BackfillJob is deadlocked."
deadlocked_depends_on_past = any(
t.are_dependencies_met(
dep_context=DepContext(ignore_depends_on_past=False),
@@ -688,26 +708,26 @@ def tabulate_ti_keys_set(ti_keys: Iterable[TaskInstanceKey]) -> str:
)
if deadlocked_depends_on_past:
err += (
- 'Some of the deadlocked tasks were unable to run because '
+ "Some of the deadlocked tasks were unable to run because "
'of "depends_on_past" relationships. Try running the '
- 'backfill with the option '
+ "backfill with the option "
'"ignore_first_depends_on_past=True" or passing "-I" at '
- 'the command line.'
+ "the command line."
)
- err += '\nThese tasks have succeeded:\n'
+ err += "\nThese tasks have succeeded:\n"
err += tabulate_ti_keys_set(ti_status.succeeded)
- err += '\n\nThese tasks are running:\n'
+ err += "\n\nThese tasks are running:\n"
err += tabulate_ti_keys_set(ti_status.running)
- err += '\n\nThese tasks have failed:\n'
+ err += "\n\nThese tasks have failed:\n"
err += tabulate_ti_keys_set(ti_status.failed)
- err += '\n\nThese tasks are skipped:\n'
+ err += "\n\nThese tasks are skipped:\n"
err += tabulate_ti_keys_set(ti_status.skipped)
- err += '\n\nThese tasks are deadlocked:\n'
+ err += "\n\nThese tasks are deadlocked:\n"
err += tabulate_ti_keys_set([ti.key for ti in ti_status.deadlocked])
return err
- def _get_dag_with_subdags(self) -> List[DAG]:
+ def _get_dag_with_subdags(self) -> list[DAG]:
return [self.dag] + self.dag.subdags
@provide_session
@@ -727,7 +747,7 @@ def _execute_dagruns(self, dagrun_infos, ti_status, executor, pickle_id, start_d
for dagrun_info in dagrun_infos:
for dag in self._get_dag_with_subdags():
dag_run = self._get_dag_run(dagrun_info, dag, session=session)
- tis_map = self._task_instances_for_dag_run(dag_run, session=session)
+ tis_map = self._task_instances_for_dag_run(dag, dag_run, session=session)
if dag_run is None:
continue
@@ -781,7 +801,7 @@ def _execute(self, session=None):
tasks_that_depend_on_past = [t.task_id for t in self.dag.task_dict.values() if t.depends_on_past]
if tasks_that_depend_on_past:
raise AirflowException(
- f'You cannot backfill backwards because one or more '
+ f"You cannot backfill backwards because one or more "
f'tasks depend_on_past: {",".join(tasks_that_depend_on_past)}'
)
dagrun_infos = dagrun_infos[::-1]
@@ -831,7 +851,7 @@ def _execute(self, session=None):
pickle_id = pickle.id
executor = self.executor
- executor.job_id = "backfill"
+ executor.job_id = self.id
executor.start()
ti_status.total_runs = len(dagrun_infos) # total dag runs in backfill
@@ -876,10 +896,10 @@ def _execute(self, session=None):
session.commit()
executor.end()
- self.log.info("Backfill done. Exiting.")
+ self.log.info("Backfill done for DAG %s. Exiting.", self.dag)
@provide_session
- def reset_state_for_orphaned_tasks(self, filter_by_dag_run=None, session=None):
+ def reset_state_for_orphaned_tasks(self, filter_by_dag_run=None, session=None) -> int | None:
"""
This function checks if there are any tasks in the dagrun (or all) that
have a schedule or queued states but are not known by the executor. If
@@ -889,7 +909,6 @@ def reset_state_for_orphaned_tasks(self, filter_by_dag_run=None, session=None):
:param filter_by_dag_run: the dag_run we want to process, None if all
:return: the number of TIs reset
- :rtype: int
"""
queued_tis = self.executor.queued_tasks
# also consider running as the state might not have changed in the db yet
@@ -934,7 +953,7 @@ def query(result, items):
reset_tis = helpers.reduce_in_chunks(query, tis_to_reset, [], self.max_tis_per_query)
- task_instance_str = '\n\t'.join(repr(x) for x in reset_tis)
+ task_instance_str = "\n\t".join(repr(x) for x in reset_tis)
session.flush()
self.log.info("Reset the following %s TaskInstances:\n\t%s", len(reset_tis), task_instance_str)
diff --git a/airflow/jobs/base_job.py b/airflow/jobs/base_job.py
index 7befccf956c41..d695615a5d8ba 100644
--- a/airflow/jobs/base_job.py
+++ b/airflow/jobs/base_job.py
@@ -15,12 +15,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-#
+from __future__ import annotations
from time import sleep
-from typing import Optional
-from sqlalchemy import Column, Index, Integer, String
+from sqlalchemy import Column, Index, Integer, String, case
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import backref, foreign, relationship
from sqlalchemy.orm.session import make_transient
@@ -29,9 +28,8 @@
from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.executors.executor_loader import ExecutorLoader
+from airflow.listeners.listener import get_listener_manager
from airflow.models.base import ID_LEN, Base
-from airflow.models.dagrun import DagRun
-from airflow.models.taskinstance import TaskInstance
from airflow.stats import Stats
from airflow.utils import timezone
from airflow.utils.helpers import convert_camel_to_snake
@@ -43,6 +41,12 @@
from airflow.utils.state import State
+def _resolve_dagrun_model():
+ from airflow.models.dagrun import DagRun
+
+ return DagRun
+
+
class BaseJob(Base, LoggingMixin):
"""
Abstract class to be derived for jobs. Jobs are processing items with state
@@ -66,24 +70,24 @@ class BaseJob(Base, LoggingMixin):
hostname = Column(String(500))
unixname = Column(String(1000))
- __mapper_args__ = {'polymorphic_on': job_type, 'polymorphic_identity': 'BaseJob'}
+ __mapper_args__ = {"polymorphic_on": job_type, "polymorphic_identity": "BaseJob"}
__table_args__ = (
- Index('job_type_heart', job_type, latest_heartbeat),
- Index('idx_job_state_heartbeat', state, latest_heartbeat),
- Index('idx_job_dag_id', dag_id),
+ Index("job_type_heart", job_type, latest_heartbeat),
+ Index("idx_job_state_heartbeat", state, latest_heartbeat),
+ Index("idx_job_dag_id", dag_id),
)
task_instances_enqueued = relationship(
- TaskInstance,
- primaryjoin=id == foreign(TaskInstance.queued_by_job_id), # type: ignore[has-type]
- backref=backref('queued_by_job', uselist=False),
+ "TaskInstance",
+ primaryjoin="BaseJob.id == foreign(TaskInstance.queued_by_job_id)",
+ backref=backref("queued_by_job", uselist=False),
)
dag_runs = relationship(
- DagRun,
- primaryjoin=id == foreign(DagRun.creating_job_id), # type: ignore[has-type]
- backref=backref('creating_job'),
+ "DagRun",
+ primaryjoin=lambda: BaseJob.id == foreign(_resolve_dagrun_model().creating_job_id),
+ backref="creating_job",
)
"""
@@ -92,7 +96,7 @@ class BaseJob(Base, LoggingMixin):
Only makes sense for SchedulerJob and BackfillJob instances.
"""
- heartrate = conf.getfloat('scheduler', 'JOB_HEARTBEAT_SEC')
+ heartrate = conf.getfloat("scheduler", "JOB_HEARTBEAT_SEC")
def __init__(self, executor=None, heartrate=None, *args, **kwargs):
self.hostname = get_hostname()
@@ -100,13 +104,14 @@ def __init__(self, executor=None, heartrate=None, *args, **kwargs):
self.executor = executor
self.executor_class = executor.__class__.__name__
else:
- self.executor_class = conf.get('core', 'EXECUTOR')
+ self.executor_class = conf.get("core", "EXECUTOR")
self.start_date = timezone.utcnow()
self.latest_heartbeat = timezone.utcnow()
if heartrate is not None:
self.heartrate = heartrate
self.unixname = getuser()
- self.max_tis_per_query: int = conf.getint('scheduler', 'max_tis_per_query')
+ self.max_tis_per_query: int = conf.getint("scheduler", "max_tis_per_query")
+ get_listener_manager().hook.on_starting(component=self)
super().__init__(*args, **kwargs)
@cached_property
@@ -115,18 +120,28 @@ def executor(self):
@classmethod
@provide_session
- def most_recent_job(cls, session=None) -> Optional['BaseJob']:
+ def most_recent_job(cls, session=None) -> BaseJob | None:
"""
Return the most recent job of this type, if any, based on last
heartbeat received.
+ Jobs in "running" state take precedence over others to make sure alive
+ job is returned if it is available.
This method should be called on a subclass (i.e. on SchedulerJob) to
return jobs of that type.
:param session: Database session
- :rtype: BaseJob or None
"""
- return session.query(cls).order_by(cls.latest_heartbeat.desc()).limit(1).first()
+ return (
+ session.query(cls)
+ .order_by(
+ # Put "running" jobs at the front.
+ case({State.RUNNING: 0}, value=cls.state, else_=1),
+ cls.latest_heartbeat.desc(),
+ )
+ .limit(1)
+ .first()
+ )
def is_alive(self, grace_multiplier=2.1):
"""
@@ -137,7 +152,6 @@ def is_alive(self, grace_multiplier=2.1):
:param grace_multiplier: multiplier of heartrate to require heart beat
within
- :rtype: boolean
"""
return (
self.state == State.RUNNING
@@ -153,7 +167,7 @@ def kill(self, session=None):
try:
self.on_kill()
except Exception as e:
- self.log.error('on_kill() method failed: %s', str(e))
+ self.log.error("on_kill() method failed: %s", str(e))
session.merge(job)
session.commit()
raise AirflowException("Job shut down externally.")
@@ -223,16 +237,16 @@ def heartbeat(self, only_if_necessary: bool = False):
previous_heartbeat = self.latest_heartbeat
self.heartbeat_callback(session=session)
- self.log.debug('[heartbeat]')
+ self.log.debug("[heartbeat]")
except OperationalError:
- Stats.incr(convert_camel_to_snake(self.__class__.__name__) + '_heartbeat_failure', 1, 1)
+ Stats.incr(convert_camel_to_snake(self.__class__.__name__) + "_heartbeat_failure", 1, 1)
self.log.exception("%s heartbeat got an exception", self.__class__.__name__)
# We didn't manage to heartbeat, so make sure that the timestamp isn't updated
self.latest_heartbeat = previous_heartbeat
def run(self):
"""Starts the job."""
- Stats.incr(self.__class__.__name__.lower() + '_start', 1, 1)
+ Stats.incr(self.__class__.__name__.lower() + "_start", 1, 1)
# Adding an entry in the DB
with create_session() as session:
self.state = State.RUNNING
@@ -251,11 +265,12 @@ def run(self):
self.state = State.FAILED
raise
finally:
+ get_listener_manager().hook.before_stopping(component=self)
self.end_date = timezone.utcnow()
session.merge(self)
session.commit()
- Stats.incr(self.__class__.__name__.lower() + '_end', 1, 1)
+ Stats.incr(self.__class__.__name__.lower() + "_end", 1, 1)
def _execute(self):
raise NotImplementedError("This method needs to be overridden")
diff --git a/airflow/jobs/local_task_job.py b/airflow/jobs/local_task_job.py
index 5711342e04d97..cc449e1fc56fb 100644
--- a/airflow/jobs/local_task_job.py
+++ b/airflow/jobs/local_task_job.py
@@ -15,34 +15,60 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-#
+from __future__ import annotations
+
import signal
-from typing import Optional
import psutil
-from sqlalchemy.exc import OperationalError
from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.jobs.base_job import BaseJob
from airflow.listeners.events import register_task_instance_state_events
from airflow.listeners.listener import get_listener_manager
-from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance
-from airflow.sentry import Sentry
from airflow.stats import Stats
from airflow.task.task_runner import get_task_runner
from airflow.utils import timezone
from airflow.utils.net import get_hostname
from airflow.utils.session import provide_session
-from airflow.utils.sqlalchemy import with_row_locks
from airflow.utils.state import State
+SIGSEGV_MESSAGE = """
+******************************************* Received SIGSEGV *******************************************
+SIGSEGV (Segmentation Violation) signal indicates Segmentation Fault error which refers to
+an attempt by a program/library to write or read outside its allocated memory.
+
+In Python environment usually this signal refers to libraries which use low level C API.
+Make sure that you use use right libraries/Docker Images
+for your architecture (Intel/ARM) and/or Operational System (Linux/macOS).
+
+Suggested way to debug
+======================
+ - Set environment variable 'PYTHONFAULTHANDLER' to 'true'.
+ - Start airflow services.
+ - Restart failed airflow task.
+ - Check 'scheduler' and 'worker' services logs for additional traceback
+ which might contain information about module/library where actual error happen.
+
+Known Issues
+============
+
+Note: Only Linux-based distros supported as "Production" execution environment for Airflow.
+
+macOS
+-----
+ 1. Due to limitations in Apple's libraries not every process might 'fork' safe.
+ One of the general error is unable to query the macOS system configuration for network proxies.
+ If your are not using a proxy you could disable it by set environment variable 'no_proxy' to '*'.
+ See: https://github.com/python/cpython/issues/58037 and https://bugs.python.org/issue30385#msg293958
+********************************************************************************************************"""
+
class LocalTaskJob(BaseJob):
"""LocalTaskJob runs a single task instance."""
- __mapper_args__ = {'polymorphic_identity': 'LocalTaskJob'}
+ __mapper_args__ = {"polymorphic_identity": "LocalTaskJob"}
def __init__(
self,
@@ -52,9 +78,9 @@ def __init__(
ignore_task_deps: bool = False,
ignore_ti_state: bool = False,
mark_success: bool = False,
- pickle_id: Optional[str] = None,
- pool: Optional[str] = None,
- external_executor_id: Optional[str] = None,
+ pickle_id: str | None = None,
+ pool: str | None = None,
+ external_executor_id: str | None = None,
*args,
**kwargs,
):
@@ -87,7 +113,26 @@ def signal_handler(signum, frame):
self.task_runner.terminate()
self.handle_task_exit(128 + signum)
+ def segfault_signal_handler(signum, frame):
+ """Setting sigmentation violation signal handler"""
+ self.log.critical(SIGSEGV_MESSAGE)
+ self.task_runner.terminate()
+ self.handle_task_exit(128 + signum)
+ raise AirflowException("Segmentation Fault detected.")
+
+ def sigusr2_debug_handler(signum, frame):
+ import sys
+ import threading
+ import traceback
+
+ id2name = {th.ident: th.name for th in threading.enumerate()}
+ for threadId, stack in sys._current_frames().items():
+ print(id2name[threadId])
+ traceback.print_stack(f=stack)
+
+ signal.signal(signal.SIGSEGV, segfault_signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
+ signal.signal(signal.SIGUSR2, sigusr2_debug_handler)
if not self.task_instance.check_and_change_state_before_execution(
mark_success=self.mark_success,
@@ -105,7 +150,7 @@ def signal_handler(signum, frame):
try:
self.task_runner.start()
- heartbeat_time_limit = conf.getint('scheduler', 'scheduler_zombie_task_threshold')
+ heartbeat_time_limit = conf.getint("scheduler", "scheduler_zombie_task_threshold")
# LocalTaskJob should not run callbacks, which are handled by TaskInstance._run_raw_task
# 1, LocalTaskJob does not parse DAG, thus cannot run callbacks
@@ -143,7 +188,7 @@ def signal_handler(signum, frame):
# This can only really happen if the worker can't read the DB for a long time
time_since_last_heartbeat = (timezone.utcnow() - self.latest_heartbeat).total_seconds()
if time_since_last_heartbeat > heartbeat_time_limit:
- Stats.incr('local_task_job_prolonged_heartbeat_failure', 1, 1)
+ Stats.incr("local_task_job_prolonged_heartbeat_failure", 1, 1)
self.log.error("Heartbeat time limit exceeded!")
raise AirflowException(
f"Time since last heartbeat({time_since_last_heartbeat:.2f}s) exceeded limit "
@@ -161,10 +206,11 @@ def handle_task_exit(self, return_code: int) -> None:
# Without setting this, heartbeat may get us
self.terminating = True
self.log.info("Task exited with return code %s", return_code)
+ self._log_return_code_metric(return_code)
if not self.task_instance.test_mode:
- if conf.getboolean('scheduler', 'schedule_after_task_execution', fallback=True):
- self._run_mini_scheduler_on_child_tasks()
+ if conf.getboolean("scheduler", "schedule_after_task_execution", fallback=True):
+ self.task_instance.schedule_downstream_tasks()
def on_kill(self):
self.task_runner.terminate()
@@ -195,7 +241,13 @@ def heartbeat_callback(self, session=None):
recorded_pid = ti.pid
same_process = recorded_pid == current_pid
- if ti.run_as_user or self.task_runner.run_as_user:
+ if recorded_pid is not None and (ti.run_as_user or self.task_runner.run_as_user):
+ # when running as another user, compare the task runner pid to the parent of
+ # the recorded pid because user delegation becomes an extra process level.
+ # However, if recorded_pid is None, pass that through as it signals the task
+ # runner process has already completed and been cleared out. `psutil.Process`
+ # uses the current process if the parameter is None, which is not what is intended
+ # for comparison.
recorded_pid = psutil.Process(ti.pid).ppid()
same_process = recorded_pid == current_pid
@@ -204,7 +256,7 @@ def heartbeat_callback(self, session=None):
"Recorded pid %s does not match the current pid %s", recorded_pid, current_pid
)
raise AirflowException("PID of job runner does not match")
- elif self.task_runner.return_code() is None and hasattr(self.task_runner, 'process'):
+ elif self.task_runner.return_code() is None and hasattr(self.task_runner, "process"):
if ti.state == State.SKIPPED:
# A DagRun timeout will cause tasks to be externally marked as skipped.
dagrun = ti.get_dagrun(session=session)
@@ -223,56 +275,10 @@ def heartbeat_callback(self, session=None):
self.terminating = True
self._state_change_checks += 1
- @provide_session
- @Sentry.enrich_errors
- def _run_mini_scheduler_on_child_tasks(self, session=None) -> None:
- try:
- # Re-select the row with a lock
- dag_run = with_row_locks(
- session.query(DagRun).filter_by(
- dag_id=self.dag_id,
- run_id=self.task_instance.run_id,
- ),
- session=session,
- ).one()
-
- task = self.task_instance.task
- assert task.dag # For Mypy.
-
- # Get a partial DAG with just the specific tasks we want to examine.
- # In order for dep checks to work correctly, we include ourself (so
- # TriggerRuleDep can check the state of the task we just executed).
- partial_dag = task.dag.partial_subset(
- task.downstream_task_ids,
- include_downstream=True,
- include_upstream=False,
- include_direct_upstream=True,
- )
-
- dag_run.dag = partial_dag
- info = dag_run.task_instance_scheduling_decisions(session)
-
- skippable_task_ids = {
- task_id for task_id in partial_dag.task_ids if task_id not in task.downstream_task_ids
- }
-
- schedulable_tis = [ti for ti in info.schedulable_tis if ti.task_id not in skippable_task_ids]
- for schedulable_ti in schedulable_tis:
- if not hasattr(schedulable_ti, "task"):
- schedulable_ti.task = task.dag.get_task(schedulable_ti.task_id)
-
- num = dag_run.schedule_tis(schedulable_tis)
- self.log.info("%d downstream tasks scheduled from follow-on schedule check", num)
-
- session.commit()
- except OperationalError as e:
- # Any kind of DB error here is _non fatal_ as this block is just an optimisation.
- self.log.info(
- "Skipping mini scheduling run due to exception: %s",
- e.statement,
- exc_info=True,
- )
- session.rollback()
+ def _log_return_code_metric(self, return_code: int):
+ Stats.incr(
+ f"local_task_job.task_exit.{self.id}.{self.dag_id}.{self.task_instance.task_id}.{return_code}"
+ )
@staticmethod
def _enable_task_listeners():
diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py
index 22ba5decb4111..b8b608efcd847 100644
--- a/airflow/jobs/scheduler_job.py
+++ b/airflow/jobs/scheduler_job.py
@@ -15,7 +15,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-#
+from __future__ import annotations
+
import itertools
import logging
import multiprocessing
@@ -25,37 +26,44 @@
import time
import warnings
from collections import defaultdict
-from datetime import timedelta
-from typing import Collection, DefaultDict, Dict, Iterator, List, Optional, Set, Tuple
+from datetime import datetime, timedelta
+from pathlib import Path
+from typing import TYPE_CHECKING, Collection, DefaultDict, Iterator
-from sqlalchemy import func, not_, or_, text
+from sqlalchemy import and_, func, not_, or_, text
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import load_only, selectinload
from sqlalchemy.orm.session import Session, make_transient
+from sqlalchemy.sql import expression
-from airflow import models, settings
+from airflow import settings
from airflow.callbacks.callback_requests import DagCallbackRequest, SlaCallbackRequest, TaskCallbackRequest
-from airflow.callbacks.database_callback_sink import DatabaseCallbackSink
from airflow.callbacks.pipe_callback_sink import PipeCallbackSink
from airflow.configuration import conf
-from airflow.dag_processing.manager import DagFileProcessorAgent
+from airflow.exceptions import RemovedInAirflow3Warning
from airflow.executors.executor_loader import UNPICKLEABLE_EXECUTORS
from airflow.jobs.base_job import BaseJob
-from airflow.jobs.local_task_job import LocalTaskJob
-from airflow.models import DAG
-from airflow.models.dag import DagModel
+from airflow.models.dag import DAG, DagModel
from airflow.models.dagbag import DagBag
from airflow.models.dagrun import DagRun
+from airflow.models.dataset import (
+ DagScheduleDatasetReference,
+ DatasetDagRunQueue,
+ DatasetEvent,
+ DatasetModel,
+ TaskOutletDatasetReference,
+)
from airflow.models.serialized_dag import SerializedDagModel
from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance, TaskInstanceKey
from airflow.stats import Stats
from airflow.ti_deps.dependencies_states import EXECUTION_STATES
+from airflow.timetables.simple import DatasetTriggeredTimetable
from airflow.utils import timezone
-from airflow.utils.docs import get_docs_url
from airflow.utils.event_scheduler import EventScheduler
from airflow.utils.retries import MAX_DB_RETRIES, retry_db_transaction, run_with_db_retries
-from airflow.utils.session import create_session, provide_session
+from airflow.utils.session import NEW_SESSION, create_session, provide_session
from airflow.utils.sqlalchemy import (
+ CommitProhibitorGuard,
is_lock_not_available_error,
prohibit_commit,
skip_locked,
@@ -65,17 +73,22 @@
from airflow.utils.state import DagRunState, State, TaskInstanceState
from airflow.utils.types import DagRunType
-TI = models.TaskInstance
-DR = models.DagRun
-DM = models.DagModel
+if TYPE_CHECKING:
+ from types import FrameType
+
+ from airflow.dag_processing.manager import DagFileProcessorAgent
+
+TI = TaskInstance
+DR = DagRun
+DM = DagModel
-def _is_parent_process():
+def _is_parent_process() -> bool:
"""
Returns True if the current process is the parent process. False if the current process is a child
process started by multiprocessing.
"""
- return multiprocessing.current_process().name == 'MainProcess'
+ return multiprocessing.current_process().name == "MainProcess"
class SchedulerJob(BaseJob):
@@ -100,18 +113,18 @@ class SchedulerJob(BaseJob):
:param log: override the default Logger
"""
- __mapper_args__ = {'polymorphic_identity': 'SchedulerJob'}
- heartrate: int = conf.getint('scheduler', 'SCHEDULER_HEARTBEAT_SEC')
+ __mapper_args__ = {"polymorphic_identity": "SchedulerJob"}
+ heartrate: int = conf.getint("scheduler", "SCHEDULER_HEARTBEAT_SEC")
def __init__(
self,
subdir: str = settings.DAGS_FOLDER,
- num_runs: int = conf.getint('scheduler', 'num_runs'),
+ num_runs: int = conf.getint("scheduler", "num_runs"),
num_times_parse_dags: int = -1,
- scheduler_idle_sleep_time: float = conf.getfloat('scheduler', 'scheduler_idle_sleep_time'),
+ scheduler_idle_sleep_time: float = conf.getfloat("scheduler", "scheduler_idle_sleep_time"),
do_pickle: bool = False,
- log: Optional[logging.Logger] = None,
- processor_poll_interval: Optional[float] = None,
+ log: logging.Logger | None = None,
+ processor_poll_interval: float | None = None,
*args,
**kwargs,
):
@@ -127,14 +140,15 @@ def __init__(
warnings.warn(
"The 'processor_poll_interval' parameter is deprecated. "
"Please use 'scheduler_idle_sleep_time'.",
- DeprecationWarning,
+ RemovedInAirflow3Warning,
stacklevel=2,
)
scheduler_idle_sleep_time = processor_poll_interval
self._scheduler_idle_sleep_time = scheduler_idle_sleep_time
# How many seconds do we wait for tasks to heartbeat before mark them as zombies.
- self._zombie_threshold_secs = conf.getint('scheduler', 'scheduler_zombie_task_threshold')
+ self._zombie_threshold_secs = conf.getint("scheduler", "scheduler_zombie_task_threshold")
self._standalone_dag_processor = conf.getboolean("scheduler", "standalone_dag_processor")
+ self._dag_stale_not_seen_duration = conf.getint("scheduler", "dag_stale_not_seen_duration")
self.do_pickle = do_pickle
super().__init__(*args, **kwargs)
@@ -142,28 +156,13 @@ def __init__(
self._log = log
# Check what SQL backend we use
- sql_conn: str = conf.get_mandatory_value('database', 'sql_alchemy_conn').lower()
- self.using_sqlite = sql_conn.startswith('sqlite')
- self.using_mysql = sql_conn.startswith('mysql')
+ sql_conn: str = conf.get_mandatory_value("database", "sql_alchemy_conn").lower()
+ self.using_sqlite = sql_conn.startswith("sqlite")
# Dag Processor agent - not used in Dag Processor standalone mode.
- self.processor_agent: Optional[DagFileProcessorAgent] = None
+ self.processor_agent: DagFileProcessorAgent | None = None
self.dagbag = DagBag(dag_folder=self.subdir, read_dags_from_db=True, load_op_links=False)
- self._paused_dag_without_running_dagruns: Set = set()
-
- if conf.getboolean('smart_sensor', 'use_smart_sensor'):
- compatible_sensors = set(
- map(
- lambda l: l.strip(),
- conf.get_mandatory_value('smart_sensor', 'sensors_enabled').split(','),
- )
- )
- docs_url = get_docs_url('concepts/smart-sensors.html#migrating-to-deferrable-operators')
- warnings.warn(
- f'Smart sensors are deprecated, yet can be used for {compatible_sensors} sensors.'
- f' Please use Deferrable Operators instead. See {docs_url} for more info.',
- DeprecationWarning,
- )
+ self._paused_dag_without_running_dagruns: set = set()
def register_signals(self) -> None:
"""Register signals that stop child processes"""
@@ -171,7 +170,7 @@ def register_signals(self) -> None:
signal.signal(signal.SIGTERM, self._exit_gracefully)
signal.signal(signal.SIGUSR2, self._debug_dump)
- def _exit_gracefully(self, signum, frame) -> None:
+ def _exit_gracefully(self, signum: int, frame: FrameType | None) -> None:
"""Helper method to clean up processor_agent to avoid leaving orphan processes."""
if not _is_parent_process():
# Only the parent process should perform the cleanup.
@@ -182,7 +181,7 @@ def _exit_gracefully(self, signum, frame) -> None:
self.processor_agent.end()
sys.exit(os.EX_OK)
- def _debug_dump(self, signum, frame):
+ def _debug_dump(self, signum: int, frame: FrameType | None) -> None:
if not _is_parent_process():
# Only the parent process should perform the debug dump.
return
@@ -197,7 +196,7 @@ def _debug_dump(self, signum, frame):
self.executor.debug_dump()
self.log.info("-" * 80)
- def is_alive(self, grace_multiplier: Optional[float] = None) -> bool:
+ def is_alive(self, grace_multiplier: float | None = None) -> bool:
"""
Is this SchedulerJob alive?
@@ -207,44 +206,40 @@ def is_alive(self, grace_multiplier: Optional[float] = None) -> bool:
``grace_multiplier`` is accepted for compatibility with the parent class.
- :rtype: boolean
"""
if grace_multiplier is not None:
# Accept the same behaviour as superclass
return super().is_alive(grace_multiplier=grace_multiplier)
- scheduler_health_check_threshold: int = conf.getint('scheduler', 'scheduler_health_check_threshold')
+ scheduler_health_check_threshold: int = conf.getint("scheduler", "scheduler_health_check_threshold")
return (
self.state == State.RUNNING
and (timezone.utcnow() - self.latest_heartbeat).total_seconds() < scheduler_health_check_threshold
)
- @provide_session
def __get_concurrency_maps(
- self, states: List[TaskInstanceState], session: Session = None
- ) -> Tuple[DefaultDict[str, int], DefaultDict[Tuple[str, str], int]]:
+ self, states: list[TaskInstanceState], session: Session
+ ) -> tuple[DefaultDict[str, int], DefaultDict[tuple[str, str], int]]:
"""
Get the concurrency maps.
:param states: List of states to query for
:return: A map from (dag_id, task_id) to # of task instances and
a map from (dag_id, task_id) to # of task instances in the given state list
- :rtype: tuple[dict[str, int], dict[tuple[str, str], int]]
"""
- ti_concurrency_query: List[Tuple[str, str, int]] = (
- session.query(TI.task_id, TI.dag_id, func.count('*'))
+ ti_concurrency_query: list[tuple[str, str, int]] = (
+ session.query(TI.task_id, TI.dag_id, func.count("*"))
.filter(TI.state.in_(states))
.group_by(TI.task_id, TI.dag_id)
).all()
dag_map: DefaultDict[str, int] = defaultdict(int)
- task_map: DefaultDict[Tuple[str, str], int] = defaultdict(int)
+ task_map: DefaultDict[tuple[str, str], int] = defaultdict(int)
for result in ti_concurrency_query:
task_id, dag_id, count = result
dag_map[dag_id] += count
task_map[(dag_id, task_id)] = count
return dag_map, task_map
- @provide_session
- def _executable_task_instances_to_queued(self, max_tis: int, session: Session = None) -> List[TI]:
+ def _executable_task_instances_to_queued(self, max_tis: int, session: Session) -> list[TI]:
"""
Finds TIs that are ready for execution with respect to pool limits,
dag max_active_tasks, executor state, and priority.
@@ -252,9 +247,10 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session =
:param max_tis: Maximum number of TIs to queue in this loop.
:return: list[airflow.models.TaskInstance]
"""
+ from airflow.models.pool import Pool
from airflow.utils.db import DBLocks
- executable_tis: List[TI] = []
+ executable_tis: list[TI] = []
if session.get_bind().dialect.name == "postgresql":
# Optimization: to avoid littering the DB errors of "ERROR: canceling statement due to lock
@@ -268,16 +264,16 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session =
if not lock_acquired:
# Throw an error like the one that would happen with NOWAIT
raise OperationalError(
- "Failed to acquire advisory lock", params=None, orig=RuntimeError('55P03')
+ "Failed to acquire advisory lock", params=None, orig=RuntimeError("55P03")
)
# Get the pool settings. We get a lock on the pool rows, treating this as a "critical section"
# Throws an exception if lock cannot be obtained, rather than blocking
- pools = models.Pool.slots_stats(lock_rows=True, session=session)
+ pools = Pool.slots_stats(lock_rows=True, session=session)
# If the pools are full, there is no point doing anything!
# If _somehow_ the pool is overfull, don't let the limit go negative - it breaks SQL
- pool_slots_free = sum(max(0, pool['open']) for pool in pools.values())
+ pool_slots_free = sum(max(0, pool["open"]) for pool in pools.values())
if pool_slots_free == 0:
self.log.debug("All pools are full!")
@@ -285,11 +281,11 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session =
max_tis = min(max_tis, pool_slots_free)
- starved_pools = {pool_name for pool_name, stats in pools.items() if stats['open'] <= 0}
+ starved_pools = {pool_name for pool_name, stats in pools.items() if stats["open"] <= 0}
# dag_id to # of running tasks and (dag_id, task_id) to # of running tasks.
dag_active_tasks_map: DefaultDict[str, int]
- task_concurrency_map: DefaultDict[Tuple[str, str], int]
+ task_concurrency_map: DefaultDict[tuple[str, str], int]
dag_active_tasks_map, task_concurrency_map = self.__get_concurrency_maps(
states=list(EXECUTION_STATES), session=session
)
@@ -299,8 +295,8 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session =
num_starving_tasks_total = 0
# dag and task ids that can't be queued because of concurrency limits
- starved_dags: Set[str] = set()
- starved_tasks: Set[Tuple[str, str]] = set()
+ starved_dags: set[str] = set()
+ starved_tasks: set[tuple[str, str]] = set()
pool_num_starving_tasks: DefaultDict[str, int] = defaultdict(int)
@@ -315,13 +311,14 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session =
# and the dag is not paused
query = (
session.query(TI)
+ .with_hint(TI, "USE INDEX (ti_state)", dialect_name="mysql")
.join(TI.dag_run)
.filter(DR.run_type != DagRunType.BACKFILL_JOB, DR.state == DagRunState.RUNNING)
.join(TI.dag_model)
.filter(not_(DM.is_paused))
.filter(TI.state == TaskInstanceState.SCHEDULED)
- .options(selectinload('dag_model'))
- .order_by(-TI.priority_weight, DR.execution_date)
+ .options(selectinload("dag_model"))
+ .order_by(-TI.priority_weight, DR.execution_date, TI.map_index)
)
if starved_pools:
@@ -336,12 +333,21 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session =
query = query.limit(max_tis)
- task_instances_to_examine: List[TI] = with_row_locks(
- query,
- of=TI,
- session=session,
- **skip_locked(session=session),
- ).all()
+ timer = Stats.timer("scheduler.critical_section_query_duration")
+ timer.start()
+
+ try:
+ task_instances_to_examine: list[TI] = with_row_locks(
+ query,
+ of=TI,
+ session=session,
+ **skip_locked(session=session),
+ ).all()
+ timer.stop(send=True)
+ except OperationalError as e:
+ timer.stop(send=False)
+ raise e
+
# TODO[HA]: This was wrong before anyway, as it only looked at a sub-set of dags, not everything.
# Stats.gauge('scheduler.tasks.pending', len(task_instances_to_examine))
@@ -444,12 +450,12 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session =
dag_id,
task_instance,
)
- session.query(TI).filter(TI.dag_id == dag_id, TI.state == State.SCHEDULED).update(
- {TI.state: State.FAILED}, synchronize_session='fetch'
- )
+ session.query(TI).filter(
+ TI.dag_id == dag_id, TI.state == TaskInstanceState.SCHEDULED
+ ).update({TI.state: TaskInstanceState.FAILED}, synchronize_session="fetch")
continue
- task_concurrency_limit: Optional[int] = None
+ task_concurrency_limit: int | None = None
if serialized_dag.has_task(task_instance.task_id):
task_concurrency_limit = serialized_dag.get_task(
task_instance.task_id
@@ -494,11 +500,11 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session =
)
for pool_name, num_starving_tasks in pool_num_starving_tasks.items():
- Stats.gauge(f'pool.starving_tasks.{pool_name}', num_starving_tasks)
+ Stats.gauge(f"pool.starving_tasks.{pool_name}", num_starving_tasks)
- Stats.gauge('scheduler.tasks.starving', num_starving_tasks_total)
- Stats.gauge('scheduler.tasks.running', num_tasks_in_executor)
- Stats.gauge('scheduler.tasks.executable', len(executable_tis))
+ Stats.gauge("scheduler.tasks.starving", num_starving_tasks_total)
+ Stats.gauge("scheduler.tasks.running", num_tasks_in_executor)
+ Stats.gauge("scheduler.tasks.executable", len(executable_tis))
if len(executable_tis) > 0:
task_instance_str = "\n\t".join(repr(x) for x in executable_tis)
@@ -521,10 +527,7 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session =
make_transient(ti)
return executable_tis
- @provide_session
- def _enqueue_task_instances_with_queued_state(
- self, task_instances: List[TI], session: Session = None
- ) -> None:
+ def _enqueue_task_instances_with_queued_state(self, task_instances: list[TI], session: Session) -> None:
"""
Takes task_instances, which should have been set to queued, and enqueues them
with the executor.
@@ -581,14 +584,13 @@ def _critical_section_enqueue_task_instances(self, session: Session) -> int:
self._enqueue_task_instances_with_queued_state(queued_tis, session=session)
return len(queued_tis)
- @provide_session
- def _process_executor_events(self, session: Session = None) -> int:
+ def _process_executor_events(self, session: Session) -> int:
"""Respond to executor events."""
if not self._standalone_dag_processor and not self.processor_agent:
raise ValueError("Processor agent is not started.")
- ti_primary_key_to_try_number_map: Dict[Tuple[str, str, str, int], int] = {}
+ ti_primary_key_to_try_number_map: dict[tuple[str, str, str, int], int] = {}
event_buffer = self.executor.get_event_buffer()
- tis_with_right_state: List[TaskInstanceKey] = []
+ tis_with_right_state: list[TaskInstanceKey] = []
# Report execution
for ti_key, value in event_buffer.items():
@@ -614,7 +616,7 @@ def _process_executor_events(self, session: Session = None) -> int:
# Check state of finished tasks
filter_for_tis = TI.filter_for_tis(tis_with_right_state)
- query = session.query(TI).filter(filter_for_tis).options(selectinload('dag_model'))
+ query = session.query(TI).filter(filter_for_tis).options(selectinload("dag_model"))
# row lock this entire set of taskinstances to make sure the scheduler doesn't fail when we have
# multi-schedulers
tis: Iterator[TI] = with_row_locks(
@@ -628,7 +630,6 @@ def _process_executor_events(self, session: Session = None) -> int:
buffer_key = ti.key.with_try_number(try_number)
state, info = event_buffer.pop(buffer_key)
- # TODO: should we fail RUNNING as well, as we do in Backfills?
if state == TaskInstanceState.QUEUED:
ti.external_executor_id = info
self.log.info("Setting external_id for %s to %s", ti, info)
@@ -664,8 +665,21 @@ def _process_executor_events(self, session: Session = None) -> int:
ti.pid,
)
- if ti.try_number == buffer_key.try_number and ti.state == State.QUEUED:
- Stats.incr('scheduler.tasks.killed_externally')
+ # There are two scenarios why the same TI with the same try_number is queued
+ # after executor is finished with it:
+ # 1) the TI was killed externally and it had no time to mark itself failed
+ # - in this case we should mark it as failed here.
+ # 2) the TI has been requeued after getting deferred - in this case either our executor has it
+ # or the TI is queued by another job. Either ways we should not fail it.
+
+ # All of this could also happen if the state is "running",
+ # but that is handled by the zombie detection.
+
+ ti_queued = ti.try_number == buffer_key.try_number and ti.state == TaskInstanceState.QUEUED
+ ti_requeued = ti.queued_by_job_id != self.id or self.executor.has_task(ti)
+
+ if ti_queued and not ti_requeued:
+ Stats.incr("scheduler.tasks.killed_externally")
msg = (
"Executor reports task instance %s finished (%s) although the "
"task says its %s. (Info: %s) Was the task killed externally?"
@@ -686,6 +700,7 @@ def _process_executor_events(self, session: Session = None) -> int:
full_filepath=ti.dag_model.fileloc,
simple_task_instance=SimpleTaskInstance.from_ti(ti),
msg=msg % (ti, state, ti.state, info),
+ processor_subdir=ti.dag_model.processor_subdir,
)
self.executor.send_callback(request)
else:
@@ -694,6 +709,8 @@ def _process_executor_events(self, session: Session = None) -> int:
return len(event_buffer)
def _execute(self) -> None:
+ from airflow.dag_processing.manager import DagFileProcessorAgent
+
self.log.info("Starting the scheduler")
# DAGs can be pickled for easier remote execution by some executors
@@ -705,11 +722,11 @@ def _execute(self) -> None:
# so the scheduler job and DAG parser don't access the DB at the same time.
async_mode = not self.using_sqlite
- processor_timeout_seconds: int = conf.getint('core', 'dag_file_processor_timeout')
+ processor_timeout_seconds: int = conf.getint("core", "dag_file_processor_timeout")
processor_timeout = timedelta(seconds=processor_timeout_seconds)
if not self._standalone_dag_processor:
self.processor_agent = DagFileProcessorAgent(
- dag_directory=self.subdir,
+ dag_directory=Path(self.subdir),
max_runs=self.num_times_parse_dags,
processor_timeout=processor_timeout,
dag_ids=[],
@@ -725,6 +742,8 @@ def _execute(self) -> None:
get_sink_pipe=self.processor_agent.get_callbacks_pipe
)
else:
+ from airflow.callbacks.database_callback_sink import DatabaseCallbackSink
+
self.log.debug("Using DatabaseCallbackSink as callback sink.")
self.executor.callback_sink = DatabaseCallbackSink()
@@ -750,7 +769,7 @@ def _execute(self) -> None:
self.log.info(
"Deactivating DAGs that haven't been touched since %s", execute_start_time.isoformat()
)
- models.DAG.deactivate_stale_dags(execute_start_time)
+ DAG.deactivate_stale_dags(execute_start_time)
settings.Session.remove() # type: ignore
except Exception:
@@ -768,25 +787,32 @@ def _execute(self) -> None:
self.log.exception("Exception when executing DagFileProcessorAgent.end")
self.log.info("Exited execute loop")
- def _update_dag_run_state_for_paused_dags(self):
+ @provide_session
+ def _update_dag_run_state_for_paused_dags(self, session: Session = NEW_SESSION) -> None:
try:
- paused_dag_ids = DagModel.get_all_paused_dag_ids()
- for dag_id in paused_dag_ids:
- if dag_id in self._paused_dag_without_running_dagruns:
- continue
-
- dag = SerializedDagModel.get_dag(dag_id)
+ paused_runs = (
+ session.query(DagRun)
+ .join(DagRun.dag_model)
+ .join(TaskInstance)
+ .filter(
+ DagModel.is_paused == expression.true(),
+ DagRun.state == DagRunState.RUNNING,
+ DagRun.run_type != DagRunType.BACKFILL_JOB,
+ )
+ .having(DagRun.last_scheduling_decision <= func.max(TaskInstance.updated_at))
+ .group_by(DagRun)
+ )
+ for dag_run in paused_runs:
+ dag = self.dagbag.get_dag(dag_run.dag_id, session=session)
if dag is None:
continue
- dag_runs = DagRun.find(dag_id=dag_id, state=State.RUNNING)
- for dag_run in dag_runs:
- dag_run.dag = dag
- _, callback_to_run = dag_run.update_state(execute_callbacks=False)
- if callback_to_run:
- self._send_dag_callbacks_to_processor(dag, callback_to_run)
- self._paused_dag_without_running_dagruns.add(dag_id)
+
+ dag_run.dag = dag
+ _, callback_to_run = dag_run.update_state(execute_callbacks=False, session=session)
+ if callback_to_run:
+ self._send_dag_callbacks_to_processor(dag, callback_to_run)
except Exception as e: # should not fail the scheduler
- self.log.exception('Failed to update dag run state for paused dags due to %s', str(e))
+ self.log.exception("Failed to update dag run state for paused dags due to %s", str(e))
def _run_scheduler_loop(self) -> None:
"""
@@ -803,11 +829,10 @@ def _run_scheduler_loop(self) -> None:
.. image:: ../docs/apache-airflow/img/scheduler_loop.jpg
- :rtype: None
"""
if not self.processor_agent and not self._standalone_dag_processor:
raise ValueError("Processor agent is not started.")
- is_unit_test: bool = conf.getboolean('core', 'unit_test_mode')
+ is_unit_test: bool = conf.getboolean("core", "unit_test_mode")
timers = EventScheduler()
@@ -815,28 +840,39 @@ def _run_scheduler_loop(self) -> None:
self.adopt_or_reset_orphaned_tasks()
timers.call_regular_interval(
- conf.getfloat('scheduler', 'orphaned_tasks_check_interval', fallback=300.0),
+ conf.getfloat("scheduler", "orphaned_tasks_check_interval", fallback=300.0),
self.adopt_or_reset_orphaned_tasks,
)
timers.call_regular_interval(
- conf.getfloat('scheduler', 'trigger_timeout_check_interval', fallback=15.0),
+ conf.getfloat("scheduler", "trigger_timeout_check_interval", fallback=15.0),
self.check_trigger_timeouts,
)
timers.call_regular_interval(
- conf.getfloat('scheduler', 'pool_metrics_interval', fallback=5.0),
+ conf.getfloat("scheduler", "pool_metrics_interval", fallback=5.0),
self._emit_pool_metrics,
)
timers.call_regular_interval(
- conf.getfloat('scheduler', 'zombie_detection_interval', fallback=10.0),
+ conf.getfloat("scheduler", "zombie_detection_interval", fallback=10.0),
self._find_zombies,
)
timers.call_regular_interval(60.0, self._update_dag_run_state_for_paused_dags)
+ timers.call_regular_interval(
+ conf.getfloat("scheduler", "parsing_cleanup_interval"),
+ self._orphan_unreferenced_datasets,
+ )
+
+ if self._standalone_dag_processor:
+ timers.call_regular_interval(
+ conf.getfloat("scheduler", "parsing_cleanup_interval"),
+ self._cleanup_stale_dags,
+ )
+
for loop_count in itertools.count(start=1):
- with Stats.timer() as timer:
+ with Stats.timer("scheduler.scheduler_loop_duration") as timer:
if self.using_sqlite and self.processor_agent:
self.processor_agent.run_single_parsing_loop()
@@ -885,7 +921,7 @@ def _run_scheduler_loop(self) -> None:
)
break
- def _do_scheduling(self, session) -> int:
+ def _do_scheduling(self, session: Session) -> int:
"""
This function is where the main scheduling decisions take places. It:
@@ -913,7 +949,6 @@ def _do_scheduling(self, session) -> int:
See docs of _critical_section_enqueue_task_instances for more.
:return: Number of TIs enqueued in this iteration
- :rtype: int
"""
# Put a check in place to make sure we don't commit unexpectedly
with prohibit_commit(session) as guard:
@@ -926,12 +961,7 @@ def _do_scheduling(self, session) -> int:
# Bulk fetch the currently active dag runs for the dags we are
# examining, rather than making one query per DagRun
- callback_tuples = []
- for dag_run in dag_runs:
- callback_to_run = self._schedule_dag_run(dag_run, session)
- callback_tuples.append((dag_run, callback_to_run))
-
- guard.commit()
+ callback_tuples = self._schedule_all_dag_runs(guard, dag_runs, session)
# Send the callbacks after we commit to ensure the context is up to date when it gets run
for dag_run, callback_to_run in callback_tuples:
@@ -954,7 +984,7 @@ def _do_scheduling(self, session) -> int:
num_queued_tis = 0
else:
try:
- timer = Stats.timer('scheduler.critical_section_duration')
+ timer = Stats.timer("scheduler.critical_section_duration")
timer.start()
# Find anything TIs in state SCHEDULED, try to QUEUE it (send it to the executor)
@@ -968,7 +998,7 @@ def _do_scheduling(self, session) -> int:
if is_lock_not_available_error(error=e):
self.log.debug("Critical section lock held by another Scheduler")
- Stats.incr('scheduler.critical_section_busy')
+ Stats.incr("scheduler.critical_section_busy")
session.rollback()
return 0
raise
@@ -983,10 +1013,19 @@ def _get_next_dagruns_to_examine(self, state: DagRunState, session: Session):
return DagRun.next_dagruns_to_examine(state, session)
@retry_db_transaction
- def _create_dagruns_for_dags(self, guard, session):
+ def _create_dagruns_for_dags(self, guard: CommitProhibitorGuard, session: Session) -> None:
"""Find Dag Models needing DagRuns and Create Dag Runs with retries in case of OperationalError"""
- query = DagModel.dags_needing_dagruns(session)
- self._create_dag_runs(query.all(), session)
+ query, dataset_triggered_dag_info = DagModel.dags_needing_dagruns(session)
+ all_dags_needing_dag_runs = set(query.all())
+ dataset_triggered_dags = [
+ dag for dag in all_dags_needing_dag_runs if dag.dag_id in dataset_triggered_dag_info
+ ]
+ non_dataset_dags = all_dags_needing_dag_runs.difference(dataset_triggered_dags)
+ self._create_dag_runs(non_dataset_dags, session)
+ if dataset_triggered_dags:
+ self._create_dag_runs_dataset_triggered(
+ dataset_triggered_dags, dataset_triggered_dag_info, session
+ )
# commit the session - Release the write lock on DagModel table.
guard.commit()
@@ -1052,7 +1091,107 @@ def _create_dag_runs(self, dag_models: Collection[DagModel], session: Session) -
# TODO[HA]: Should we do a session.flush() so we don't have to keep lots of state/object in
# memory for larger dags? or expunge_all()
- def _should_update_dag_next_dagruns(self, dag, dag_model: DagModel, total_active_runs) -> bool:
+ def _create_dag_runs_dataset_triggered(
+ self,
+ dag_models: Collection[DagModel],
+ dataset_triggered_dag_info: dict[str, tuple[datetime, datetime]],
+ session: Session,
+ ) -> None:
+ """For DAGs that are triggered by datasets, create dag runs."""
+ # Bulk Fetch DagRuns with dag_id and execution_date same
+ # as DagModel.dag_id and DagModel.next_dagrun
+ # This list is used to verify if the DagRun already exist so that we don't attempt to create
+ # duplicate dag runs
+ exec_dates = {
+ dag_id: timezone.coerce_datetime(last_time)
+ for dag_id, (_, last_time) in dataset_triggered_dag_info.items()
+ }
+ existing_dagruns: set[tuple[str, timezone.DateTime]] = set(
+ session.query(DagRun.dag_id, DagRun.execution_date).filter(
+ tuple_in_condition((DagRun.dag_id, DagRun.execution_date), exec_dates.items())
+ )
+ )
+
+ for dag_model in dag_models:
+ dag = self.dagbag.get_dag(dag_model.dag_id, session=session)
+ if not dag:
+ self.log.error("DAG '%s' not found in serialized_dag table", dag_model.dag_id)
+ continue
+
+ if not isinstance(dag.timetable, DatasetTriggeredTimetable):
+ self.log.error(
+ "DAG '%s' was dataset-scheduled, but didn't have a DatasetTriggeredTimetable!",
+ dag_model.dag_id,
+ )
+ continue
+
+ dag_hash = self.dagbag.dags_hash.get(dag.dag_id)
+
+ # Explicitly check if the DagRun already exists. This is an edge case
+ # where a Dag Run is created but `DagModel.next_dagrun` and `DagModel.next_dagrun_create_after`
+ # are not updated.
+ # We opted to check DagRun existence instead
+ # of catching an Integrity error and rolling back the session i.e
+ # we need to set dag.next_dagrun_info if the Dag Run already exists or if we
+ # create a new one. This is so that in the next Scheduling loop we try to create new runs
+ # instead of falling in a loop of Integrity Error.
+ exec_date = exec_dates[dag.dag_id]
+ if (dag.dag_id, exec_date) not in existing_dagruns:
+
+ previous_dag_run = (
+ session.query(DagRun)
+ .filter(
+ DagRun.dag_id == dag.dag_id,
+ DagRun.execution_date < exec_date,
+ DagRun.run_type == DagRunType.DATASET_TRIGGERED,
+ )
+ .order_by(DagRun.execution_date.desc())
+ .first()
+ )
+ dataset_event_filters = [
+ DagScheduleDatasetReference.dag_id == dag.dag_id,
+ DatasetEvent.timestamp <= exec_date,
+ ]
+ if previous_dag_run:
+ dataset_event_filters.append(DatasetEvent.timestamp > previous_dag_run.execution_date)
+
+ dataset_events = (
+ session.query(DatasetEvent)
+ .join(
+ DagScheduleDatasetReference,
+ DatasetEvent.dataset_id == DagScheduleDatasetReference.dataset_id,
+ )
+ .join(DatasetEvent.source_dag_run)
+ .filter(*dataset_event_filters)
+ .all()
+ )
+
+ data_interval = dag.timetable.data_interval_for_events(exec_date, dataset_events)
+ run_id = dag.timetable.generate_run_id(
+ run_type=DagRunType.DATASET_TRIGGERED,
+ logical_date=exec_date,
+ data_interval=data_interval,
+ session=session,
+ events=dataset_events,
+ )
+
+ dag_run = dag.create_dagrun(
+ run_id=run_id,
+ run_type=DagRunType.DATASET_TRIGGERED,
+ execution_date=exec_date,
+ data_interval=data_interval,
+ state=DagRunState.QUEUED,
+ external_trigger=False,
+ session=session,
+ dag_hash=dag_hash,
+ creating_job_id=self.id,
+ )
+ dag_run.consumed_dataset_events.extend(dataset_events)
+ session.query(DatasetDagRunQueue).filter(
+ DatasetDagRunQueue.target_dag_id == dag_run.dag_id
+ ).delete()
+
+ def _should_update_dag_next_dagruns(self, dag, dag_model: DagModel, total_active_runs: int) -> bool:
"""Check if the dag's next_dagruns_create_after should be updated."""
if total_active_runs >= dag.max_active_runs:
self.log.info(
@@ -1065,10 +1204,7 @@ def _should_update_dag_next_dagruns(self, dag, dag_model: DagModel, total_active
return False
return True
- def _start_queued_dagruns(
- self,
- session: Session,
- ) -> None:
+ def _start_queued_dagruns(self, session: Session) -> None:
"""Find DagRuns in queued state and decide moving them to running state"""
dag_runs = self._get_next_dagruns_to_examine(DagRunState.QUEUED, session)
@@ -1088,10 +1224,9 @@ def _update_state(dag: DAG, dag_run: DagRun):
# always happening immediately after the data interval.
expected_start_date = dag.get_run_data_interval(dag_run).end
schedule_delay = dag_run.start_date - expected_start_date
- Stats.timing(f'dagrun.schedule_delay.{dag.dag_id}', schedule_delay)
+ Stats.timing(f"dagrun.schedule_delay.{dag.dag_id}", schedule_delay)
for dag_run in dag_runs:
-
dag = dag_run.dag = self.dagbag.get_dag(dag_run.dag_id, session=session)
if not dag:
self.log.error("DAG '%s' not found in serialized_dag table", dag_run.dag_id)
@@ -1108,19 +1243,32 @@ def _update_state(dag: DAG, dag_run: DagRun):
else:
active_runs_of_dags[dag_run.dag_id] += 1
_update_state(dag, dag_run)
+ dag_run.notify_dagrun_state_changed()
+
+ @retry_db_transaction
+ def _schedule_all_dag_runs(self, guard, dag_runs, session):
+ """Makes scheduling decisions for all `dag_runs`"""
+ callback_tuples = []
+ for dag_run in dag_runs:
+ callback_to_run = self._schedule_dag_run(dag_run, session)
+ callback_tuples.append((dag_run, callback_to_run))
+
+ guard.commit()
+
+ return callback_tuples
def _schedule_dag_run(
self,
dag_run: DagRun,
session: Session,
- ) -> Optional[DagCallbackRequest]:
+ ) -> DagCallbackRequest | None:
"""
Make scheduling decisions about an individual dag run
:param dag_run: The DagRun to schedule
:return: Callback that needs to be executed
"""
- callback: Optional[DagCallbackRequest] = None
+ callback: DagCallbackRequest | None = None
dag = dag_run.dag = self.dagbag.get_dag(dag_run.dag_id, session=session)
@@ -1156,19 +1304,20 @@ def _schedule_dag_run(
dag_id=dag.dag_id,
run_id=dag_run.run_id,
is_failure_callback=True,
- msg='timed_out',
+ processor_subdir=dag_model.processor_subdir,
+ msg="timed_out",
)
- # Send SLA & DAG Success/Failure Callbacks to be executed
- self._send_dag_callbacks_to_processor(dag, callback_to_execute)
- # Because we send the callback here, we need to return None
- return callback
+ dag_run.notify_dagrun_state_changed()
+ return callback_to_execute
if dag_run.execution_date > timezone.utcnow() and not dag.allow_future_exec_dates:
self.log.error("Execution date is in future: %s", dag_run.execution_date)
return callback
- self._verify_integrity_if_dag_changed(dag_run=dag_run, session=session)
+ if not self._verify_integrity_if_dag_changed(dag_run=dag_run, session=session):
+ self.log.warning("The DAG disappeared before verifying integrity: %s. Skipping.", dag_run.dag_id)
+ return callback
# TODO[HA]: Rename update_state -> schedule_dag_run, ?? something else?
schedulable_tis, callback_to_run = dag_run.update_state(session=session, execute_callbacks=False)
if dag_run.state in State.finished:
@@ -1184,30 +1333,36 @@ def _schedule_dag_run(
return callback_to_run
- @provide_session
- def _verify_integrity_if_dag_changed(self, dag_run: DagRun, session=None):
- """Only run DagRun.verify integrity if Serialized DAG has changed since it is slow"""
+ def _verify_integrity_if_dag_changed(self, dag_run: DagRun, session: Session) -> bool:
+ """
+ Only run DagRun.verify integrity if Serialized DAG has changed since it is slow.
+
+ Return True if we determine that DAG still exists.
+ """
latest_version = SerializedDagModel.get_latest_version_hash(dag_run.dag_id, session=session)
if dag_run.dag_hash == latest_version:
self.log.debug("DAG %s not changed structure, skipping dagrun.verify_integrity", dag_run.dag_id)
- return
+ return True
dag_run.dag_hash = latest_version
# Refresh the DAG
dag_run.dag = self.dagbag.get_dag(dag_id=dag_run.dag_id, session=session)
+ if not dag_run.dag:
+ return False
# Verify integrity also takes care of session.flush
dag_run.verify_integrity(session=session)
+ return True
- def _send_dag_callbacks_to_processor(self, dag: DAG, callback: Optional[DagCallbackRequest] = None):
+ def _send_dag_callbacks_to_processor(self, dag: DAG, callback: DagCallbackRequest | None = None) -> None:
self._send_sla_callbacks_to_processor(dag)
if callback:
self.executor.send_callback(callback)
else:
self.log.debug("callback is empty")
- def _send_sla_callbacks_to_processor(self, dag: DAG):
+ def _send_sla_callbacks_to_processor(self, dag: DAG) -> None:
"""Sends SLA Callbacks to DagFileProcessor if tasks have SLAs set and check_slas=True"""
if not settings.CHECK_SLAS:
return
@@ -1216,32 +1371,42 @@ def _send_sla_callbacks_to_processor(self, dag: DAG):
self.log.debug("Skipping SLA check for %s because no tasks in DAG have SLAs", dag)
return
- request = SlaCallbackRequest(full_filepath=dag.fileloc, dag_id=dag.dag_id)
+ if not dag.timetable.periodic:
+ self.log.debug("Skipping SLA check for %s because DAG is not scheduled", dag)
+ return
+
+ dag_model = DagModel.get_dagmodel(dag.dag_id)
+ request = SlaCallbackRequest(
+ full_filepath=dag.fileloc,
+ dag_id=dag.dag_id,
+ processor_subdir=dag_model.processor_subdir,
+ )
self.executor.send_callback(request)
@provide_session
- def _emit_pool_metrics(self, session: Session = None) -> None:
- pools = models.Pool.slots_stats(session=session)
+ def _emit_pool_metrics(self, session: Session = NEW_SESSION) -> None:
+ from airflow.models.pool import Pool
+
+ pools = Pool.slots_stats(session=session)
for pool_name, slot_stats in pools.items():
- Stats.gauge(f'pool.open_slots.{pool_name}', slot_stats["open"])
- Stats.gauge(f'pool.queued_slots.{pool_name}', slot_stats["queued"])
- Stats.gauge(f'pool.running_slots.{pool_name}', slot_stats["running"])
+ Stats.gauge(f"pool.open_slots.{pool_name}", slot_stats["open"])
+ Stats.gauge(f"pool.queued_slots.{pool_name}", slot_stats["queued"])
+ Stats.gauge(f"pool.running_slots.{pool_name}", slot_stats["running"])
@provide_session
- def heartbeat_callback(self, session: Session = None) -> None:
- Stats.incr('scheduler_heartbeat', 1, 1)
+ def heartbeat_callback(self, session: Session = NEW_SESSION) -> None:
+ Stats.incr("scheduler_heartbeat", 1, 1)
@provide_session
- def adopt_or_reset_orphaned_tasks(self, session: Session = None):
+ def adopt_or_reset_orphaned_tasks(self, session: Session = NEW_SESSION) -> int:
"""
Reset any TaskInstance still in QUEUED or SCHEDULED states that were
enqueued by a SchedulerJob that is no longer running.
:return: the number of TIs reset
- :rtype: int
"""
self.log.info("Resetting orphaned tasks for active dag runs")
- timeout = conf.getint('scheduler', 'scheduler_health_check_threshold')
+ timeout = conf.getint("scheduler", "scheduler_health_check_threshold")
for attempt in run_with_db_retries(logger=self.log):
with attempt:
@@ -1264,7 +1429,7 @@ def adopt_or_reset_orphaned_tasks(self, session: Session = None):
if num_failed:
self.log.info("Marked %d SchedulerJob instances as failed", num_failed)
- Stats.incr(self.__class__.__name__.lower() + '_end', num_failed)
+ Stats.incr(self.__class__.__name__.lower() + "_end", num_failed)
resettable_states = [TaskInstanceState.QUEUED, TaskInstanceState.RUNNING]
query = (
@@ -1299,11 +1464,11 @@ def adopt_or_reset_orphaned_tasks(self, session: Session = None):
for ti in set(tis_to_reset_or_adopt) - set(to_reset):
ti.queued_by_job_id = self.id
- Stats.incr('scheduler.orphaned_tasks.cleared', len(to_reset))
- Stats.incr('scheduler.orphaned_tasks.adopted', len(tis_to_reset_or_adopt) - len(to_reset))
+ Stats.incr("scheduler.orphaned_tasks.cleared", len(to_reset))
+ Stats.incr("scheduler.orphaned_tasks.adopted", len(tis_to_reset_or_adopt) - len(to_reset))
if to_reset:
- task_instance_str = '\n\t'.join(reset_tis_message)
+ task_instance_str = "\n\t".join(reset_tis_message)
self.log.info(
"Reset the following %s orphaned TaskInstances:\n\t%s",
len(to_reset),
@@ -1320,7 +1485,7 @@ def adopt_or_reset_orphaned_tasks(self, session: Session = None):
return len(to_reset)
@provide_session
- def check_trigger_timeouts(self, session: Session = None):
+ def check_trigger_timeouts(self, session: Session = NEW_SESSION) -> None:
"""
Looks at all tasks that are in the "deferred" state and whose trigger
or execution timeout has passed, so they can be marked as failed.
@@ -1345,38 +1510,117 @@ def check_trigger_timeouts(self, session: Session = None):
if num_timed_out_tasks:
self.log.info("Timed out %i deferred tasks without fired triggers", num_timed_out_tasks)
- @provide_session
- def _find_zombies(self, session):
+ def _find_zombies(self) -> None:
"""
Find zombie task instances, which are tasks haven't heartbeated for too long
- and update the current zombie list.
+ or have a no-longer-running LocalTaskJob, and create a TaskCallbackRequest
+ to be handled by the DAG processor.
"""
+ from airflow.jobs.local_task_job import LocalTaskJob
+
self.log.debug("Finding 'running' jobs without a recent heartbeat")
limit_dttm = timezone.utcnow() - timedelta(seconds=self._zombie_threshold_secs)
- zombies = (
- session.query(TaskInstance, DagModel.fileloc)
- .join(LocalTaskJob, TaskInstance.job_id == LocalTaskJob.id)
- .join(DagModel, TaskInstance.dag_id == DagModel.dag_id)
- .filter(TaskInstance.state == State.RUNNING)
- .filter(
- or_(
- LocalTaskJob.state != State.RUNNING,
- LocalTaskJob.latest_heartbeat < limit_dttm,
+ with create_session() as session:
+ zombies: list[tuple[TI, str, str]] = (
+ session.query(TI, DM.fileloc, DM.processor_subdir)
+ .with_hint(TI, "USE INDEX (ti_state)", dialect_name="mysql")
+ .join(LocalTaskJob, TI.job_id == LocalTaskJob.id)
+ .join(DM, TI.dag_id == DM.dag_id)
+ .filter(TI.state == TaskInstanceState.RUNNING)
+ .filter(
+ or_(
+ LocalTaskJob.state != State.RUNNING,
+ LocalTaskJob.latest_heartbeat < limit_dttm,
+ )
)
+ .filter(TI.queued_by_job_id == self.id)
+ .all()
)
- .all()
- )
if zombies:
self.log.warning("Failing (%s) jobs without heartbeat after %s", len(zombies), limit_dttm)
- for ti, file_loc in zombies:
+ for ti, file_loc, processor_subdir in zombies:
+ zombie_message_details = self._generate_zombie_message_details(ti)
request = TaskCallbackRequest(
full_filepath=file_loc,
+ processor_subdir=processor_subdir,
simple_task_instance=SimpleTaskInstance.from_ti(ti),
- msg=f"Detected {ti} as zombie",
+ msg=str(zombie_message_details),
)
self.log.error("Detected zombie job: %s", request)
self.executor.send_callback(request)
- Stats.incr('zombies_killed')
+ Stats.incr("zombies_killed")
+
+ @staticmethod
+ def _generate_zombie_message_details(ti: TaskInstance):
+ zombie_message_details = {
+ "DAG Id": ti.dag_id,
+ "Task Id": ti.task_id,
+ "Run Id": ti.run_id,
+ }
+
+ if ti.map_index != -1:
+ zombie_message_details["Map Index"] = ti.map_index
+ if ti.hostname:
+ zombie_message_details["Hostname"] = ti.hostname
+ if ti.external_executor_id:
+ zombie_message_details["External Executor Id"] = ti.external_executor_id
+
+ return zombie_message_details
+
+ @provide_session
+ def _cleanup_stale_dags(self, session: Session = NEW_SESSION) -> None:
+ """
+ Find all dags that were not updated by Dag Processor recently and mark them as inactive.
+
+ In case one of DagProcessors is stopped (in case there are multiple of them
+ for different dag folders), it's dags are never marked as inactive.
+ Also remove dags from SerializedDag table.
+ Executed on schedule only if [scheduler]standalone_dag_processor is True.
+ """
+ self.log.debug("Checking dags not parsed within last %s seconds.", self._dag_stale_not_seen_duration)
+ limit_lpt = timezone.utcnow() - timedelta(seconds=self._dag_stale_not_seen_duration)
+ stale_dags = (
+ session.query(DagModel).filter(DagModel.is_active, DagModel.last_parsed_time < limit_lpt).all()
+ )
+ if not stale_dags:
+ self.log.debug("Not stale dags found.")
+ return
+
+ self.log.info("Found (%d) stales dags not parsed after %s.", len(stale_dags), limit_lpt)
+ for dag in stale_dags:
+ dag.is_active = False
+ SerializedDagModel.remove_dag(dag_id=dag.dag_id, session=session)
+ session.flush()
+
+ @provide_session
+ def _orphan_unreferenced_datasets(self, session: Session = NEW_SESSION) -> None:
+ """
+ Detects datasets that are no longer referenced in any DAG schedule parameters or task outlets and
+ sets the dataset is_orphaned flag to True
+ """
+ orphaned_dataset_query = (
+ session.query(DatasetModel)
+ .join(
+ DagScheduleDatasetReference,
+ isouter=True,
+ )
+ .join(
+ TaskOutletDatasetReference,
+ isouter=True,
+ )
+ # MSSQL doesn't like it when we select a column that we haven't grouped by. All other DBs let us
+ # group by id and select all columns.
+ .group_by(DatasetModel if session.get_bind().dialect.name == "mssql" else DatasetModel.id)
+ .having(
+ and_(
+ func.count(DagScheduleDatasetReference.dag_id) == 0,
+ func.count(TaskOutletDatasetReference.dag_id) == 0,
+ )
+ )
+ )
+ for dataset in orphaned_dataset_query:
+ self.log.info("Orphaning unreferenced dataset '%s'", dataset.uri)
+ dataset.is_orphaned = expression.true()
diff --git a/airflow/jobs/triggerer_job.py b/airflow/jobs/triggerer_job.py
index ac7d22a6b1da9..20fcf20b169b0 100644
--- a/airflow/jobs/triggerer_job.py
+++ b/airflow/jobs/triggerer_job.py
@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
import asyncio
import os
@@ -22,7 +23,7 @@
import threading
import time
from collections import deque
-from typing import Deque, Dict, Set, Tuple, Type
+from typing import Deque
from sqlalchemy import func
@@ -48,14 +49,14 @@ class TriggererJob(BaseJob):
- A subthread runs all the async code
"""
- __mapper_args__ = {'polymorphic_identity': 'TriggererJob'}
+ __mapper_args__ = {"polymorphic_identity": "TriggererJob"}
def __init__(self, capacity=None, *args, **kwargs):
# Call superclass
super().__init__(*args, **kwargs)
if capacity is None:
- self.capacity = conf.getint('triggerer', 'default_capacity', fallback=1000)
+ self.capacity = conf.getint("triggerer", "default_capacity", fallback=1000)
elif isinstance(capacity, int) and capacity > 0:
self.capacity = capacity
else:
@@ -158,7 +159,7 @@ def handle_events(self):
# Tell the model to wake up its tasks
Trigger.submit_event(trigger_id=trigger_id, event=event)
# Emit stat event
- Stats.incr('triggers.succeeded')
+ Stats.incr("triggers.succeeded")
def handle_failed_triggers(self):
"""
@@ -170,10 +171,10 @@ def handle_failed_triggers(self):
trigger_id, saved_exc = self.runner.failed_triggers.popleft()
Trigger.submit_failure(trigger_id=trigger_id, exc=saved_exc)
# Emit stat event
- Stats.incr('triggers.failed')
+ Stats.incr("triggers.failed")
def emit_metrics(self):
- Stats.gauge('triggers.running', len(self.runner.triggers))
+ Stats.gauge("triggers.running", len(self.runner.triggers))
class TriggerDetails(TypedDict):
@@ -195,22 +196,22 @@ class TriggerRunner(threading.Thread, LoggingMixin):
"""
# Maps trigger IDs to their running tasks and other info
- triggers: Dict[int, TriggerDetails]
+ triggers: dict[int, TriggerDetails]
# Cache for looking up triggers by classpath
- trigger_cache: Dict[str, Type[BaseTrigger]]
+ trigger_cache: dict[str, type[BaseTrigger]]
# Inbound queue of new triggers
- to_create: Deque[Tuple[int, BaseTrigger]]
+ to_create: Deque[tuple[int, BaseTrigger]]
# Inbound queue of deleted triggers
to_cancel: Deque[int]
# Outbound queue of events
- events: Deque[Tuple[int, TriggerEvent]]
+ events: Deque[tuple[int, TriggerEvent]]
# Outbound queue of failed triggers
- failed_triggers: Deque[Tuple[int, BaseException]]
+ failed_triggers: Deque[tuple[int, BaseException]]
# Should-we-stop flag
stop: bool = False
@@ -346,7 +347,7 @@ async def block_watchdog(self):
"to get more information on overrunning coroutines.",
time_elapsed,
)
- Stats.incr('triggers.blocked_main_thread')
+ Stats.incr("triggers.blocked_main_thread")
# Async trigger logic
@@ -355,10 +356,10 @@ async def run_trigger(self, trigger_id, trigger):
Wrapper which runs an actual trigger (they are async generators)
and pushes their events into our outbound event deque.
"""
- self.log.info("Trigger %s starting", self.triggers[trigger_id]['name'])
+ self.log.info("Trigger %s starting", self.triggers[trigger_id]["name"])
try:
async for event in trigger.run():
- self.log.info("Trigger %s fired: %s", self.triggers[trigger_id]['name'], event)
+ self.log.info("Trigger %s fired: %s", self.triggers[trigger_id]["name"], event)
self.triggers[trigger_id]["events"] += 1
self.events.append((trigger_id, event))
finally:
@@ -370,7 +371,7 @@ async def run_trigger(self, trigger_id, trigger):
# Main-thread sync API
- def update_triggers(self, requested_trigger_ids: Set[int]):
+ def update_triggers(self, requested_trigger_ids: set[int]):
"""
Called from the main thread to request that we update what
triggers we're running.
@@ -413,7 +414,7 @@ def update_triggers(self, requested_trigger_ids: Set[int]):
for old_id in cancel_trigger_ids:
self.to_cancel.append(old_id)
- def get_trigger_by_classpath(self, classpath: str) -> Type[BaseTrigger]:
+ def get_trigger_by_classpath(self, classpath: str) -> type[BaseTrigger]:
"""
Gets a trigger class by its classpath ("path.to.module.classname")
diff --git a/airflow/kubernetes/k8s_model.py b/airflow/kubernetes/k8s_model.py
index 01e294672aa52..123294e0bdc20 100644
--- a/airflow/kubernetes/k8s_model.py
+++ b/airflow/kubernetes/k8s_model.py
@@ -15,9 +15,10 @@
# specific language governing permissions and limitations
# under the License.
"""Classes for interacting with Kubernetes API."""
+from __future__ import annotations
+
from abc import ABC, abstractmethod
from functools import reduce
-from typing import List, Optional
from kubernetes.client import models as k8s
@@ -42,7 +43,7 @@ def attach_to_pod(self, pod: k8s.V1Pod) -> k8s.V1Pod:
"""
-def append_to_pod(pod: k8s.V1Pod, k8s_objects: Optional[List[K8SModel]]):
+def append_to_pod(pod: k8s.V1Pod, k8s_objects: list[K8SModel] | None):
"""
:param pod: A pod to attach a list of Kubernetes objects to
:param k8s_objects: a potential None list of K8SModels
diff --git a/airflow/kubernetes/kube_client.py b/airflow/kubernetes/kube_client.py
index 7e6ba05119787..46aff02b9d947 100644
--- a/airflow/kubernetes/kube_client.py
+++ b/airflow/kubernetes/kube_client.py
@@ -15,8 +15,9 @@
# specific language governing permissions and limitations
# under the License.
"""Client for kubernetes communication"""
+from __future__ import annotations
+
import logging
-from typing import Optional
from airflow.configuration import conf
@@ -30,7 +31,10 @@
has_kubernetes = True
def _disable_verify_ssl() -> None:
- configuration = Configuration()
+ if hasattr(Configuration, "get_default_copy"):
+ configuration = Configuration.get_default_copy()
+ else:
+ configuration = Configuration()
configuration.verify_ssl = False
Configuration.set_default(configuration)
@@ -55,35 +59,35 @@ def _enable_tcp_keepalive() -> None:
from urllib3.connection import HTTPConnection, HTTPSConnection
- tcp_keep_idle = conf.getint('kubernetes', 'tcp_keep_idle')
- tcp_keep_intvl = conf.getint('kubernetes', 'tcp_keep_intvl')
- tcp_keep_cnt = conf.getint('kubernetes', 'tcp_keep_cnt')
+ tcp_keep_idle = conf.getint("kubernetes_executor", "tcp_keep_idle")
+ tcp_keep_intvl = conf.getint("kubernetes_executor", "tcp_keep_intvl")
+ tcp_keep_cnt = conf.getint("kubernetes_executor", "tcp_keep_cnt")
socket_options = [(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)]
if hasattr(socket, "TCP_KEEPIDLE"):
socket_options.append((socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, tcp_keep_idle))
else:
- log.warning("Unable to set TCP_KEEPIDLE on this platform")
+ log.debug("Unable to set TCP_KEEPIDLE on this platform")
if hasattr(socket, "TCP_KEEPINTVL"):
socket_options.append((socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, tcp_keep_intvl))
else:
- log.warning("Unable to set TCP_KEEPINTVL on this platform")
+ log.debug("Unable to set TCP_KEEPINTVL on this platform")
if hasattr(socket, "TCP_KEEPCNT"):
socket_options.append((socket.IPPROTO_TCP, socket.TCP_KEEPCNT, tcp_keep_cnt))
else:
- log.warning("Unable to set TCP_KEEPCNT on this platform")
+ log.debug("Unable to set TCP_KEEPCNT on this platform")
HTTPSConnection.default_socket_options = HTTPSConnection.default_socket_options + socket_options
HTTPConnection.default_socket_options = HTTPConnection.default_socket_options + socket_options
def get_kube_client(
- in_cluster: bool = conf.getboolean('kubernetes', 'in_cluster'),
- cluster_context: Optional[str] = None,
- config_file: Optional[str] = None,
+ in_cluster: bool = conf.getboolean("kubernetes_executor", "in_cluster"),
+ cluster_context: str | None = None,
+ config_file: str | None = None,
) -> client.CoreV1Api:
"""
Retrieves Kubernetes client
@@ -97,19 +101,19 @@ def get_kube_client(
if not has_kubernetes:
raise _import_err
- if conf.getboolean('kubernetes', 'enable_tcp_keepalive'):
+ if conf.getboolean("kubernetes_executor", "enable_tcp_keepalive"):
_enable_tcp_keepalive()
- if not conf.getboolean('kubernetes', 'verify_ssl'):
- _disable_verify_ssl()
-
if in_cluster:
config.load_incluster_config()
else:
if cluster_context is None:
- cluster_context = conf.get('kubernetes', 'cluster_context', fallback=None)
+ cluster_context = conf.get("kubernetes_executor", "cluster_context", fallback=None)
if config_file is None:
- config_file = conf.get('kubernetes', 'config_file', fallback=None)
+ config_file = conf.get("kubernetes_executor", "config_file", fallback=None)
config.load_kube_config(config_file=config_file, context=cluster_context)
+ if not conf.getboolean("kubernetes_executor", "verify_ssl"):
+ _disable_verify_ssl()
+
return client.CoreV1Api()
diff --git a/airflow/kubernetes/kube_config.py b/airflow/kubernetes/kube_config.py
index 2b550214543dc..0285f65208d01 100644
--- a/airflow/kubernetes/kube_config.py
+++ b/airflow/kubernetes/kube_config.py
@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
from airflow.configuration import conf
from airflow.exceptions import AirflowConfigException
@@ -23,30 +24,30 @@
class KubeConfig:
"""Configuration for Kubernetes"""
- core_section = 'core'
- kubernetes_section = 'kubernetes'
- logging_section = 'logging'
+ core_section = "core"
+ kubernetes_section = "kubernetes_executor"
+ logging_section = "logging"
def __init__(self):
configuration_dict = conf.as_dict(display_sensitive=True)
self.core_configuration = configuration_dict[self.core_section]
self.airflow_home = AIRFLOW_HOME
- self.dags_folder = conf.get(self.core_section, 'dags_folder')
- self.parallelism = conf.getint(self.core_section, 'parallelism')
- self.pod_template_file = conf.get(self.kubernetes_section, 'pod_template_file', fallback=None)
+ self.dags_folder = conf.get(self.core_section, "dags_folder")
+ self.parallelism = conf.getint(self.core_section, "parallelism")
+ self.pod_template_file = conf.get(self.kubernetes_section, "pod_template_file", fallback=None)
- self.delete_worker_pods = conf.getboolean(self.kubernetes_section, 'delete_worker_pods')
+ self.delete_worker_pods = conf.getboolean(self.kubernetes_section, "delete_worker_pods")
self.delete_worker_pods_on_failure = conf.getboolean(
- self.kubernetes_section, 'delete_worker_pods_on_failure'
+ self.kubernetes_section, "delete_worker_pods_on_failure"
)
self.worker_pods_creation_batch_size = conf.getint(
- self.kubernetes_section, 'worker_pods_creation_batch_size'
+ self.kubernetes_section, "worker_pods_creation_batch_size"
)
- self.worker_container_repository = conf.get(self.kubernetes_section, 'worker_container_repository')
- self.worker_container_tag = conf.get(self.kubernetes_section, 'worker_container_tag')
+ self.worker_container_repository = conf.get(self.kubernetes_section, "worker_container_repository")
+ self.worker_container_tag = conf.get(self.kubernetes_section, "worker_container_tag")
if self.worker_container_repository and self.worker_container_tag:
- self.kube_image = f'{self.worker_container_repository}:{self.worker_container_tag}'
+ self.kube_image = f"{self.worker_container_repository}:{self.worker_container_tag}"
else:
self.kube_image = None
@@ -54,27 +55,27 @@ def __init__(self):
# that if your
# cluster has RBAC enabled, your scheduler may need service account permissions to
# create, watch, get, and delete pods in this namespace.
- self.kube_namespace = conf.get(self.kubernetes_section, 'namespace')
- self.multi_namespace_mode = conf.getboolean(self.kubernetes_section, 'multi_namespace_mode')
+ self.kube_namespace = conf.get(self.kubernetes_section, "namespace")
+ self.multi_namespace_mode = conf.getboolean(self.kubernetes_section, "multi_namespace_mode")
# The Kubernetes Namespace in which pods will be created by the executor. Note
# that if your
# cluster has RBAC enabled, your workers may need service account permissions to
# interact with cluster components.
- self.executor_namespace = conf.get(self.kubernetes_section, 'namespace')
+ self.executor_namespace = conf.get(self.kubernetes_section, "namespace")
- self.worker_pods_pending_timeout = conf.getint(self.kubernetes_section, 'worker_pods_pending_timeout')
+ self.worker_pods_pending_timeout = conf.getint(self.kubernetes_section, "worker_pods_pending_timeout")
self.worker_pods_pending_timeout_check_interval = conf.getint(
- self.kubernetes_section, 'worker_pods_pending_timeout_check_interval'
+ self.kubernetes_section, "worker_pods_pending_timeout_check_interval"
)
self.worker_pods_pending_timeout_batch_size = conf.getint(
- self.kubernetes_section, 'worker_pods_pending_timeout_batch_size'
+ self.kubernetes_section, "worker_pods_pending_timeout_batch_size"
)
self.worker_pods_queued_check_interval = conf.getint(
- self.kubernetes_section, 'worker_pods_queued_check_interval'
+ self.kubernetes_section, "worker_pods_queued_check_interval"
)
self.kube_client_request_args = conf.getjson(
- self.kubernetes_section, 'kube_client_request_args', fallback={}
+ self.kubernetes_section, "kube_client_request_args", fallback={}
)
if not isinstance(self.kube_client_request_args, dict):
raise AirflowConfigException(
@@ -82,13 +83,13 @@ def __init__(self):
+ type(self.kube_client_request_args).__name__
)
if self.kube_client_request_args:
- if '_request_timeout' in self.kube_client_request_args and isinstance(
- self.kube_client_request_args['_request_timeout'], list
+ if "_request_timeout" in self.kube_client_request_args and isinstance(
+ self.kube_client_request_args["_request_timeout"], list
):
- self.kube_client_request_args['_request_timeout'] = tuple(
- self.kube_client_request_args['_request_timeout']
+ self.kube_client_request_args["_request_timeout"] = tuple(
+ self.kube_client_request_args["_request_timeout"]
)
- self.delete_option_kwargs = conf.getjson(self.kubernetes_section, 'delete_option_kwargs', fallback={})
+ self.delete_option_kwargs = conf.getjson(self.kubernetes_section, "delete_option_kwargs", fallback={})
if not isinstance(self.delete_option_kwargs, dict):
raise AirflowConfigException(
f"[{self.kubernetes_section}] 'delete_option_kwargs' expected a JSON dict, got "
diff --git a/airflow/kubernetes/kubernetes_helper_functions.py b/airflow/kubernetes/kubernetes_helper_functions.py
index 1068a0521b894..1f0c809cf94bd 100644
--- a/airflow/kubernetes/kubernetes_helper_functions.py
+++ b/airflow/kubernetes/kubernetes_helper_functions.py
@@ -14,9 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
import logging
-from typing import Dict, Optional
import pendulum
from slugify import slugify
@@ -26,22 +26,7 @@
log = logging.getLogger(__name__)
-def _strip_unsafe_kubernetes_special_chars(string: str) -> str:
- """
- Kubernetes only supports lowercase alphanumeric characters, "-" and "." in
- the pod name.
- However, there are special rules about how "-" and "." can be used so let's
- only keep
- alphanumeric chars see here for detail:
- https://kubernetes.io/docs/concepts/overview/working-with-objects/names/
-
- :param string: The requested Pod name
- :return: Pod name stripped of any unsafe characters
- """
- return slugify(string, separator='', lowercase=True)
-
-
-def create_pod_id(dag_id: str, task_id: str) -> str:
+def create_pod_id(dag_id: str | None = None, task_id: str | None = None) -> str:
"""
Generates the kubernetes safe pod_id. Note that this is
NOT the full ID that will be launched to k8s. We will add a uuid
@@ -51,27 +36,32 @@ def create_pod_id(dag_id: str, task_id: str) -> str:
:param task_id: Task ID
:return: The non-unique pod_id for this task/DAG pairing
"""
- safe_dag_id = _strip_unsafe_kubernetes_special_chars(dag_id)
- safe_task_id = _strip_unsafe_kubernetes_special_chars(task_id)
- return safe_dag_id + safe_task_id
+ name = ""
+ if dag_id:
+ name += dag_id
+ if task_id:
+ if name:
+ name += "-"
+ name += task_id
+ return slugify(name, lowercase=True)[:253].strip("-.")
-def annotations_to_key(annotations: Dict[str, str]) -> Optional[TaskInstanceKey]:
+def annotations_to_key(annotations: dict[str, str]) -> TaskInstanceKey:
"""Build a TaskInstanceKey based on pod annotations"""
log.debug("Creating task key for annotations %s", annotations)
- dag_id = annotations['dag_id']
- task_id = annotations['task_id']
- try_number = int(annotations['try_number'])
- annotation_run_id = annotations.get('run_id')
- map_index = int(annotations.get('map_index', -1))
+ dag_id = annotations["dag_id"]
+ task_id = annotations["task_id"]
+ try_number = int(annotations["try_number"])
+ annotation_run_id = annotations.get("run_id")
+ map_index = int(annotations.get("map_index", -1))
- if not annotation_run_id and 'execution_date' in annotations:
+ if not annotation_run_id and "execution_date" in annotations:
# Compat: Look up the run_id from the TI table!
from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance
from airflow.settings import Session
- execution_date = pendulum.parse(annotations['execution_date'])
+ execution_date = pendulum.parse(annotations["execution_date"])
# Do _not_ use create-session, we don't want to expunge
session = Session()
diff --git a/airflow/kubernetes/pod.py b/airflow/kubernetes/pod.py
index a5b6cde0e335b..5b946b2e3a885 100644
--- a/airflow/kubernetes/pod.py
+++ b/airflow/kubernetes/pod.py
@@ -19,16 +19,20 @@
This module is deprecated.
Please use :mod:`kubernetes.client.models` for `V1ResourceRequirements` and `Port`.
"""
-# flake8: noqa
+from __future__ import annotations
import warnings
+from airflow.exceptions import RemovedInAirflow3Warning
+
+# flake8: noqa
+
with warnings.catch_warnings():
- warnings.simplefilter("ignore", DeprecationWarning)
+ warnings.simplefilter("ignore", RemovedInAirflow3Warning)
from airflow.providers.cncf.kubernetes.backcompat.pod import Port, Resources # noqa: autoflake
warnings.warn(
"This module is deprecated. Please use `kubernetes.client.models` for `V1ResourceRequirements` and `Port`.",
- DeprecationWarning,
+ RemovedInAirflow3Warning,
stacklevel=2,
)
diff --git a/airflow/kubernetes/pod_generator.py b/airflow/kubernetes/pod_generator.py
index 52b45801ccabc..4382b390aa485 100644
--- a/airflow/kubernetes/pod_generator.py
+++ b/airflow/kubernetes/pod_generator.py
@@ -20,29 +20,33 @@
The advantage being that the full Kubernetes API
is supported and no serialization need be written.
"""
+from __future__ import annotations
+
import copy
import datetime
import hashlib
+import logging
import os
import re
import uuid
import warnings
from functools import reduce
-from typing import List, Optional, Union
from dateutil import parser
from kubernetes.client import models as k8s
from kubernetes.client.api_client import ApiClient
-from airflow.exceptions import AirflowConfigException
+from airflow.exceptions import AirflowConfigException, PodReconciliationError, RemovedInAirflow3Warning
from airflow.kubernetes.pod_generator_deprecated import PodDefaults, PodGenerator as PodGeneratorDeprecated
from airflow.utils import yaml
from airflow.version import version as airflow_version
+log = logging.getLogger(__name__)
+
MAX_LABEL_LEN = 63
-def make_safe_label_value(string):
+def make_safe_label_value(string: str) -> str:
"""
Valid label values must be 63 characters or less and must be empty or begin and
end with an alphanumeric character ([a-z0-9A-Z]) with dashes (-), underscores (_),
@@ -70,7 +74,7 @@ def datetime_to_label_safe_datestring(datetime_obj: datetime.datetime) -> str:
:param datetime_obj: datetime.datetime object
:return: ISO-like string representing the datetime
"""
- return datetime_obj.isoformat().replace(":", "_").replace('+', '_plus_')
+ return datetime_obj.isoformat().replace(":", "_").replace("+", "_plus_")
def label_safe_datestring_to_datetime(string: str) -> datetime.datetime:
@@ -82,7 +86,7 @@ def label_safe_datestring_to_datetime(string: str) -> datetime.datetime:
:param string: str
:return: datetime.datetime object
"""
- return parser.parse(string.replace('_plus_', '+').replace("_", ":"))
+ return parser.parse(string.replace("_plus_", "+").replace("_", ":"))
class PodGenerator:
@@ -100,8 +104,8 @@ class PodGenerator:
def __init__(
self,
- pod: Optional[k8s.V1Pod] = None,
- pod_template_file: Optional[str] = None,
+ pod: k8s.V1Pod | None = None,
+ pod_template_file: str | None = None,
extract_xcom: bool = True,
):
if not pod_template_file and not pod:
@@ -147,7 +151,7 @@ def add_xcom_sidecar(pod: k8s.V1Pod) -> k8s.V1Pod:
return pod_cp
@staticmethod
- def from_obj(obj) -> Optional[Union[dict, k8s.V1Pod]]:
+ def from_obj(obj) -> dict | k8s.V1Pod | None:
"""Converts to pod from obj"""
if obj is None:
return None
@@ -169,19 +173,19 @@ def from_obj(obj) -> Optional[Union[dict, k8s.V1Pod]]:
return k8s_object
elif isinstance(k8s_legacy_object, dict):
warnings.warn(
- 'Using a dictionary for the executor_config is deprecated and will soon be removed.'
+ "Using a dictionary for the executor_config is deprecated and will soon be removed."
'please use a `kubernetes.client.models.V1Pod` class with a "pod_override" key'
- ' instead. ',
- category=DeprecationWarning,
+ " instead. ",
+ category=RemovedInAirflow3Warning,
)
return PodGenerator.from_legacy_obj(obj)
else:
raise TypeError(
- 'Cannot convert a non-kubernetes.client.models.V1Pod object into a KubernetesExecutorConfig'
+ "Cannot convert a non-kubernetes.client.models.V1Pod object into a KubernetesExecutorConfig"
)
@staticmethod
- def from_legacy_obj(obj) -> Optional[k8s.V1Pod]:
+ def from_legacy_obj(obj) -> k8s.V1Pod | None:
"""Converts to pod from obj"""
if obj is None:
return None
@@ -193,18 +197,18 @@ def from_legacy_obj(obj) -> Optional[k8s.V1Pod]:
if not namespaced:
return None
- resources = namespaced.get('resources')
+ resources = namespaced.get("resources")
if resources is None:
requests = {
- 'cpu': namespaced.pop('request_cpu', None),
- 'memory': namespaced.pop('request_memory', None),
- 'ephemeral-storage': namespaced.get('ephemeral-storage'), # We pop this one in limits
+ "cpu": namespaced.pop("request_cpu", None),
+ "memory": namespaced.pop("request_memory", None),
+ "ephemeral-storage": namespaced.get("ephemeral-storage"), # We pop this one in limits
}
limits = {
- 'cpu': namespaced.pop('limit_cpu', None),
- 'memory': namespaced.pop('limit_memory', None),
- 'ephemeral-storage': namespaced.pop('ephemeral-storage', None),
+ "cpu": namespaced.pop("limit_cpu", None),
+ "memory": namespaced.pop("limit_memory", None),
+ "ephemeral-storage": namespaced.pop("ephemeral-storage", None),
}
all_resources = list(requests.values()) + list(limits.values())
if all(r is None for r in all_resources):
@@ -214,11 +218,11 @@ def from_legacy_obj(obj) -> Optional[k8s.V1Pod]:
requests = {k: v for k, v in requests.items() if v is not None}
limits = {k: v for k, v in limits.items() if v is not None}
resources = k8s.V1ResourceRequirements(requests=requests, limits=limits)
- namespaced['resources'] = resources
+ namespaced["resources"] = resources
return PodGeneratorDeprecated(**namespaced).gen_pod()
@staticmethod
- def reconcile_pods(base_pod: k8s.V1Pod, client_pod: Optional[k8s.V1Pod]) -> k8s.V1Pod:
+ def reconcile_pods(base_pod: k8s.V1Pod, client_pod: k8s.V1Pod | None) -> k8s.V1Pod:
"""
:param base_pod: has the base attributes which are overwritten if they exist
in the client pod and remain if they do not exist in the client_pod
@@ -253,17 +257,17 @@ def reconcile_metadata(base_meta, client_meta):
elif client_meta and base_meta:
client_meta.labels = merge_objects(base_meta.labels, client_meta.labels)
client_meta.annotations = merge_objects(base_meta.annotations, client_meta.annotations)
- extend_object_field(base_meta, client_meta, 'managed_fields')
- extend_object_field(base_meta, client_meta, 'finalizers')
- extend_object_field(base_meta, client_meta, 'owner_references')
+ extend_object_field(base_meta, client_meta, "managed_fields")
+ extend_object_field(base_meta, client_meta, "finalizers")
+ extend_object_field(base_meta, client_meta, "owner_references")
return merge_objects(base_meta, client_meta)
return None
@staticmethod
def reconcile_specs(
- base_spec: Optional[k8s.V1PodSpec], client_spec: Optional[k8s.V1PodSpec]
- ) -> Optional[k8s.V1PodSpec]:
+ base_spec: k8s.V1PodSpec | None, client_spec: k8s.V1PodSpec | None
+ ) -> k8s.V1PodSpec | None:
"""
:param base_spec: has the base attributes which are overwritten if they exist
in the client_spec and remain if they do not exist in the client_spec
@@ -278,16 +282,16 @@ def reconcile_specs(
client_spec.containers = PodGenerator.reconcile_containers(
base_spec.containers, client_spec.containers
)
- merged_spec = extend_object_field(base_spec, client_spec, 'init_containers')
- merged_spec = extend_object_field(base_spec, merged_spec, 'volumes')
+ merged_spec = extend_object_field(base_spec, client_spec, "init_containers")
+ merged_spec = extend_object_field(base_spec, merged_spec, "volumes")
return merge_objects(base_spec, merged_spec)
return None
@staticmethod
def reconcile_containers(
- base_containers: List[k8s.V1Container], client_containers: List[k8s.V1Container]
- ) -> List[k8s.V1Container]:
+ base_containers: list[k8s.V1Container], client_containers: list[k8s.V1Container]
+ ) -> list[k8s.V1Container]:
"""
:param base_containers: has the base attributes which are overwritten if they exist
in the client_containers and remain if they do not exist in the client_containers
@@ -303,11 +307,11 @@ def reconcile_containers(
client_container = client_containers[0]
base_container = base_containers[0]
- client_container = extend_object_field(base_container, client_container, 'volume_mounts')
- client_container = extend_object_field(base_container, client_container, 'env')
- client_container = extend_object_field(base_container, client_container, 'env_from')
- client_container = extend_object_field(base_container, client_container, 'ports')
- client_container = extend_object_field(base_container, client_container, 'volume_devices')
+ client_container = extend_object_field(base_container, client_container, "volume_mounts")
+ client_container = extend_object_field(base_container, client_container, "env")
+ client_container = extend_object_field(base_container, client_container, "env_from")
+ client_container = extend_object_field(base_container, client_container, "ports")
+ client_container = extend_object_field(base_container, client_container, "volume_devices")
client_container = merge_objects(base_container, client_container)
return [client_container] + PodGenerator.reconcile_containers(
@@ -321,13 +325,13 @@ def construct_pod(
pod_id: str,
try_number: int,
kube_image: str,
- date: Optional[datetime.datetime],
- args: List[str],
- pod_override_object: Optional[k8s.V1Pod],
+ date: datetime.datetime | None,
+ args: list[str],
+ pod_override_object: k8s.V1Pod | None,
base_worker_pod: k8s.V1Pod,
namespace: str,
scheduler_job_id: str,
- run_id: Optional[str] = None,
+ run_id: str | None = None,
map_index: int = -1,
) -> k8s.V1Pod:
"""
@@ -344,27 +348,27 @@ def construct_pod(
image = kube_image
annotations = {
- 'dag_id': dag_id,
- 'task_id': task_id,
- 'try_number': str(try_number),
+ "dag_id": dag_id,
+ "task_id": task_id,
+ "try_number": str(try_number),
}
labels = {
- 'airflow-worker': make_safe_label_value(scheduler_job_id),
- 'dag_id': make_safe_label_value(dag_id),
- 'task_id': make_safe_label_value(task_id),
- 'try_number': str(try_number),
- 'airflow_version': airflow_version.replace('+', '-'),
- 'kubernetes_executor': 'True',
+ "airflow-worker": make_safe_label_value(scheduler_job_id),
+ "dag_id": make_safe_label_value(dag_id),
+ "task_id": make_safe_label_value(task_id),
+ "try_number": str(try_number),
+ "airflow_version": airflow_version.replace("+", "-"),
+ "kubernetes_executor": "True",
}
if map_index >= 0:
- annotations['map_index'] = str(map_index)
- labels['map_index'] = str(map_index)
+ annotations["map_index"] = str(map_index)
+ labels["map_index"] = str(map_index)
if date:
- annotations['execution_date'] = date.isoformat()
- labels['execution_date'] = datetime_to_label_safe_datestring(date)
+ annotations["execution_date"] = date.isoformat()
+ labels["execution_date"] = datetime_to_label_safe_datestring(date)
if run_id:
- annotations['run_id'] = run_id
- labels['run_id'] = make_safe_label_value(run_id)
+ annotations["run_id"] = run_id
+ labels["run_id"] = make_safe_label_value(run_id)
dynamic_pod = k8s.V1Pod(
metadata=k8s.V1ObjectMeta(
@@ -389,7 +393,10 @@ def construct_pod(
# Pod from the pod_template_File -> Pod from executor_config arg -> Pod from the K8s executor
pod_list = [base_worker_pod, pod_override_object, dynamic_pod]
- return reduce(PodGenerator.reconcile_pods, pod_list)
+ try:
+ return reduce(PodGenerator.reconcile_pods, pod_list)
+ except Exception as e:
+ raise PodReconciliationError from e
@staticmethod
def serialize_pod(pod: k8s.V1Pod) -> dict:
@@ -408,24 +415,25 @@ def deserialize_model_file(path: str) -> k8s.V1Pod:
"""
:param path: Path to the file
:return: a kubernetes.client.models.V1Pod
-
- Unfortunately we need access to the private method
- ``_ApiClient__deserialize_model`` from the kubernetes client.
- This issue is tracked here; https://github.com/kubernetes-client/python/issues/977.
"""
if os.path.exists(path):
with open(path) as stream:
pod = yaml.safe_load(stream)
else:
- pod = yaml.safe_load(path)
+ pod = None
+ log.warning("Model file %s does not exist", path)
return PodGenerator.deserialize_model_dict(pod)
@staticmethod
- def deserialize_model_dict(pod_dict: dict) -> k8s.V1Pod:
+ def deserialize_model_dict(pod_dict: dict | None) -> k8s.V1Pod:
"""
Deserializes python dictionary to k8s.V1Pod
+ Unfortunately we need access to the private method
+ ``_ApiClient__deserialize_model`` from the kubernetes client.
+ This issue is tracked here; https://github.com/kubernetes-client/python/issues/977.
+
:param pod_dict: Serialized dict of k8s.V1Pod object
:return: De-serialized k8s.V1Pod
"""
@@ -433,7 +441,7 @@ def deserialize_model_dict(pod_dict: dict) -> k8s.V1Pod:
return api_client._ApiClient__deserialize_model(pod_dict, k8s.V1Pod)
@staticmethod
- def make_unique_pod_id(pod_id: str) -> Optional[str]:
+ def make_unique_pod_id(pod_id: str) -> str | None:
r"""
Kubernetes pod names must consist of one or more lowercase
rfc1035/rfc1123 labels separated by '.' with a maximum length of 253
@@ -456,7 +464,7 @@ def make_unique_pod_id(pod_id: str) -> Optional[str]:
# Get prefix length after subtracting the uuid length. Clean up '.' and '-' from
# end of podID ('.' can't be followed by '-').
label_prefix_length = MAX_LABEL_LEN - len(safe_uuid) - 1 # -1 for separator
- trimmed_pod_id = pod_id[:label_prefix_length].rstrip('-.')
+ trimmed_pod_id = pod_id[:label_prefix_length].rstrip("-.")
# previously used a '.' as the separator, but this could create errors in some situations
return f"{trimmed_pod_id}-{safe_uuid}"
diff --git a/airflow/kubernetes/pod_generator_deprecated.py b/airflow/kubernetes/pod_generator_deprecated.py
index fcbb2bb402d6d..f08a6c45d2231 100644
--- a/airflow/kubernetes/pod_generator_deprecated.py
+++ b/airflow/kubernetes/pod_generator_deprecated.py
@@ -20,11 +20,12 @@
The advantage being that the full Kubernetes API
is supported and no serialization need be written.
"""
+from __future__ import annotations
+
import copy
import hashlib
import re
import uuid
-from typing import Dict, List, Optional, Union
from kubernetes.client import models as k8s
@@ -36,15 +37,15 @@
class PodDefaults:
"""Static defaults for Pods"""
- XCOM_MOUNT_PATH = '/airflow/xcom'
- SIDECAR_CONTAINER_NAME = 'airflow-xcom-sidecar'
+ XCOM_MOUNT_PATH = "/airflow/xcom"
+ SIDECAR_CONTAINER_NAME = "airflow-xcom-sidecar"
XCOM_CMD = 'trap "exit 0" INT; while true; do sleep 30; done;'
- VOLUME_MOUNT = k8s.V1VolumeMount(name='xcom', mount_path=XCOM_MOUNT_PATH)
- VOLUME = k8s.V1Volume(name='xcom', empty_dir=k8s.V1EmptyDirVolumeSource())
+ VOLUME_MOUNT = k8s.V1VolumeMount(name="xcom", mount_path=XCOM_MOUNT_PATH)
+ VOLUME = k8s.V1Volume(name="xcom", empty_dir=k8s.V1EmptyDirVolumeSource())
SIDECAR_CONTAINER = k8s.V1Container(
name=SIDECAR_CONTAINER_NAME,
- command=['sh', '-c', XCOM_CMD],
- image='alpine',
+ command=["sh", "-c", XCOM_CMD],
+ image="alpine",
volume_mounts=[VOLUME_MOUNT],
resources=k8s.V1ResourceRequirements(
requests={
@@ -117,38 +118,38 @@ class PodGenerator:
def __init__(
self,
- image: Optional[str] = None,
- name: Optional[str] = None,
- namespace: Optional[str] = None,
- volume_mounts: Optional[List[Union[k8s.V1VolumeMount, dict]]] = None,
- envs: Optional[Dict[str, str]] = None,
- cmds: Optional[List[str]] = None,
- args: Optional[List[str]] = None,
- labels: Optional[Dict[str, str]] = None,
- node_selectors: Optional[Dict[str, str]] = None,
- ports: Optional[List[Union[k8s.V1ContainerPort, dict]]] = None,
- volumes: Optional[List[Union[k8s.V1Volume, dict]]] = None,
- image_pull_policy: Optional[str] = None,
- restart_policy: Optional[str] = None,
- image_pull_secrets: Optional[str] = None,
- init_containers: Optional[List[k8s.V1Container]] = None,
- service_account_name: Optional[str] = None,
- resources: Optional[Union[k8s.V1ResourceRequirements, dict]] = None,
- annotations: Optional[Dict[str, str]] = None,
- affinity: Optional[dict] = None,
+ image: str | None = None,
+ name: str | None = None,
+ namespace: str | None = None,
+ volume_mounts: list[k8s.V1VolumeMount | dict] | None = None,
+ envs: dict[str, str] | None = None,
+ cmds: list[str] | None = None,
+ args: list[str] | None = None,
+ labels: dict[str, str] | None = None,
+ node_selectors: dict[str, str] | None = None,
+ ports: list[k8s.V1ContainerPort | dict] | None = None,
+ volumes: list[k8s.V1Volume | dict] | None = None,
+ image_pull_policy: str | None = None,
+ restart_policy: str | None = None,
+ image_pull_secrets: str | None = None,
+ init_containers: list[k8s.V1Container] | None = None,
+ service_account_name: str | None = None,
+ resources: k8s.V1ResourceRequirements | dict | None = None,
+ annotations: dict[str, str] | None = None,
+ affinity: dict | None = None,
hostnetwork: bool = False,
- tolerations: Optional[list] = None,
- security_context: Optional[Union[k8s.V1PodSecurityContext, dict]] = None,
- configmaps: Optional[List[str]] = None,
- dnspolicy: Optional[str] = None,
- schedulername: Optional[str] = None,
+ tolerations: list | None = None,
+ security_context: k8s.V1PodSecurityContext | dict | None = None,
+ configmaps: list[str] | None = None,
+ dnspolicy: str | None = None,
+ schedulername: str | None = None,
extract_xcom: bool = False,
- priority_class_name: Optional[str] = None,
+ priority_class_name: str | None = None,
):
self.pod = k8s.V1Pod()
- self.pod.api_version = 'v1'
- self.pod.kind = 'Pod'
+ self.pod.api_version = "v1"
+ self.pod.kind = "Pod"
# Pod Metadata
self.metadata = k8s.V1ObjectMeta()
@@ -158,7 +159,7 @@ def __init__(
self.metadata.annotations = annotations
# Pod Container
- self.container = k8s.V1Container(name='base')
+ self.container = k8s.V1Container(name="base")
self.container.image = image
self.container.env = []
@@ -204,7 +205,7 @@ def __init__(
self.spec.image_pull_secrets = []
if image_pull_secrets:
- for image_pull_secret in image_pull_secrets.split(','):
+ for image_pull_secret in image_pull_secrets.split(","):
self.spec.image_pull_secrets.append(k8s.V1LocalObjectReference(name=image_pull_secret))
# Attach sidecar
@@ -240,7 +241,7 @@ def add_sidecar(pod: k8s.V1Pod) -> k8s.V1Pod:
return pod_cp
@staticmethod
- def from_obj(obj) -> Optional[k8s.V1Pod]:
+ def from_obj(obj) -> k8s.V1Pod | None:
"""Converts to pod from obj"""
if obj is None:
return None
@@ -250,8 +251,8 @@ def from_obj(obj) -> Optional[k8s.V1Pod]:
if not isinstance(obj, dict):
raise TypeError(
- 'Cannot convert a non-dictionary or non-PodGenerator '
- 'object into a KubernetesExecutorConfig'
+ "Cannot convert a non-dictionary or non-PodGenerator "
+ "object into a KubernetesExecutorConfig"
)
# We do not want to extract constant here from ExecutorLoader because it is just
@@ -261,25 +262,25 @@ def from_obj(obj) -> Optional[k8s.V1Pod]:
if not namespaced:
return None
- resources = namespaced.get('resources')
+ resources = namespaced.get("resources")
if resources is None:
requests = {
- 'cpu': namespaced.get('request_cpu'),
- 'memory': namespaced.get('request_memory'),
- 'ephemeral-storage': namespaced.get('ephemeral-storage'),
+ "cpu": namespaced.get("request_cpu"),
+ "memory": namespaced.get("request_memory"),
+ "ephemeral-storage": namespaced.get("ephemeral-storage"),
}
limits = {
- 'cpu': namespaced.get('limit_cpu'),
- 'memory': namespaced.get('limit_memory'),
- 'ephemeral-storage': namespaced.get('ephemeral-storage'),
+ "cpu": namespaced.get("limit_cpu"),
+ "memory": namespaced.get("limit_memory"),
+ "ephemeral-storage": namespaced.get("ephemeral-storage"),
}
all_resources = list(requests.values()) + list(limits.values())
if all(r is None for r in all_resources):
resources = None
else:
resources = k8s.V1ResourceRequirements(requests=requests, limits=limits)
- namespaced['resources'] = resources
+ namespaced["resources"] = resources
return PodGenerator(**namespaced).gen_pod()
@staticmethod
diff --git a/airflow/kubernetes/pod_launcher.py b/airflow/kubernetes/pod_launcher.py
index 0b9cbbe45a481..bd52f49653313 100644
--- a/airflow/kubernetes/pod_launcher.py
+++ b/airflow/kubernetes/pod_launcher.py
@@ -19,4 +19,6 @@
This module is deprecated.
Please use :mod:`kubernetes.client.models` for V1ResourceRequirements and Port.
"""
+from __future__ import annotations
+
from airflow.kubernetes.pod_launcher_deprecated import PodLauncher, PodStatus # noqa: autoflake
diff --git a/airflow/kubernetes/pod_launcher_deprecated.py b/airflow/kubernetes/pod_launcher_deprecated.py
index 97845dad51d5a..13476c43760be 100644
--- a/airflow/kubernetes/pod_launcher_deprecated.py
+++ b/airflow/kubernetes/pod_launcher_deprecated.py
@@ -15,12 +15,13 @@
# specific language governing permissions and limitations
# under the License.
"""Launches PODs"""
+from __future__ import annotations
+
import json
import math
import time
import warnings
from datetime import datetime as dt
-from typing import Optional, Tuple
import pendulum
import tenacity
@@ -30,7 +31,7 @@
from kubernetes.stream import stream as kubernetes_stream
from requests.exceptions import HTTPError
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, RemovedInAirflow3Warning
from airflow.kubernetes.kube_client import get_kube_client
from airflow.kubernetes.pod_generator import PodDefaults
from airflow.settings import pod_mutation_hook
@@ -46,7 +47,7 @@
https://pypi.org/project/apache-airflow-providers-cncf-kubernetes/
""",
- DeprecationWarning,
+ RemovedInAirflow3Warning,
stacklevel=2,
)
@@ -54,10 +55,10 @@
class PodStatus:
"""Status of the PODs"""
- PENDING = 'pending'
- RUNNING = 'running'
- FAILED = 'failed'
- SUCCEEDED = 'succeeded'
+ PENDING = "pending"
+ RUNNING = "running"
+ FAILED = "failed"
+ SUCCEEDED = "succeeded"
class PodLauncher(LoggingMixin):
@@ -69,7 +70,7 @@ def __init__(
self,
kube_client: client.CoreV1Api = None,
in_cluster: bool = True,
- cluster_context: Optional[str] = None,
+ cluster_context: str | None = None,
extract_xcom: bool = False,
):
"""
@@ -94,14 +95,14 @@ def run_pod_async(self, pod: V1Pod, **kwargs):
sanitized_pod = self._client.api_client.sanitize_for_serialization(pod)
json_pod = json.dumps(sanitized_pod, indent=2)
- self.log.debug('Pod Creation Request: \n%s', json_pod)
+ self.log.debug("Pod Creation Request: \n%s", json_pod)
try:
resp = self._client.create_namespaced_pod(
body=sanitized_pod, namespace=pod.metadata.namespace, **kwargs
)
- self.log.debug('Pod Creation Response: %s', resp)
+ self.log.debug("Pod Creation Response: %s", resp)
except Exception as e:
- self.log.exception('Exception when attempting to create Namespaced Pod: %s', json_pod)
+ self.log.exception("Exception when attempting to create Namespaced Pod: %s", json_pod)
raise e
return resp
@@ -134,13 +135,12 @@ def start_pod(self, pod: V1Pod, startup_timeout: int = 120):
raise AirflowException("Pod took too long to start")
time.sleep(1)
- def monitor_pod(self, pod: V1Pod, get_logs: bool) -> Tuple[State, Optional[str]]:
+ def monitor_pod(self, pod: V1Pod, get_logs: bool) -> tuple[State, str | None]:
"""
Monitors a pod and returns the final state
:param pod: pod spec that will be monitored
:param get_logs: whether to read the logs locally
- :return: Tuple[State, Optional[str]]
"""
if get_logs:
read_logs_since_sec = None
@@ -148,7 +148,7 @@ def monitor_pod(self, pod: V1Pod, get_logs: bool) -> Tuple[State, Optional[str]]
while True:
logs = self.read_pod_logs(pod, timestamps=True, since_seconds=read_logs_since_sec)
for line in logs:
- timestamp, message = self.parse_log_line(line.decode('utf-8'))
+ timestamp, message = self.parse_log_line(line.decode("utf-8"))
if timestamp:
last_log_time = pendulum.parse(timestamp)
self.log.info(message)
@@ -157,7 +157,7 @@ def monitor_pod(self, pod: V1Pod, get_logs: bool) -> Tuple[State, Optional[str]]
if not self.base_container_is_running(pod):
break
- self.log.warning('Pod %s log read interrupted', pod.metadata.name)
+ self.log.warning("Pod %s log read interrupted", pod.metadata.name)
if last_log_time:
delta = pendulum.now() - last_log_time
# Prefer logs duplication rather than loss
@@ -165,25 +165,24 @@ def monitor_pod(self, pod: V1Pod, get_logs: bool) -> Tuple[State, Optional[str]]
result = None
if self.extract_xcom:
while self.base_container_is_running(pod):
- self.log.info('Container %s has state %s', pod.metadata.name, State.RUNNING)
+ self.log.info("Container %s has state %s", pod.metadata.name, State.RUNNING)
time.sleep(2)
result = self._extract_xcom(pod)
self.log.info(result)
result = json.loads(result)
while self.pod_is_running(pod):
- self.log.info('Pod %s has state %s', pod.metadata.name, State.RUNNING)
+ self.log.info("Pod %s has state %s", pod.metadata.name, State.RUNNING)
time.sleep(2)
return self._task_status(self.read_pod(pod)), result
- def parse_log_line(self, line: str) -> Tuple[Optional[str], str]:
+ def parse_log_line(self, line: str) -> tuple[str | None, str]:
"""
Parse K8s log line and returns the final state
:param line: k8s log line
:return: timestamp and log message
- :rtype: Tuple[str, str]
"""
- split_at = line.find(' ')
+ split_at = line.find(" ")
if split_at == -1:
self.log.error(
"Error parsing timestamp (no timestamp in message: %r). "
@@ -196,7 +195,7 @@ def parse_log_line(self, line: str) -> Tuple[Optional[str], str]:
return timestamp, message
def _task_status(self, event):
- self.log.info('Event: %s had an event of type %s', event.metadata.name, event.status.phase)
+ self.log.info("Event: %s had an event of type %s", event.metadata.name, event.status.phase)
status = self.process_status(event.metadata.name, event.status.phase)
return status
@@ -213,7 +212,7 @@ def pod_is_running(self, pod: V1Pod):
def base_container_is_running(self, pod: V1Pod):
"""Tests if base container is running"""
event = self.read_pod(pod)
- status = next(iter(filter(lambda s: s.name == 'base', event.status.container_statuses)), None)
+ status = next((s for s in event.status.container_statuses if s.name == "base"), None)
if not status:
return False
return status.state.running is not None
@@ -222,30 +221,30 @@ def base_container_is_running(self, pod: V1Pod):
def read_pod_logs(
self,
pod: V1Pod,
- tail_lines: Optional[int] = None,
+ tail_lines: int | None = None,
timestamps: bool = False,
- since_seconds: Optional[int] = None,
+ since_seconds: int | None = None,
):
"""Reads log from the POD"""
additional_kwargs = {}
if since_seconds:
- additional_kwargs['since_seconds'] = since_seconds
+ additional_kwargs["since_seconds"] = since_seconds
if tail_lines:
- additional_kwargs['tail_lines'] = tail_lines
+ additional_kwargs["tail_lines"] = tail_lines
try:
return self._client.read_namespaced_pod_log(
name=pod.metadata.name,
namespace=pod.metadata.namespace,
- container='base',
+ container="base",
follow=True,
timestamps=timestamps,
_preload_content=False,
**additional_kwargs,
)
except HTTPError as e:
- raise AirflowException(f'There was an error reading the kubernetes API: {e}')
+ raise AirflowException(f"There was an error reading the kubernetes API: {e}")
@tenacity.retry(stop=tenacity.stop_after_attempt(3), wait=tenacity.wait_exponential(), reraise=True)
def read_pod_events(self, pod):
@@ -255,7 +254,7 @@ def read_pod_events(self, pod):
namespace=pod.metadata.namespace, field_selector=f"involvedObject.name={pod.metadata.name}"
)
except HTTPError as e:
- raise AirflowException(f'There was an error reading the kubernetes API: {e}')
+ raise AirflowException(f"There was an error reading the kubernetes API: {e}")
@tenacity.retry(stop=tenacity.stop_after_attempt(3), wait=tenacity.wait_exponential(), reraise=True)
def read_pod(self, pod: V1Pod):
@@ -263,7 +262,7 @@ def read_pod(self, pod: V1Pod):
try:
return self._client.read_namespaced_pod(pod.metadata.name, pod.metadata.namespace)
except HTTPError as e:
- raise AirflowException(f'There was an error reading the kubernetes API: {e}')
+ raise AirflowException(f"There was an error reading the kubernetes API: {e}")
def _extract_xcom(self, pod: V1Pod):
resp = kubernetes_stream(
@@ -271,7 +270,7 @@ def _extract_xcom(self, pod: V1Pod):
pod.metadata.name,
pod.metadata.namespace,
container=PodDefaults.SIDECAR_CONTAINER_NAME,
- command=['/bin/sh'],
+ command=["/bin/sh"],
stdin=True,
stdout=True,
stderr=True,
@@ -279,18 +278,18 @@ def _extract_xcom(self, pod: V1Pod):
_preload_content=False,
)
try:
- result = self._exec_pod_command(resp, f'cat {PodDefaults.XCOM_MOUNT_PATH}/return.json')
- self._exec_pod_command(resp, 'kill -s SIGINT 1')
+ result = self._exec_pod_command(resp, f"cat {PodDefaults.XCOM_MOUNT_PATH}/return.json")
+ self._exec_pod_command(resp, "kill -s SIGINT 1")
finally:
resp.close()
if result is None:
- raise AirflowException(f'Failed to extract xcom from pod: {pod.metadata.name}')
+ raise AirflowException(f"Failed to extract xcom from pod: {pod.metadata.name}")
return result
def _exec_pod_command(self, resp, command):
if resp.is_open():
- self.log.info('Running command... %s\n', command)
- resp.write_stdin(command + '\n')
+ self.log.info("Running command... %s\n", command)
+ resp.write_stdin(command + "\n")
while resp.is_open():
resp.update(timeout=1)
if resp.peek_stdout():
@@ -306,13 +305,13 @@ def process_status(self, job_id, status):
if status == PodStatus.PENDING:
return State.QUEUED
elif status == PodStatus.FAILED:
- self.log.error('Event with job id %s Failed', job_id)
+ self.log.error("Event with job id %s Failed", job_id)
return State.FAILED
elif status == PodStatus.SUCCEEDED:
- self.log.info('Event with job id %s Succeeded', job_id)
+ self.log.info("Event with job id %s Succeeded", job_id)
return State.SUCCESS
elif status == PodStatus.RUNNING:
return State.RUNNING
else:
- self.log.error('Event: Invalid state %s on job %s', status, job_id)
+ self.log.error("Event: Invalid state %s on job %s", status, job_id)
return State.FAILED
diff --git a/airflow/kubernetes/pod_runtime_info_env.py b/airflow/kubernetes/pod_runtime_info_env.py
index a51f3b96fc39a..32e178263b126 100644
--- a/airflow/kubernetes/pod_runtime_info_env.py
+++ b/airflow/kubernetes/pod_runtime_info_env.py
@@ -16,14 +16,18 @@
# specific language governing permissions and limitations
# under the License.
"""This module is deprecated. Please use :mod:`kubernetes.client.models.V1EnvVar`."""
+from __future__ import annotations
+
import warnings
+from airflow.exceptions import RemovedInAirflow3Warning
+
with warnings.catch_warnings():
- warnings.simplefilter("ignore", DeprecationWarning)
+ warnings.simplefilter("ignore", RemovedInAirflow3Warning)
from airflow.providers.cncf.kubernetes.backcompat.pod_runtime_info_env import PodRuntimeInfoEnv # noqa
warnings.warn(
"This module is deprecated. Please use `kubernetes.client.models.V1EnvVar`.",
- DeprecationWarning,
+ RemovedInAirflow3Warning,
stacklevel=2,
)
diff --git a/airflow/kubernetes/secret.py b/airflow/kubernetes/secret.py
index afb30916ff357..145dcd5b2cb5f 100644
--- a/airflow/kubernetes/secret.py
+++ b/airflow/kubernetes/secret.py
@@ -15,9 +15,10 @@
# specific language governing permissions and limitations
# under the License.
"""Classes for interacting with Kubernetes API"""
+from __future__ import annotations
+
import copy
import uuid
-from typing import Tuple
from kubernetes.client import models as k8s
@@ -45,19 +46,19 @@ def __init__(self, deploy_type, deploy_target, secret, key=None, items=None):
secret keys to paths
https://kubernetes.io/docs/concepts/configuration/secret/#projection-of-secret-keys-to-specific-paths
"""
- if deploy_type not in ('env', 'volume'):
+ if deploy_type not in ("env", "volume"):
raise AirflowConfigException("deploy_type must be env or volume")
self.deploy_type = deploy_type
self.deploy_target = deploy_target
self.items = items or []
- if deploy_target is not None and deploy_type == 'env':
+ if deploy_target is not None and deploy_type == "env":
# if deploying to env, capitalize the deploy target
self.deploy_target = deploy_target.upper()
if key is not None and deploy_target is None:
- raise AirflowConfigException('If `key` is set, `deploy_target` should not be None')
+ raise AirflowConfigException("If `key` is set, `deploy_target` should not be None")
self.secret = secret
self.key = key
@@ -75,9 +76,9 @@ def to_env_from_secret(self) -> k8s.V1EnvFromSource:
"""Reads from environment to secret"""
return k8s.V1EnvFromSource(secret_ref=k8s.V1SecretEnvSource(name=self.secret))
- def to_volume_secret(self) -> Tuple[k8s.V1Volume, k8s.V1VolumeMount]:
+ def to_volume_secret(self) -> tuple[k8s.V1Volume, k8s.V1VolumeMount]:
"""Converts to volume secret"""
- vol_id = f'secretvol{uuid.uuid4()}'
+ vol_id = f"secretvol{uuid.uuid4()}"
volume = k8s.V1Volume(name=vol_id, secret=k8s.V1SecretVolumeSource(secret_name=self.secret))
if self.items:
volume.secret.items = self.items
@@ -87,7 +88,7 @@ def attach_to_pod(self, pod: k8s.V1Pod) -> k8s.V1Pod:
"""Attaches to pod"""
cp_pod = copy.deepcopy(pod)
- if self.deploy_type == 'volume':
+ if self.deploy_type == "volume":
volume, volume_mount = self.to_volume_secret()
if cp_pod.spec.volumes is None:
cp_pod.spec.volumes = []
@@ -96,13 +97,13 @@ def attach_to_pod(self, pod: k8s.V1Pod) -> k8s.V1Pod:
cp_pod.spec.containers[0].volume_mounts = []
cp_pod.spec.containers[0].volume_mounts.append(volume_mount)
- if self.deploy_type == 'env' and self.key is not None:
+ if self.deploy_type == "env" and self.key is not None:
env = self.to_env_secret()
if cp_pod.spec.containers[0].env is None:
cp_pod.spec.containers[0].env = []
cp_pod.spec.containers[0].env.append(env)
- if self.deploy_type == 'env' and self.key is None:
+ if self.deploy_type == "env" and self.key is None:
env_from = self.to_env_from_secret()
if cp_pod.spec.containers[0].env_from is None:
cp_pod.spec.containers[0].env_from = []
@@ -119,4 +120,4 @@ def __eq__(self, other):
)
def __repr__(self):
- return f'Secret({self.deploy_type}, {self.deploy_target}, {self.secret}, {self.key})'
+ return f"Secret({self.deploy_type}, {self.deploy_target}, {self.secret}, {self.key})"
diff --git a/airflow/kubernetes/volume.py b/airflow/kubernetes/volume.py
index 81b4fda3a15da..ecb39e457fd4a 100644
--- a/airflow/kubernetes/volume.py
+++ b/airflow/kubernetes/volume.py
@@ -16,14 +16,18 @@
# specific language governing permissions and limitations
# under the License.
"""This module is deprecated. Please use :mod:`kubernetes.client.models.V1Volume`."""
+from __future__ import annotations
+
import warnings
+from airflow.exceptions import RemovedInAirflow3Warning
+
with warnings.catch_warnings():
- warnings.simplefilter("ignore", DeprecationWarning)
+ warnings.simplefilter("ignore", RemovedInAirflow3Warning)
from airflow.providers.cncf.kubernetes.backcompat.volume import Volume # noqa: autoflake
warnings.warn(
"This module is deprecated. Please use `kubernetes.client.models.V1Volume`.",
- DeprecationWarning,
+ RemovedInAirflow3Warning,
stacklevel=2,
)
diff --git a/airflow/kubernetes/volume_mount.py b/airflow/kubernetes/volume_mount.py
index f558425881752..e65351d85f5af 100644
--- a/airflow/kubernetes/volume_mount.py
+++ b/airflow/kubernetes/volume_mount.py
@@ -16,14 +16,18 @@
# specific language governing permissions and limitations
# under the License.
"""This module is deprecated. Please use :mod:`kubernetes.client.models.V1VolumeMount`."""
+from __future__ import annotations
+
import warnings
+from airflow.exceptions import RemovedInAirflow3Warning
+
with warnings.catch_warnings():
- warnings.simplefilter("ignore", DeprecationWarning)
+ warnings.simplefilter("ignore", RemovedInAirflow3Warning)
from airflow.providers.cncf.kubernetes.backcompat.volume_mount import VolumeMount # noqa: autoflake
warnings.warn(
"This module is deprecated. Please use `kubernetes.client.models.V1VolumeMount`.",
- DeprecationWarning,
+ RemovedInAirflow3Warning,
stacklevel=2,
)
diff --git a/airflow/lineage/__init__.py b/airflow/lineage/__init__.py
index 3d8c696487915..173956b74c289 100644
--- a/airflow/lineage/__init__.py
+++ b/airflow/lineage/__init__.py
@@ -15,21 +15,20 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Provides lineage support functions"""
-import json
+"""Provides lineage support functions."""
+from __future__ import annotations
+
+import itertools
import logging
from functools import wraps
-from typing import Any, Callable, Dict, Optional, TypeVar, cast
-
-import attr
-import jinja2
-from cattr import structure, unstructure
+from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast
from airflow.configuration import conf
from airflow.lineage.backend import LineageBackend
-from airflow.utils.module_loading import import_string
-ENV = jinja2.Environment()
+if TYPE_CHECKING:
+ from airflow.utils.context import Context
+
PIPELINE_OUTLETS = "pipeline_outlets"
PIPELINE_INLETS = "pipeline_inlets"
@@ -38,17 +37,8 @@
log = logging.getLogger(__name__)
-@attr.s(auto_attribs=True)
-class Metadata:
- """Class for serialized entities."""
-
- type_name: str = attr.ib()
- source: str = attr.ib()
- data: Dict = attr.ib()
-
-
-def get_backend() -> Optional[LineageBackend]:
- """Gets the lineage backend if defined in the configs"""
+def get_backend() -> LineageBackend | None:
+ """Gets the lineage backend if defined in the configs."""
clazz = conf.getimport("lineage", "backend", fallback=None)
if clazz:
@@ -63,33 +53,8 @@ def get_backend() -> Optional[LineageBackend]:
return None
-def _get_instance(meta: Metadata):
- """Instantiate an object from Metadata"""
- cls = import_string(meta.type_name)
- return structure(meta.data, cls)
-
-
-def _render_object(obj: Any, context) -> Any:
- """Renders a attr annotated object. Will set non serializable attributes to none"""
- return structure(
- json.loads(
- ENV.from_string(json.dumps(unstructure(obj), default=lambda o: None))
- .render(**context)
- .encode('utf-8')
- ),
- type(obj),
- )
-
-
-def _to_dataset(obj: Any, source: str) -> Optional[Metadata]:
- """Create Metadata from attr annotated object"""
- if not attr.has(obj):
- return None
-
- type_name = obj.__module__ + '.' + obj.__class__.__name__
- data = unstructure(obj)
-
- return Metadata(type_name, source, data)
+def _render_object(obj: Any, context: Context) -> dict:
+ return context["ti"].task.render_template(obj, context)
T = TypeVar("T", bound=Callable)
@@ -97,6 +62,8 @@ def _to_dataset(obj: Any, source: str) -> Optional[Metadata]:
def apply_lineage(func: T) -> T:
"""
+ Conditionally send lineage to the backend.
+
Saves the lineage to XCom and if configured to do so sends it
to the backend.
"""
@@ -104,20 +71,22 @@ def apply_lineage(func: T) -> T:
@wraps(func)
def wrapper(self, context, *args, **kwargs):
+
self.log.debug("Lineage called with inlets: %s, outlets: %s", self.inlets, self.outlets)
+
ret_val = func(self, context, *args, **kwargs)
- outlets = [unstructure(_to_dataset(x, f"{self.dag_id}.{self.task_id}")) for x in self.outlets]
- inlets = [unstructure(_to_dataset(x, None)) for x in self.inlets]
+ outlets = list(self.outlets)
+ inlets = list(self.inlets)
- if self.outlets:
+ if outlets:
self.xcom_push(
- context, key=PIPELINE_OUTLETS, value=outlets, execution_date=context['ti'].execution_date
+ context, key=PIPELINE_OUTLETS, value=outlets, execution_date=context["ti"].execution_date
)
- if self.inlets:
+ if inlets:
self.xcom_push(
- context, key=PIPELINE_INLETS, value=inlets, execution_date=context['ti'].execution_date
+ context, key=PIPELINE_INLETS, value=inlets, execution_date=context["ti"].execution_date
)
if _backend:
@@ -130,7 +99,9 @@ def wrapper(self, context, *args, **kwargs):
def prepare_lineage(func: T) -> T:
"""
- Prepares the lineage inlets and outlets. Inlets can be:
+ Prepares the lineage inlets and outlets.
+
+ Inlets can be:
* "auto" -> picks up any outlets from direct upstream tasks that have outlets defined, as such that
if A -> B -> C and B does not have outlets but A does, these are provided as inlets.
@@ -145,49 +116,44 @@ def wrapper(self, context, *args, **kwargs):
self.log.debug("Preparing lineage inlets and outlets")
- if isinstance(self._inlets, (str, AbstractOperator)) or attr.has(self._inlets):
- self._inlets = [
- self._inlets,
- ]
+ if isinstance(self.inlets, (str, AbstractOperator)):
+ self.inlets = [self.inlets]
- if self._inlets and isinstance(self._inlets, list):
+ if self.inlets and isinstance(self.inlets, list):
# get task_ids that are specified as parameter and make sure they are upstream
task_ids = (
- {o for o in self._inlets if isinstance(o, str)}
- .union(op.task_id for op in self._inlets if isinstance(op, AbstractOperator))
+ {o for o in self.inlets if isinstance(o, str)}
+ .union(op.task_id for op in self.inlets if isinstance(op, AbstractOperator))
.intersection(self.get_flat_relative_ids(upstream=True))
)
# pick up unique direct upstream task_ids if AUTO is specified
- if AUTO.upper() in self._inlets or AUTO.lower() in self._inlets:
+ if AUTO.upper() in self.inlets or AUTO.lower() in self.inlets:
task_ids = task_ids.union(task_ids.symmetric_difference(self.upstream_task_ids))
+ # Remove auto and task_ids
+ self.inlets = [i for i in self.inlets if not isinstance(i, str)]
_inlets = self.xcom_pull(context, task_ids=task_ids, dag_id=self.dag_id, key=PIPELINE_OUTLETS)
# re-instantiate the obtained inlets
- _inlets = [
- _get_instance(structure(item, Metadata)) for sublist in _inlets if sublist for item in sublist
- ]
+ # xcom_pull returns a list of items for each given task_id
+ _inlets = [item for item in itertools.chain.from_iterable(_inlets)]
self.inlets.extend(_inlets)
- self.inlets.extend(self._inlets)
- elif self._inlets:
+ elif self.inlets:
raise AttributeError("inlets is not a list, operator, string or attr annotated object")
- if not isinstance(self._outlets, list):
- self._outlets = [
- self._outlets,
- ]
-
- self.outlets.extend(self._outlets)
+ if not isinstance(self.outlets, list):
+ self.outlets = [self.outlets]
# render inlets and outlets
- self.inlets = [_render_object(i, context) for i in self.inlets if attr.has(i)]
+ self.inlets = [_render_object(i, context) for i in self.inlets]
- self.outlets = [_render_object(i, context) for i in self.outlets if attr.has(i)]
+ self.outlets = [_render_object(i, context) for i in self.outlets]
self.log.debug("inlets: %s, outlets: %s", self.inlets, self.outlets)
+
return func(self, context, *args, **kwargs)
return cast(T, wrapper)
diff --git a/airflow/lineage/backend.py b/airflow/lineage/backend.py
index ca072f434105a..29a755109c64f 100644
--- a/airflow/lineage/backend.py
+++ b/airflow/lineage/backend.py
@@ -15,25 +15,27 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Sends lineage metadata to a backend"""
-from typing import TYPE_CHECKING, Optional
+"""Sends lineage metadata to a backend."""
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
if TYPE_CHECKING:
from airflow.models.baseoperator import BaseOperator
class LineageBackend:
- """Sends lineage metadata to a backend"""
+ """Sends lineage metadata to a backend."""
def send_lineage(
self,
- operator: 'BaseOperator',
- inlets: Optional[list] = None,
- outlets: Optional[list] = None,
- context: Optional[dict] = None,
+ operator: BaseOperator,
+ inlets: list | None = None,
+ outlets: list | None = None,
+ context: dict | None = None,
):
"""
- Sends lineage metadata to a backend
+ Sends lineage metadata to a backend.
:param operator: the operator executing a transformation on the inlets and outlets
:param inlets: the inlets to this operator
diff --git a/airflow/lineage/entities.py b/airflow/lineage/entities.py
index 87703edfbaf10..c52dc4cfd0133 100644
--- a/airflow/lineage/entities.py
+++ b/airflow/lineage/entities.py
@@ -15,31 +15,33 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+"""Defines base entities used for providing lineage information."""
+from __future__ import annotations
-"""
-Defines the base entities that can be used for providing lineage
-information.
-"""
-from typing import Any, Dict, List, Optional
+from typing import Any, ClassVar
import attr
@attr.s(auto_attribs=True)
class File:
- """File entity. Refers to a file"""
+ """File entity. Refers to a file."""
+
+ template_fields: ClassVar = ("url",)
url: str = attr.ib()
- type_hint: Optional[str] = None
+ type_hint: str | None = None
@attr.s(auto_attribs=True, kw_only=True)
class User:
- """User entity. Identifies a user"""
+ """User entity. Identifies a user."""
email: str = attr.ib()
- first_name: Optional[str] = None
- last_name: Optional[str] = None
+ first_name: str | None = None
+ last_name: str | None = None
+
+ template_fields: ClassVar = ("email", "first_name", "last_name")
@attr.s(auto_attribs=True, kw_only=True)
@@ -48,15 +50,19 @@ class Tag:
tag_name: str = attr.ib()
+ template_fields: ClassVar = ("tag_name",)
+
@attr.s(auto_attribs=True, kw_only=True)
class Column:
- """Column of a Table"""
+ """Column of a Table."""
name: str = attr.ib()
- description: Optional[str] = None
+ description: str | None = None
data_type: str = attr.ib()
- tags: List[Tag] = []
+ tags: list[Tag] = []
+
+ template_fields: ClassVar = ("name", "description", "data_type", "tags")
# this is a temporary hack to satisfy mypy. Once
@@ -64,20 +70,32 @@ class Column:
# `attr.converters.default_if_none(default=False)`
-def default_if_none(arg: Optional[bool]) -> bool:
+def default_if_none(arg: bool | None) -> bool:
+ """Get default value when None."""
return arg or False
@attr.s(auto_attribs=True, kw_only=True)
class Table:
- """Table entity"""
+ """Table entity."""
database: str = attr.ib()
cluster: str = attr.ib()
name: str = attr.ib()
- tags: List[Tag] = []
- description: Optional[str] = None
- columns: List[Column] = []
- owners: List[User] = []
- extra: Dict[str, Any] = {}
- type_hint: Optional[str] = None
+ tags: list[Tag] = []
+ description: str | None = None
+ columns: list[Column] = []
+ owners: list[User] = []
+ extra: dict[str, Any] = {}
+ type_hint: str | None = None
+
+ template_fields: ClassVar = (
+ "database",
+ "cluster",
+ "name",
+ "tags",
+ "description",
+ "columns",
+ "owners",
+ "extra",
+ )
diff --git a/airflow/listeners/__init__.py b/airflow/listeners/__init__.py
index d1df70bda1158..87840b50e2fa5 100644
--- a/airflow/listeners/__init__.py
+++ b/airflow/listeners/__init__.py
@@ -15,6 +15,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
from pluggy import HookimplMarker
hookimpl = HookimplMarker("airflow")
diff --git a/airflow/listeners/events.py b/airflow/listeners/events.py
index d5af64710a17a..53c113af8ee66 100644
--- a/airflow/listeners/events.py
+++ b/airflow/listeners/events.py
@@ -14,6 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
import logging
from sqlalchemy import event
@@ -28,6 +30,8 @@
def on_task_instance_state_session_flush(session, flush_context):
"""
+ Flush task instance's state.
+
Listens for session.flush() events that modify TaskInstance's state, and notify listeners that listen
for that event. Doing it this way enable us to be stateless in the SQLAlchemy event listener.
"""
@@ -38,7 +42,7 @@ def on_task_instance_state_session_flush(session, flush_context):
if isinstance(state.object, TaskInstance) and session.is_modified(
state.object, include_collections=False
):
- added, unchanged, deleted = flush_context.get_attribute_history(state, 'state')
+ added, unchanged, deleted = flush_context.get_attribute_history(state, "state")
logger.debug(
"session flush listener: added %s unchanged %s deleted %s - %s",
@@ -67,13 +71,15 @@ def on_task_instance_state_session_flush(session, flush_context):
def register_task_instance_state_events():
+ """Register a task instance state event."""
global _is_listening
if not _is_listening:
- event.listen(Session, 'after_flush', on_task_instance_state_session_flush)
+ event.listen(Session, "after_flush", on_task_instance_state_session_flush)
_is_listening = True
def unregister_task_instance_state_events():
+ """Unregister a task instance state event."""
global _is_listening
- event.remove(Session, 'after_flush', on_task_instance_state_session_flush)
+ event.remove(Session, "after_flush", on_task_instance_state_session_flush)
_is_listening = False
diff --git a/airflow/listeners/listener.py b/airflow/listeners/listener.py
index 3c4d052399dab..546d732513f76 100644
--- a/airflow/listeners/listener.py
+++ b/airflow/listeners/listener.py
@@ -15,8 +15,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
import logging
-from types import ModuleType
from typing import TYPE_CHECKING
import pluggy
@@ -33,37 +34,38 @@
class ListenerManager:
- """Class that manages registration of listeners and provides hook property for calling them"""
+ """Manage listener registration and provides hook property for calling them."""
def __init__(self):
- from airflow.listeners import spec
+ from airflow.listeners.spec import dagrun, lifecycle, taskinstance
self.pm = pluggy.PluginManager("airflow")
- self.pm.add_hookspecs(spec)
+ self.pm.add_hookspecs(lifecycle)
+ self.pm.add_hookspecs(dagrun)
+ self.pm.add_hookspecs(taskinstance)
@property
def has_listeners(self) -> bool:
return len(self.pm.get_plugins()) > 0
@property
- def hook(self) -> "_HookRelay":
- """Returns hook, on which plugin methods specified in spec can be called."""
+ def hook(self) -> _HookRelay:
+ """Return hook, on which plugin methods specified in spec can be called."""
return self.pm.hook
def add_listener(self, listener):
- if not isinstance(listener, ModuleType):
- raise TypeError("Listener %s is not module", str(listener))
if self.pm.is_registered(listener):
return
self.pm.register(listener)
def clear(self):
- """Remove registered plugins"""
+ """Remove registered plugins."""
for plugin in self.pm.get_plugins():
self.pm.unregister(plugin)
def get_listener_manager() -> ListenerManager:
+ """Get singleton listener manager."""
global _listener_manager
if not _listener_manager:
_listener_manager = ListenerManager()
diff --git a/airflow/listeners/spec.py b/airflow/listeners/spec.py
deleted file mode 100644
index fbaf63e89ac6e..0000000000000
--- a/airflow/listeners/spec.py
+++ /dev/null
@@ -1,49 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from typing import TYPE_CHECKING, Optional
-
-from pluggy import HookspecMarker
-
-if TYPE_CHECKING:
- from sqlalchemy.orm.session import Session
-
- from airflow.models.taskinstance import TaskInstance
- from airflow.utils.state import TaskInstanceState
-
-hookspec = HookspecMarker("airflow")
-
-
-@hookspec
-def on_task_instance_running(
- previous_state: "TaskInstanceState", task_instance: "TaskInstance", session: Optional["Session"]
-):
- """Called when task state changes to RUNNING. Previous_state can be State.NONE."""
-
-
-@hookspec
-def on_task_instance_success(
- previous_state: "TaskInstanceState", task_instance: "TaskInstance", session: Optional["Session"]
-):
- """Called when task state changes to SUCCESS. Previous_state can be State.NONE."""
-
-
-@hookspec
-def on_task_instance_failed(
- previous_state: "TaskInstanceState", task_instance: "TaskInstance", session: Optional["Session"]
-):
- """Called when task state changes to FAIL. Previous_state can be State.NONE."""
diff --git a/airflow/providers/airbyte/example_dags/__init__.py b/airflow/listeners/spec/__init__.py
similarity index 100%
rename from airflow/providers/airbyte/example_dags/__init__.py
rename to airflow/listeners/spec/__init__.py
diff --git a/airflow/listeners/spec/dagrun.py b/airflow/listeners/spec/dagrun.py
new file mode 100644
index 0000000000000..d2ae1a6b78cb5
--- /dev/null
+++ b/airflow/listeners/spec/dagrun.py
@@ -0,0 +1,42 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from pluggy import HookspecMarker
+
+if TYPE_CHECKING:
+ from airflow.models.dagrun import DagRun
+
+hookspec = HookspecMarker("airflow")
+
+
+@hookspec
+def on_dag_run_running(dag_run: DagRun, msg: str):
+ """Called when dag run state changes to RUNNING."""
+
+
+@hookspec
+def on_dag_run_success(dag_run: DagRun, msg: str):
+ """Called when dag run state changes to SUCCESS."""
+
+
+@hookspec
+def on_dag_run_failed(dag_run: DagRun, msg: str):
+ """Called when dag run state changes to FAIL."""
diff --git a/airflow/listeners/spec/lifecycle.py b/airflow/listeners/spec/lifecycle.py
new file mode 100644
index 0000000000000..6ab0aa3b5cde1
--- /dev/null
+++ b/airflow/listeners/spec/lifecycle.py
@@ -0,0 +1,44 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from pluggy import HookspecMarker
+
+hookspec = HookspecMarker("airflow")
+
+
+@hookspec
+def on_starting(component):
+ """
+ Called before Airflow component - jobs like scheduler, worker, or task runner starts.
+
+ It's guaranteed this will be called before any other plugin method.
+
+ :param component: Component that calls this method
+ """
+
+
+@hookspec
+def before_stopping(component):
+ """
+ Called before Airflow component - jobs like scheduler, worker, or task runner stops.
+
+ It's guaranteed this will be called after any other plugin method.
+
+ :param component: Component that calls this method
+ """
diff --git a/airflow/listeners/spec/taskinstance.py b/airflow/listeners/spec/taskinstance.py
new file mode 100644
index 0000000000000..78de8a5f62b14
--- /dev/null
+++ b/airflow/listeners/spec/taskinstance.py
@@ -0,0 +1,51 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from pluggy import HookspecMarker
+
+if TYPE_CHECKING:
+ from sqlalchemy.orm.session import Session
+
+ from airflow.models.taskinstance import TaskInstance
+ from airflow.utils.state import TaskInstanceState
+
+hookspec = HookspecMarker("airflow")
+
+
+@hookspec
+def on_task_instance_running(
+ previous_state: TaskInstanceState, task_instance: TaskInstance, session: Session | None
+):
+ """Called when task state changes to RUNNING. Previous_state can be State.NONE."""
+
+
+@hookspec
+def on_task_instance_success(
+ previous_state: TaskInstanceState, task_instance: TaskInstance, session: Session | None
+):
+ """Called when task state changes to SUCCESS. Previous_state can be State.NONE."""
+
+
+@hookspec
+def on_task_instance_failed(
+ previous_state: TaskInstanceState, task_instance: TaskInstance, session: Session | None
+):
+ """Called when task state changes to FAIL. Previous_state can be State.NONE."""
diff --git a/airflow/logging_config.py b/airflow/logging_config.py
index 645e53eb3efaa..d78f84cb6a443 100644
--- a/airflow/logging_config.py
+++ b/airflow/logging_config.py
@@ -15,7 +15,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-#
+from __future__ import annotations
+
import logging
import warnings
from logging.config import dictConfig
@@ -29,11 +30,11 @@
def configure_logging():
"""Configure & Validate Airflow Logging"""
- logging_class_path = ''
+ logging_class_path = ""
try:
- logging_class_path = conf.get('logging', 'logging_config_class')
+ logging_class_path = conf.get("logging", "logging_config_class")
except AirflowConfigException:
- log.debug('Could not find key logging_config_class in config')
+ log.debug("Could not find key logging_config_class in config")
if logging_class_path:
try:
@@ -43,31 +44,31 @@ def configure_logging():
if not isinstance(logging_config, dict):
raise ValueError("Logging Config should be of dict type")
- log.info('Successfully imported user-defined logging config from %s', logging_class_path)
+ log.info("Successfully imported user-defined logging config from %s", logging_class_path)
except Exception as err:
# Import default logging configurations.
- raise ImportError(f'Unable to load custom logging from {logging_class_path} due to {err}')
+ raise ImportError(f"Unable to load custom logging from {logging_class_path} due to {err}")
else:
- logging_class_path = 'airflow.config_templates.airflow_local_settings.DEFAULT_LOGGING_CONFIG'
+ logging_class_path = "airflow.config_templates.airflow_local_settings.DEFAULT_LOGGING_CONFIG"
logging_config = import_string(logging_class_path)
- log.debug('Unable to load custom logging, using default config instead')
+ log.debug("Unable to load custom logging, using default config instead")
try:
# Ensure that the password masking filter is applied to the 'task' handler
# no matter what the user did.
- if 'filters' in logging_config and 'mask_secrets' in logging_config['filters']:
+ if "filters" in logging_config and "mask_secrets" in logging_config["filters"]:
# But if they replace the logging config _entirely_, don't try to set this, it won't work
- task_handler_config = logging_config['handlers']['task']
+ task_handler_config = logging_config["handlers"]["task"]
- task_handler_config.setdefault('filters', [])
+ task_handler_config.setdefault("filters", [])
- if 'mask_secrets' not in task_handler_config['filters']:
- task_handler_config['filters'].append('mask_secrets')
+ if "mask_secrets" not in task_handler_config["filters"]:
+ task_handler_config["filters"].append("mask_secrets")
# Try to init logging
dictConfig(logging_config)
except (ValueError, KeyError) as e:
- log.error('Unable to load the config, contains a configuration error.')
+ log.error("Unable to load the config, contains a configuration error.")
# When there is an error in the config, escalate the exception
# otherwise Airflow would silently fall back on the default config
raise e
@@ -80,9 +81,9 @@ def configure_logging():
def validate_logging_config(logging_config):
"""Validate the provided Logging Config"""
# Now lets validate the other logging-related settings
- task_log_reader = conf.get('logging', 'task_log_reader')
+ task_log_reader = conf.get("logging", "task_log_reader")
- logger = logging.getLogger('airflow.task')
+ logger = logging.getLogger("airflow.task")
def _get_handler(name):
return next((h for h in logger.handlers if h.name == name), None)
@@ -96,7 +97,7 @@ def _get_handler(name):
"Running config has been adjusted to match",
DeprecationWarning,
)
- conf.set('logging', 'task_log_reader', 'task')
+ conf.set("logging", "task_log_reader", "task")
else:
raise AirflowConfigException(
f"Configured task_log_reader {task_log_reader!r} was not a handler of "
diff --git a/airflow/macros/__init__.py b/airflow/macros/__init__.py
index e1f27411e5069..4364d3278a1e1 100644
--- a/airflow/macros/__init__.py
+++ b/airflow/macros/__init__.py
@@ -15,11 +15,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
import time # noqa
import uuid # noqa
from datetime import datetime, timedelta
from random import random # noqa
-from typing import Any, Optional
+from typing import Any
import dateutil # noqa
from pendulum import DateTime
@@ -29,7 +31,7 @@
def ds_add(ds: str, days: int) -> str:
"""
- Add or subtract days from a YYYY-MM-DD
+ Add or subtract days from a YYYY-MM-DD.
:param ds: anchor date in ``YYYY-MM-DD`` format to add to
:param days: number of days to add to the ds, you can use negative values
@@ -47,8 +49,7 @@ def ds_add(ds: str, days: int) -> str:
def ds_format(ds: str, input_format: str, output_format: str) -> str:
"""
- Takes an input string and outputs another string
- as specified in the output format
+ Output datetime string in a given format.
:param ds: input string which contains a date
:param input_format: input string format. E.g. %Y-%m-%d
@@ -62,15 +63,15 @@ def ds_format(ds: str, input_format: str, output_format: str) -> str:
return datetime.strptime(str(ds), input_format).strftime(output_format)
-def datetime_diff_for_humans(dt: Any, since: Optional[DateTime] = None) -> str:
+def datetime_diff_for_humans(dt: Any, since: DateTime | None = None) -> str:
"""
- Return a human-readable/approximate difference between two datetimes, or
- one and now.
+ Return a human-readable/approximate difference between datetimes.
+
+ When only one datetime is provided, the comparison will be based on now.
:param dt: The datetime to display the diff for
:param since: When to display the date from. If ``None`` then the diff is
between ``dt`` and now.
- :rtype: str
"""
import pendulum
diff --git a/airflow/macros/hive.py b/airflow/macros/hive.py
index fe5685fcf2536..7da0202eb2596 100644
--- a/airflow/macros/hive.py
+++ b/airflow/macros/hive.py
@@ -15,12 +15,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
import datetime
def max_partition(
- table, schema="default", field=None, filter_map=None, metastore_conn_id='metastore_default'
+ table, schema="default", field=None, filter_map=None, metastore_conn_id="metastore_default"
):
"""
Gets the max partition for a table.
@@ -43,13 +44,13 @@ def max_partition(
"""
from airflow.providers.apache.hive.hooks.hive import HiveMetastoreHook
- if '.' in table:
- schema, table = table.split('.')
+ if "." in table:
+ schema, table = table.split(".")
hive_hook = HiveMetastoreHook(metastore_conn_id=metastore_conn_id)
return hive_hook.max_partition(schema=schema, table_name=table, field=field, filter_map=filter_map)
-def _closest_date(target_dt, date_list, before_target=None):
+def _closest_date(target_dt, date_list, before_target=None) -> datetime.date | None:
"""
This function finds the date in a list closest to the target date.
An optional parameter can be given to get the closest before or after.
@@ -58,7 +59,6 @@ def _closest_date(target_dt, date_list, before_target=None):
:param date_list: The list of dates to search
:param before_target: closest before or after the target
:returns: The closest date
- :rtype: datetime.date or None
"""
time_before = lambda d: target_dt - d if d <= target_dt else datetime.timedelta.max
time_after = lambda d: d - target_dt if d >= target_dt else datetime.timedelta.max
@@ -71,7 +71,9 @@ def _closest_date(target_dt, date_list, before_target=None):
return min(date_list, key=time_after).date()
-def closest_ds_partition(table, ds, before=True, schema="default", metastore_conn_id='metastore_default'):
+def closest_ds_partition(
+ table, ds, before=True, schema="default", metastore_conn_id="metastore_default"
+) -> str | None:
"""
This function finds the date in a list closest to the target date.
An optional parameter can be given to get the closest before or after.
@@ -82,7 +84,6 @@ def closest_ds_partition(table, ds, before=True, schema="default", metastore_con
:param schema: table schema
:param metastore_conn_id: which metastore connection to use
:returns: The closest date
- :rtype: str or None
>>> tbl = 'airflow.static_babynames_partitioned'
>>> closest_ds_partition(tbl, '2015-01-02')
@@ -90,8 +91,8 @@ def closest_ds_partition(table, ds, before=True, schema="default", metastore_con
"""
from airflow.providers.apache.hive.hooks.hive import HiveMetastoreHook
- if '.' in table:
- schema, table = table.split('.')
+ if "." in table:
+ schema, table = table.split(".")
hive_hook = HiveMetastoreHook(metastore_conn_id=metastore_conn_id)
partitions = hive_hook.get_partitions(schema=schema, table_name=table)
if not partitions:
@@ -100,7 +101,9 @@ def closest_ds_partition(table, ds, before=True, schema="default", metastore_con
if ds in part_vals:
return ds
else:
- parts = [datetime.datetime.strptime(pv, '%Y-%m-%d') for pv in part_vals]
- target_dt = datetime.datetime.strptime(ds, '%Y-%m-%d')
+ parts = [datetime.datetime.strptime(pv, "%Y-%m-%d") for pv in part_vals]
+ target_dt = datetime.datetime.strptime(ds, "%Y-%m-%d")
closest_ds = _closest_date(target_dt, parts, before_target=before)
- return closest_ds.isoformat()
+ if closest_ds is not None:
+ return closest_ds.isoformat()
+ return None
diff --git a/airflow/migrations/db_types.py b/airflow/migrations/db_types.py
index 9b8f3e974ffd9..70a1a1e6b47f4 100644
--- a/airflow/migrations/db_types.py
+++ b/airflow/migrations/db_types.py
@@ -15,7 +15,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-#
+from __future__ import annotations
+
import sqlalchemy as sa
from alembic import context
from lazy_object_proxy import Proxy
@@ -71,7 +72,7 @@ def lazy_load():
module = globals()
# Lookup the type based on the dialect specific type, or fallback to the generic type
- type_ = module.get(f'_{dialect}_{name}', None) or module.get(f'_sa_{name}')
+ type_ = module.get(f"_{dialect}_{name}", None) or module.get(f"_sa_{name}")
val = module[name] = type_()
return val
diff --git a/airflow/migrations/env.py b/airflow/migrations/env.py
index 58a8f7f4a0465..9f97195a6d9c4 100644
--- a/airflow/migrations/env.py
+++ b/airflow/migrations/env.py
@@ -15,12 +15,16 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+import contextlib
+import sys
from logging.config import fileConfig
from alembic import context
from airflow import models, settings
+from airflow.utils.db import compare_server_default, compare_type
def include_object(_, name, type_, *args):
@@ -32,6 +36,9 @@ def include_object(_, name, type_, *args):
return True
+# Make sure everything is imported so that alembic can find it all
+models.import_all_models()
+
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
@@ -51,8 +58,6 @@ def include_object(_, name, type_, *args):
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
-COMPARE_TYPE = False
-
def run_migrations_offline():
"""Run migrations in 'offline' mode.
@@ -70,7 +75,8 @@ def run_migrations_offline():
url=settings.SQL_ALCHEMY_CONN,
target_metadata=target_metadata,
literal_binds=True,
- compare_type=COMPARE_TYPE,
+ compare_type=compare_type,
+ compare_server_default=compare_server_default,
render_as_batch=True,
)
@@ -85,14 +91,18 @@ def run_migrations_online():
and associate a connection with the context.
"""
- connectable = settings.engine
+ with contextlib.ExitStack() as stack:
+ connection = config.attributes.get("connection", None)
+
+ if not connection:
+ connection = stack.push(settings.engine.connect())
- with connectable.connect() as connection:
context.configure(
connection=connection,
transaction_per_migration=True,
target_metadata=target_metadata,
- compare_type=COMPARE_TYPE,
+ compare_type=compare_type,
+ compare_server_default=compare_server_default,
include_object=include_object,
render_as_batch=True,
)
@@ -105,3 +115,9 @@ def run_migrations_online():
run_migrations_offline()
else:
run_migrations_online()
+
+if "airflow.www.app" in sys.modules:
+ # Already imported, make sure we clear out any cached app
+ from airflow.www.app import purge_cached_app
+
+ purge_cached_app()
diff --git a/airflow/migrations/utils.py b/airflow/migrations/utils.py
index 5737fa950774d..78f925ac84b53 100644
--- a/airflow/migrations/utils.py
+++ b/airflow/migrations/utils.py
@@ -14,12 +14,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
from collections import defaultdict
from contextlib import contextmanager
-def get_mssql_table_constraints(conn, table_name):
+def get_mssql_table_constraints(conn, table_name) -> dict[str, dict[str, list[str]]]:
"""
This function return primary and unique constraint
along with column name. Some tables like `task_instance`
@@ -29,7 +30,6 @@ def get_mssql_table_constraints(conn, table_name):
:param conn: sql connection object
:param table_name: table name
:return: a dictionary of ((constraint name, constraint type), column name) of table
- :rtype: defaultdict(list)
"""
query = f"""SELECT tc.CONSTRAINT_NAME , tc.CONSTRAINT_TYPE, ccu.COLUMN_NAME
FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS tc
@@ -47,7 +47,7 @@ def get_mssql_table_constraints(conn, table_name):
@contextmanager
def disable_sqlite_fkeys(op):
- if op.get_bind().dialect.name == 'sqlite':
+ if op.get_bind().dialect.name == "sqlite":
op.execute("PRAGMA foreign_keys=off")
yield op
op.execute("PRAGMA foreign_keys=on")
diff --git a/airflow/migrations/versions/0001_1_5_0_current_schema.py b/airflow/migrations/versions/0001_1_5_0_current_schema.py
index 9824db7dad36f..0bfc7ca518308 100644
--- a/airflow/migrations/versions/0001_1_5_0_current_schema.py
+++ b/airflow/migrations/versions/0001_1_5_0_current_schema.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""current schema
Revision ID: e3a246e0dc1
@@ -23,20 +22,20 @@
Create Date: 2015-08-18 16:35:00.883495
"""
+from __future__ import annotations
import sqlalchemy as sa
from alembic import op
-from sqlalchemy import func
+from sqlalchemy import func, inspect
-from airflow.compat.sqlalchemy import inspect
from airflow.migrations.db_types import StringID
# revision identifiers, used by Alembic.
-revision = 'e3a246e0dc1'
+revision = "e3a246e0dc1"
down_revision = None
branch_labels = None
depends_on = None
-airflow_version = '1.5.0'
+airflow_version = "1.5.0"
def upgrade():
@@ -44,199 +43,199 @@ def upgrade():
inspector = inspect(conn)
tables = inspector.get_table_names()
- if 'connection' not in tables:
+ if "connection" not in tables:
op.create_table(
- 'connection',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('conn_id', StringID(), nullable=True),
- sa.Column('conn_type', sa.String(length=500), nullable=True),
- sa.Column('host', sa.String(length=500), nullable=True),
- sa.Column('schema', sa.String(length=500), nullable=True),
- sa.Column('login', sa.String(length=500), nullable=True),
- sa.Column('password', sa.String(length=500), nullable=True),
- sa.Column('port', sa.Integer(), nullable=True),
- sa.Column('extra', sa.String(length=5000), nullable=True),
- sa.PrimaryKeyConstraint('id'),
+ "connection",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("conn_id", StringID(), nullable=True),
+ sa.Column("conn_type", sa.String(length=500), nullable=True),
+ sa.Column("host", sa.String(length=500), nullable=True),
+ sa.Column("schema", sa.String(length=500), nullable=True),
+ sa.Column("login", sa.String(length=500), nullable=True),
+ sa.Column("password", sa.String(length=500), nullable=True),
+ sa.Column("port", sa.Integer(), nullable=True),
+ sa.Column("extra", sa.String(length=5000), nullable=True),
+ sa.PrimaryKeyConstraint("id"),
)
- if 'dag' not in tables:
+ if "dag" not in tables:
op.create_table(
- 'dag',
- sa.Column('dag_id', StringID(), nullable=False),
- sa.Column('is_paused', sa.Boolean(), nullable=True),
- sa.Column('is_subdag', sa.Boolean(), nullable=True),
- sa.Column('is_active', sa.Boolean(), nullable=True),
- sa.Column('last_scheduler_run', sa.DateTime(), nullable=True),
- sa.Column('last_pickled', sa.DateTime(), nullable=True),
- sa.Column('last_expired', sa.DateTime(), nullable=True),
- sa.Column('scheduler_lock', sa.Boolean(), nullable=True),
- sa.Column('pickle_id', sa.Integer(), nullable=True),
- sa.Column('fileloc', sa.String(length=2000), nullable=True),
- sa.Column('owners', sa.String(length=2000), nullable=True),
- sa.PrimaryKeyConstraint('dag_id'),
+ "dag",
+ sa.Column("dag_id", StringID(), nullable=False),
+ sa.Column("is_paused", sa.Boolean(), nullable=True),
+ sa.Column("is_subdag", sa.Boolean(), nullable=True),
+ sa.Column("is_active", sa.Boolean(), nullable=True),
+ sa.Column("last_scheduler_run", sa.DateTime(), nullable=True),
+ sa.Column("last_pickled", sa.DateTime(), nullable=True),
+ sa.Column("last_expired", sa.DateTime(), nullable=True),
+ sa.Column("scheduler_lock", sa.Boolean(), nullable=True),
+ sa.Column("pickle_id", sa.Integer(), nullable=True),
+ sa.Column("fileloc", sa.String(length=2000), nullable=True),
+ sa.Column("owners", sa.String(length=2000), nullable=True),
+ sa.PrimaryKeyConstraint("dag_id"),
)
- if 'dag_pickle' not in tables:
+ if "dag_pickle" not in tables:
op.create_table(
- 'dag_pickle',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('pickle', sa.PickleType(), nullable=True),
- sa.Column('created_dttm', sa.DateTime(), nullable=True),
- sa.Column('pickle_hash', sa.BigInteger(), nullable=True),
- sa.PrimaryKeyConstraint('id'),
+ "dag_pickle",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("pickle", sa.PickleType(), nullable=True),
+ sa.Column("created_dttm", sa.DateTime(), nullable=True),
+ sa.Column("pickle_hash", sa.BigInteger(), nullable=True),
+ sa.PrimaryKeyConstraint("id"),
)
- if 'import_error' not in tables:
+ if "import_error" not in tables:
op.create_table(
- 'import_error',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('timestamp', sa.DateTime(), nullable=True),
- sa.Column('filename', sa.String(length=1024), nullable=True),
- sa.Column('stacktrace', sa.Text(), nullable=True),
- sa.PrimaryKeyConstraint('id'),
+ "import_error",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("timestamp", sa.DateTime(), nullable=True),
+ sa.Column("filename", sa.String(length=1024), nullable=True),
+ sa.Column("stacktrace", sa.Text(), nullable=True),
+ sa.PrimaryKeyConstraint("id"),
)
- if 'job' not in tables:
+ if "job" not in tables:
op.create_table(
- 'job',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('dag_id', sa.String(length=250), nullable=True),
- sa.Column('state', sa.String(length=20), nullable=True),
- sa.Column('job_type', sa.String(length=30), nullable=True),
- sa.Column('start_date', sa.DateTime(), nullable=True),
- sa.Column('end_date', sa.DateTime(), nullable=True),
- sa.Column('latest_heartbeat', sa.DateTime(), nullable=True),
- sa.Column('executor_class', sa.String(length=500), nullable=True),
- sa.Column('hostname', sa.String(length=500), nullable=True),
- sa.Column('unixname', sa.String(length=1000), nullable=True),
- sa.PrimaryKeyConstraint('id'),
+ "job",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("dag_id", sa.String(length=250), nullable=True),
+ sa.Column("state", sa.String(length=20), nullable=True),
+ sa.Column("job_type", sa.String(length=30), nullable=True),
+ sa.Column("start_date", sa.DateTime(), nullable=True),
+ sa.Column("end_date", sa.DateTime(), nullable=True),
+ sa.Column("latest_heartbeat", sa.DateTime(), nullable=True),
+ sa.Column("executor_class", sa.String(length=500), nullable=True),
+ sa.Column("hostname", sa.String(length=500), nullable=True),
+ sa.Column("unixname", sa.String(length=1000), nullable=True),
+ sa.PrimaryKeyConstraint("id"),
)
- op.create_index('job_type_heart', 'job', ['job_type', 'latest_heartbeat'], unique=False)
- if 'log' not in tables:
+ op.create_index("job_type_heart", "job", ["job_type", "latest_heartbeat"], unique=False)
+ if "log" not in tables:
op.create_table(
- 'log',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('dttm', sa.DateTime(), nullable=True),
- sa.Column('dag_id', StringID(), nullable=True),
- sa.Column('task_id', StringID(), nullable=True),
- sa.Column('event', sa.String(length=30), nullable=True),
- sa.Column('execution_date', sa.DateTime(), nullable=True),
- sa.Column('owner', sa.String(length=500), nullable=True),
- sa.PrimaryKeyConstraint('id'),
+ "log",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("dttm", sa.DateTime(), nullable=True),
+ sa.Column("dag_id", StringID(), nullable=True),
+ sa.Column("task_id", StringID(), nullable=True),
+ sa.Column("event", sa.String(length=30), nullable=True),
+ sa.Column("execution_date", sa.DateTime(), nullable=True),
+ sa.Column("owner", sa.String(length=500), nullable=True),
+ sa.PrimaryKeyConstraint("id"),
)
- if 'sla_miss' not in tables:
+ if "sla_miss" not in tables:
op.create_table(
- 'sla_miss',
- sa.Column('task_id', StringID(), nullable=False),
- sa.Column('dag_id', StringID(), nullable=False),
- sa.Column('execution_date', sa.DateTime(), nullable=False),
- sa.Column('email_sent', sa.Boolean(), nullable=True),
- sa.Column('timestamp', sa.DateTime(), nullable=True),
- sa.Column('description', sa.Text(), nullable=True),
- sa.PrimaryKeyConstraint('task_id', 'dag_id', 'execution_date'),
+ "sla_miss",
+ sa.Column("task_id", StringID(), nullable=False),
+ sa.Column("dag_id", StringID(), nullable=False),
+ sa.Column("execution_date", sa.DateTime(), nullable=False),
+ sa.Column("email_sent", sa.Boolean(), nullable=True),
+ sa.Column("timestamp", sa.DateTime(), nullable=True),
+ sa.Column("description", sa.Text(), nullable=True),
+ sa.PrimaryKeyConstraint("task_id", "dag_id", "execution_date"),
)
- if 'slot_pool' not in tables:
+ if "slot_pool" not in tables:
op.create_table(
- 'slot_pool',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('pool', StringID(length=50), nullable=True),
- sa.Column('slots', sa.Integer(), nullable=True),
- sa.Column('description', sa.Text(), nullable=True),
- sa.PrimaryKeyConstraint('id'),
- sa.UniqueConstraint('pool'),
+ "slot_pool",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("pool", StringID(length=50), nullable=True),
+ sa.Column("slots", sa.Integer(), nullable=True),
+ sa.Column("description", sa.Text(), nullable=True),
+ sa.PrimaryKeyConstraint("id"),
+ sa.UniqueConstraint("pool"),
)
- if 'task_instance' not in tables:
+ if "task_instance" not in tables:
op.create_table(
- 'task_instance',
- sa.Column('task_id', StringID(), nullable=False),
- sa.Column('dag_id', StringID(), nullable=False),
- sa.Column('execution_date', sa.DateTime(), nullable=False),
- sa.Column('start_date', sa.DateTime(), nullable=True),
- sa.Column('end_date', sa.DateTime(), nullable=True),
- sa.Column('duration', sa.Integer(), nullable=True),
- sa.Column('state', sa.String(length=20), nullable=True),
- sa.Column('try_number', sa.Integer(), nullable=True),
- sa.Column('hostname', sa.String(length=1000), nullable=True),
- sa.Column('unixname', sa.String(length=1000), nullable=True),
- sa.Column('job_id', sa.Integer(), nullable=True),
- sa.Column('pool', sa.String(length=50), nullable=True),
- sa.Column('queue', sa.String(length=50), nullable=True),
- sa.Column('priority_weight', sa.Integer(), nullable=True),
- sa.PrimaryKeyConstraint('task_id', 'dag_id', 'execution_date'),
+ "task_instance",
+ sa.Column("task_id", StringID(), nullable=False),
+ sa.Column("dag_id", StringID(), nullable=False),
+ sa.Column("execution_date", sa.DateTime(), nullable=False),
+ sa.Column("start_date", sa.DateTime(), nullable=True),
+ sa.Column("end_date", sa.DateTime(), nullable=True),
+ sa.Column("duration", sa.Integer(), nullable=True),
+ sa.Column("state", sa.String(length=20), nullable=True),
+ sa.Column("try_number", sa.Integer(), nullable=True),
+ sa.Column("hostname", sa.String(length=1000), nullable=True),
+ sa.Column("unixname", sa.String(length=1000), nullable=True),
+ sa.Column("job_id", sa.Integer(), nullable=True),
+ sa.Column("pool", sa.String(length=50), nullable=True),
+ sa.Column("queue", sa.String(length=50), nullable=True),
+ sa.Column("priority_weight", sa.Integer(), nullable=True),
+ sa.PrimaryKeyConstraint("task_id", "dag_id", "execution_date"),
)
- op.create_index('ti_dag_state', 'task_instance', ['dag_id', 'state'], unique=False)
- op.create_index('ti_pool', 'task_instance', ['pool', 'state', 'priority_weight'], unique=False)
+ op.create_index("ti_dag_state", "task_instance", ["dag_id", "state"], unique=False)
+ op.create_index("ti_pool", "task_instance", ["pool", "state", "priority_weight"], unique=False)
op.create_index(
- 'ti_state_lkp', 'task_instance', ['dag_id', 'task_id', 'execution_date', 'state'], unique=False
+ "ti_state_lkp", "task_instance", ["dag_id", "task_id", "execution_date", "state"], unique=False
)
- if 'user' not in tables:
+ if "user" not in tables:
op.create_table(
- 'user',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('username', StringID(), nullable=True),
- sa.Column('email', sa.String(length=500), nullable=True),
- sa.PrimaryKeyConstraint('id'),
- sa.UniqueConstraint('username'),
+ "user",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("username", StringID(), nullable=True),
+ sa.Column("email", sa.String(length=500), nullable=True),
+ sa.PrimaryKeyConstraint("id"),
+ sa.UniqueConstraint("username"),
)
- if 'variable' not in tables:
+ if "variable" not in tables:
op.create_table(
- 'variable',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('key', StringID(), nullable=True),
- sa.Column('val', sa.Text(), nullable=True),
- sa.PrimaryKeyConstraint('id'),
- sa.UniqueConstraint('key'),
+ "variable",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("key", StringID(), nullable=True),
+ sa.Column("val", sa.Text(), nullable=True),
+ sa.PrimaryKeyConstraint("id"),
+ sa.UniqueConstraint("key"),
)
- if 'chart' not in tables:
+ if "chart" not in tables:
op.create_table(
- 'chart',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('label', sa.String(length=200), nullable=True),
- sa.Column('conn_id', sa.String(length=250), nullable=False),
- sa.Column('user_id', sa.Integer(), nullable=True),
- sa.Column('chart_type', sa.String(length=100), nullable=True),
- sa.Column('sql_layout', sa.String(length=50), nullable=True),
- sa.Column('sql', sa.Text(), nullable=True),
- sa.Column('y_log_scale', sa.Boolean(), nullable=True),
- sa.Column('show_datatable', sa.Boolean(), nullable=True),
- sa.Column('show_sql', sa.Boolean(), nullable=True),
- sa.Column('height', sa.Integer(), nullable=True),
- sa.Column('default_params', sa.String(length=5000), nullable=True),
- sa.Column('x_is_date', sa.Boolean(), nullable=True),
- sa.Column('iteration_no', sa.Integer(), nullable=True),
- sa.Column('last_modified', sa.DateTime(), nullable=True),
+ "chart",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("label", sa.String(length=200), nullable=True),
+ sa.Column("conn_id", sa.String(length=250), nullable=False),
+ sa.Column("user_id", sa.Integer(), nullable=True),
+ sa.Column("chart_type", sa.String(length=100), nullable=True),
+ sa.Column("sql_layout", sa.String(length=50), nullable=True),
+ sa.Column("sql", sa.Text(), nullable=True),
+ sa.Column("y_log_scale", sa.Boolean(), nullable=True),
+ sa.Column("show_datatable", sa.Boolean(), nullable=True),
+ sa.Column("show_sql", sa.Boolean(), nullable=True),
+ sa.Column("height", sa.Integer(), nullable=True),
+ sa.Column("default_params", sa.String(length=5000), nullable=True),
+ sa.Column("x_is_date", sa.Boolean(), nullable=True),
+ sa.Column("iteration_no", sa.Integer(), nullable=True),
+ sa.Column("last_modified", sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(
- ['user_id'],
- ['user.id'],
+ ["user_id"],
+ ["user.id"],
),
- sa.PrimaryKeyConstraint('id'),
+ sa.PrimaryKeyConstraint("id"),
)
- if 'xcom' not in tables:
+ if "xcom" not in tables:
op.create_table(
- 'xcom',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('key', StringID(length=512), nullable=True),
- sa.Column('value', sa.PickleType(), nullable=True),
- sa.Column('timestamp', sa.DateTime(), default=func.now(), nullable=False),
- sa.Column('execution_date', sa.DateTime(), nullable=False),
- sa.Column('task_id', StringID(), nullable=False),
- sa.Column('dag_id', StringID(), nullable=False),
- sa.PrimaryKeyConstraint('id'),
+ "xcom",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("key", StringID(length=512), nullable=True),
+ sa.Column("value", sa.PickleType(), nullable=True),
+ sa.Column("timestamp", sa.DateTime(), default=func.now, nullable=False),
+ sa.Column("execution_date", sa.DateTime(), nullable=False),
+ sa.Column("task_id", StringID(), nullable=False),
+ sa.Column("dag_id", StringID(), nullable=False),
+ sa.PrimaryKeyConstraint("id"),
)
def downgrade():
- op.drop_table('chart')
- op.drop_table('variable')
- op.drop_table('user')
- op.drop_index('ti_state_lkp', table_name='task_instance')
- op.drop_index('ti_pool', table_name='task_instance')
- op.drop_index('ti_dag_state', table_name='task_instance')
- op.drop_table('task_instance')
- op.drop_table('slot_pool')
- op.drop_table('sla_miss')
- op.drop_table('log')
- op.drop_index('job_type_heart', table_name='job')
- op.drop_table('job')
- op.drop_table('import_error')
- op.drop_table('dag_pickle')
- op.drop_table('dag')
- op.drop_table('connection')
- op.drop_table('xcom')
+ op.drop_table("chart")
+ op.drop_table("variable")
+ op.drop_table("user")
+ op.drop_index("ti_state_lkp", table_name="task_instance")
+ op.drop_index("ti_pool", table_name="task_instance")
+ op.drop_index("ti_dag_state", table_name="task_instance")
+ op.drop_table("task_instance")
+ op.drop_table("slot_pool")
+ op.drop_table("sla_miss")
+ op.drop_table("log")
+ op.drop_index("job_type_heart", table_name="job")
+ op.drop_table("job")
+ op.drop_table("import_error")
+ op.drop_table("dag_pickle")
+ op.drop_table("dag")
+ op.drop_table("connection")
+ op.drop_table("xcom")
diff --git a/airflow/migrations/versions/0002_1_5_0_create_is_encrypted.py b/airflow/migrations/versions/0002_1_5_0_create_is_encrypted.py
index 2d13be3d2e234..9ab70672c7457 100644
--- a/airflow/migrations/versions/0002_1_5_0_create_is_encrypted.py
+++ b/airflow/migrations/versions/0002_1_5_0_create_is_encrypted.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Add ``is_encrypted`` column in ``connection`` table
Revision ID: 1507a7289a2f
@@ -23,20 +22,21 @@
Create Date: 2015-08-18 18:57:51.927315
"""
+from __future__ import annotations
+
import sqlalchemy as sa
from alembic import op
-
-from airflow.compat.sqlalchemy import inspect
+from sqlalchemy import inspect
# revision identifiers, used by Alembic.
-revision = '1507a7289a2f'
-down_revision = 'e3a246e0dc1'
+revision = "1507a7289a2f"
+down_revision = "e3a246e0dc1"
branch_labels = None
depends_on = None
-airflow_version = '1.5.0'
+airflow_version = "1.5.0"
connectionhelper = sa.Table(
- 'connection', sa.MetaData(), sa.Column('id', sa.Integer, primary_key=True), sa.Column('is_encrypted')
+ "connection", sa.MetaData(), sa.Column("id", sa.Integer, primary_key=True), sa.Column("is_encrypted")
)
@@ -49,16 +49,16 @@ def upgrade():
# this will only be true if 'connection' already exists in the db,
# but not if alembic created it in a previous migration
- if 'connection' in inspector.get_table_names():
- col_names = [c['name'] for c in inspector.get_columns('connection')]
- if 'is_encrypted' in col_names:
+ if "connection" in inspector.get_table_names():
+ col_names = [c["name"] for c in inspector.get_columns("connection")]
+ if "is_encrypted" in col_names:
return
- op.add_column('connection', sa.Column('is_encrypted', sa.Boolean, unique=False, default=False))
+ op.add_column("connection", sa.Column("is_encrypted", sa.Boolean, unique=False, default=False))
conn = op.get_bind()
conn.execute(connectionhelper.update().values(is_encrypted=False))
def downgrade():
- op.drop_column('connection', 'is_encrypted')
+ op.drop_column("connection", "is_encrypted")
diff --git a/airflow/migrations/versions/0003_1_5_0_for_compatibility.py b/airflow/migrations/versions/0003_1_5_0_for_compatibility.py
index cd45e1fa745da..99c0a5df0db17 100644
--- a/airflow/migrations/versions/0003_1_5_0_for_compatibility.py
+++ b/airflow/migrations/versions/0003_1_5_0_for_compatibility.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Maintain history for compatibility with earlier migrations
Revision ID: 13eb55f81627
@@ -23,13 +22,14 @@
Create Date: 2015-08-23 05:12:49.732174
"""
+from __future__ import annotations
# revision identifiers, used by Alembic.
-revision = '13eb55f81627'
-down_revision = '1507a7289a2f'
+revision = "13eb55f81627"
+down_revision = "1507a7289a2f"
branch_labels = None
depends_on = None
-airflow_version = '1.5.0'
+airflow_version = "1.5.0"
def upgrade():
diff --git a/airflow/migrations/versions/0004_1_5_0_more_logging_into_task_isntance.py b/airflow/migrations/versions/0004_1_5_0_more_logging_into_task_isntance.py
index 2eb793addaed5..e54d86c3771f7 100644
--- a/airflow/migrations/versions/0004_1_5_0_more_logging_into_task_isntance.py
+++ b/airflow/migrations/versions/0004_1_5_0_more_logging_into_task_isntance.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Add ``operator`` and ``queued_dttm`` to ``task_instance`` table
Revision ID: 338e90f54d61
@@ -23,22 +22,24 @@
Create Date: 2015-08-25 06:09:20.460147
"""
+from __future__ import annotations
+
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
-revision = '338e90f54d61'
-down_revision = '13eb55f81627'
+revision = "338e90f54d61"
+down_revision = "13eb55f81627"
branch_labels = None
depends_on = None
-airflow_version = '1.5.0'
+airflow_version = "1.5.0"
def upgrade():
- op.add_column('task_instance', sa.Column('operator', sa.String(length=1000), nullable=True))
- op.add_column('task_instance', sa.Column('queued_dttm', sa.DateTime(), nullable=True))
+ op.add_column("task_instance", sa.Column("operator", sa.String(length=1000), nullable=True))
+ op.add_column("task_instance", sa.Column("queued_dttm", sa.DateTime(), nullable=True))
def downgrade():
- op.drop_column('task_instance', 'queued_dttm')
- op.drop_column('task_instance', 'operator')
+ op.drop_column("task_instance", "queued_dttm")
+ op.drop_column("task_instance", "operator")
diff --git a/airflow/migrations/versions/0005_1_5_2_job_id_indices.py b/airflow/migrations/versions/0005_1_5_2_job_id_indices.py
index e6ba4fd226670..e2443e676a6ba 100644
--- a/airflow/migrations/versions/0005_1_5_2_job_id_indices.py
+++ b/airflow/migrations/versions/0005_1_5_2_job_id_indices.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Add indices in ``job`` table
Revision ID: 52d714495f0
@@ -23,19 +22,21 @@
Create Date: 2015-10-20 03:17:01.962542
"""
+from __future__ import annotations
+
from alembic import op
# revision identifiers, used by Alembic.
-revision = '52d714495f0'
-down_revision = '338e90f54d61'
+revision = "52d714495f0"
+down_revision = "338e90f54d61"
branch_labels = None
depends_on = None
-airflow_version = '1.5.2'
+airflow_version = "1.5.2"
def upgrade():
- op.create_index('idx_job_state_heartbeat', 'job', ['state', 'latest_heartbeat'], unique=False)
+ op.create_index("idx_job_state_heartbeat", "job", ["state", "latest_heartbeat"], unique=False)
def downgrade():
- op.drop_index('idx_job_state_heartbeat', table_name='job')
+ op.drop_index("idx_job_state_heartbeat", table_name="job")
diff --git a/airflow/migrations/versions/0006_1_6_0_adding_extra_to_log.py b/airflow/migrations/versions/0006_1_6_0_adding_extra_to_log.py
index 5bc28ad372633..1fec347e8fd17 100644
--- a/airflow/migrations/versions/0006_1_6_0_adding_extra_to_log.py
+++ b/airflow/migrations/versions/0006_1_6_0_adding_extra_to_log.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Adding ``extra`` column to ``Log`` table
Revision ID: 502898887f84
@@ -23,20 +22,22 @@
Create Date: 2015-11-03 22:50:49.794097
"""
+from __future__ import annotations
+
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
-revision = '502898887f84'
-down_revision = '52d714495f0'
+revision = "502898887f84"
+down_revision = "52d714495f0"
branch_labels = None
depends_on = None
-airflow_version = '1.6.0'
+airflow_version = "1.6.0"
def upgrade():
- op.add_column('log', sa.Column('extra', sa.Text(), nullable=True))
+ op.add_column("log", sa.Column("extra", sa.Text(), nullable=True))
def downgrade():
- op.drop_column('log', 'extra')
+ op.drop_column("log", "extra")
diff --git a/airflow/migrations/versions/0007_1_6_0_add_dagrun.py b/airflow/migrations/versions/0007_1_6_0_add_dagrun.py
index 66a65fd3c3d3c..93440146b4929 100644
--- a/airflow/migrations/versions/0007_1_6_0_add_dagrun.py
+++ b/airflow/migrations/versions/0007_1_6_0_add_dagrun.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Add ``dag_run`` table
Revision ID: 1b38cef5b76e
@@ -23,6 +22,7 @@
Create Date: 2015-10-27 08:31:48.475140
"""
+from __future__ import annotations
import sqlalchemy as sa
from alembic import op
@@ -30,27 +30,27 @@
from airflow.migrations.db_types import StringID
# revision identifiers, used by Alembic.
-revision = '1b38cef5b76e'
-down_revision = '502898887f84'
+revision = "1b38cef5b76e"
+down_revision = "502898887f84"
branch_labels = None
depends_on = None
-airflow_version = '1.6.0'
+airflow_version = "1.6.0"
def upgrade():
op.create_table(
- 'dag_run',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('dag_id', StringID(), nullable=True),
- sa.Column('execution_date', sa.DateTime(), nullable=True),
- sa.Column('state', sa.String(length=50), nullable=True),
- sa.Column('run_id', StringID(), nullable=True),
- sa.Column('external_trigger', sa.Boolean(), nullable=True),
- sa.PrimaryKeyConstraint('id'),
- sa.UniqueConstraint('dag_id', 'execution_date'),
- sa.UniqueConstraint('dag_id', 'run_id'),
+ "dag_run",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("dag_id", StringID(), nullable=True),
+ sa.Column("execution_date", sa.DateTime(), nullable=True),
+ sa.Column("state", sa.String(length=50), nullable=True),
+ sa.Column("run_id", StringID(), nullable=True),
+ sa.Column("external_trigger", sa.Boolean(), nullable=True),
+ sa.PrimaryKeyConstraint("id"),
+ sa.UniqueConstraint("dag_id", "execution_date"),
+ sa.UniqueConstraint("dag_id", "run_id"),
)
def downgrade():
- op.drop_table('dag_run')
+ op.drop_table("dag_run")
diff --git a/airflow/migrations/versions/0008_1_6_0_task_duration.py b/airflow/migrations/versions/0008_1_6_0_task_duration.py
index 9fa217d1aa181..0a17100e25844 100644
--- a/airflow/migrations/versions/0008_1_6_0_task_duration.py
+++ b/airflow/migrations/versions/0008_1_6_0_task_duration.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Change ``task_instance.task_duration`` type to ``FLOAT``
Revision ID: 2e541a1dcfed
@@ -23,24 +22,25 @@
Create Date: 2015-10-28 20:38:41.266143
"""
+from __future__ import annotations
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import mysql
# revision identifiers, used by Alembic.
-revision = '2e541a1dcfed'
-down_revision = '1b38cef5b76e'
+revision = "2e541a1dcfed"
+down_revision = "1b38cef5b76e"
branch_labels = None
depends_on = None
-airflow_version = '1.6.0'
+airflow_version = "1.6.0"
def upgrade():
# use batch_alter_table to support SQLite workaround
with op.batch_alter_table("task_instance") as batch_op:
batch_op.alter_column(
- 'duration',
+ "duration",
existing_type=mysql.INTEGER(display_width=11),
type_=sa.Float(),
existing_nullable=True,
diff --git a/airflow/migrations/versions/0009_1_6_0_dagrun_config.py b/airflow/migrations/versions/0009_1_6_0_dagrun_config.py
index ae7790713797c..502598b09c6b7 100644
--- a/airflow/migrations/versions/0009_1_6_0_dagrun_config.py
+++ b/airflow/migrations/versions/0009_1_6_0_dagrun_config.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Add ``conf`` column in ``dag_run`` table
Revision ID: 40e67319e3a9
@@ -23,20 +22,22 @@
Create Date: 2015-10-29 08:36:31.726728
"""
+from __future__ import annotations
+
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
-revision = '40e67319e3a9'
-down_revision = '2e541a1dcfed'
+revision = "40e67319e3a9"
+down_revision = "2e541a1dcfed"
branch_labels = None
depends_on = None
-airflow_version = '1.6.0'
+airflow_version = "1.6.0"
def upgrade():
- op.add_column('dag_run', sa.Column('conf', sa.PickleType(), nullable=True))
+ op.add_column("dag_run", sa.Column("conf", sa.PickleType(), nullable=True))
def downgrade():
- op.drop_column('dag_run', 'conf')
+ op.drop_column("dag_run", "conf")
diff --git a/airflow/migrations/versions/0010_1_6_2_add_password_column_to_user.py b/airflow/migrations/versions/0010_1_6_2_add_password_column_to_user.py
index d1004ba52a8d3..5938c3e1cf0d6 100644
--- a/airflow/migrations/versions/0010_1_6_2_add_password_column_to_user.py
+++ b/airflow/migrations/versions/0010_1_6_2_add_password_column_to_user.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Add ``password`` column to ``user`` table
Revision ID: 561833c1c74b
@@ -23,20 +22,22 @@
Create Date: 2015-11-30 06:51:25.872557
"""
+from __future__ import annotations
+
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
-revision = '561833c1c74b'
-down_revision = '40e67319e3a9'
+revision = "561833c1c74b"
+down_revision = "40e67319e3a9"
branch_labels = None
depends_on = None
-airflow_version = '1.6.2'
+airflow_version = "1.6.2"
def upgrade():
- op.add_column('user', sa.Column('password', sa.String(255)))
+ op.add_column("user", sa.Column("password", sa.String(255)))
def downgrade():
- op.drop_column('user', 'password')
+ op.drop_column("user", "password")
diff --git a/airflow/migrations/versions/0011_1_6_2_dagrun_start_end.py b/airflow/migrations/versions/0011_1_6_2_dagrun_start_end.py
index ae471443b04e8..e03f49a8a7024 100644
--- a/airflow/migrations/versions/0011_1_6_2_dagrun_start_end.py
+++ b/airflow/migrations/versions/0011_1_6_2_dagrun_start_end.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Add ``start_date`` and ``end_date`` in ``dag_run`` table
Revision ID: 4446e08588
@@ -23,23 +22,24 @@
Create Date: 2015-12-10 11:26:18.439223
"""
+from __future__ import annotations
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
-revision = '4446e08588'
-down_revision = '561833c1c74b'
+revision = "4446e08588"
+down_revision = "561833c1c74b"
branch_labels = None
depends_on = None
-airflow_version = '1.6.2'
+airflow_version = "1.6.2"
def upgrade():
- op.add_column('dag_run', sa.Column('end_date', sa.DateTime(), nullable=True))
- op.add_column('dag_run', sa.Column('start_date', sa.DateTime(), nullable=True))
+ op.add_column("dag_run", sa.Column("end_date", sa.DateTime(), nullable=True))
+ op.add_column("dag_run", sa.Column("start_date", sa.DateTime(), nullable=True))
def downgrade():
- op.drop_column('dag_run', 'start_date')
- op.drop_column('dag_run', 'end_date')
+ op.drop_column("dag_run", "start_date")
+ op.drop_column("dag_run", "end_date")
diff --git a/airflow/migrations/versions/0012_1_7_0_add_notification_sent_column_to_sla_miss.py b/airflow/migrations/versions/0012_1_7_0_add_notification_sent_column_to_sla_miss.py
index 0e3548257c7a8..58ff703aeed16 100644
--- a/airflow/migrations/versions/0012_1_7_0_add_notification_sent_column_to_sla_miss.py
+++ b/airflow/migrations/versions/0012_1_7_0_add_notification_sent_column_to_sla_miss.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Add ``notification_sent`` column to ``sla_miss`` table
Revision ID: bbc73705a13e
@@ -23,20 +22,22 @@
Create Date: 2016-01-14 18:05:54.871682
"""
+from __future__ import annotations
+
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
-revision = 'bbc73705a13e'
-down_revision = '4446e08588'
+revision = "bbc73705a13e"
+down_revision = "4446e08588"
branch_labels = None
depends_on = None
-airflow_version = '1.7.0'
+airflow_version = "1.7.0"
def upgrade():
- op.add_column('sla_miss', sa.Column('notification_sent', sa.Boolean, default=False))
+ op.add_column("sla_miss", sa.Column("notification_sent", sa.Boolean, default=False))
def downgrade():
- op.drop_column('sla_miss', 'notification_sent')
+ op.drop_column("sla_miss", "notification_sent")
diff --git a/airflow/migrations/versions/0013_1_7_0_add_a_column_to_track_the_encryption_.py b/airflow/migrations/versions/0013_1_7_0_add_a_column_to_track_the_encryption_.py
index 4e0d53e062654..f8bd23bf5f441 100644
--- a/airflow/migrations/versions/0013_1_7_0_add_a_column_to_track_the_encryption_.py
+++ b/airflow/migrations/versions/0013_1_7_0_add_a_column_to_track_the_encryption_.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Add a column to track the encryption state of the 'Extra' field in connection
Revision ID: bba5a7cfc896
@@ -23,21 +22,22 @@
Create Date: 2016-01-29 15:10:32.656425
"""
+from __future__ import annotations
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
-revision = 'bba5a7cfc896'
-down_revision = 'bbc73705a13e'
+revision = "bba5a7cfc896"
+down_revision = "bbc73705a13e"
branch_labels = None
depends_on = None
-airflow_version = '1.7.0'
+airflow_version = "1.7.0"
def upgrade():
- op.add_column('connection', sa.Column('is_extra_encrypted', sa.Boolean, default=False))
+ op.add_column("connection", sa.Column("is_extra_encrypted", sa.Boolean, default=False))
def downgrade():
- op.drop_column('connection', 'is_extra_encrypted')
+ op.drop_column("connection", "is_extra_encrypted")
diff --git a/airflow/migrations/versions/0014_1_7_0_add_is_encrypted_column_to_variable_.py b/airflow/migrations/versions/0014_1_7_0_add_is_encrypted_column_to_variable_.py
index e72260ec2089d..eba4999da7f8c 100644
--- a/airflow/migrations/versions/0014_1_7_0_add_is_encrypted_column_to_variable_.py
+++ b/airflow/migrations/versions/0014_1_7_0_add_is_encrypted_column_to_variable_.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Add ``is_encrypted`` column to variable table
Revision ID: 1968acfc09e3
@@ -23,20 +22,22 @@
Create Date: 2016-02-02 17:20:55.692295
"""
+from __future__ import annotations
+
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
-revision = '1968acfc09e3'
-down_revision = 'bba5a7cfc896'
+revision = "1968acfc09e3"
+down_revision = "bba5a7cfc896"
branch_labels = None
depends_on = None
-airflow_version = '1.7.0'
+airflow_version = "1.7.0"
def upgrade():
- op.add_column('variable', sa.Column('is_encrypted', sa.Boolean, default=False))
+ op.add_column("variable", sa.Column("is_encrypted", sa.Boolean, default=False))
def downgrade():
- op.drop_column('variable', 'is_encrypted')
+ op.drop_column("variable", "is_encrypted")
diff --git a/airflow/migrations/versions/0015_1_7_1_rename_user_table.py b/airflow/migrations/versions/0015_1_7_1_rename_user_table.py
index 7259eb5c1821c..3b644e1453118 100644
--- a/airflow/migrations/versions/0015_1_7_1_rename_user_table.py
+++ b/airflow/migrations/versions/0015_1_7_1_rename_user_table.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Rename user table
Revision ID: 2e82aab8ef20
@@ -23,19 +22,21 @@
Create Date: 2016-04-02 19:28:15.211915
"""
+from __future__ import annotations
+
from alembic import op
# revision identifiers, used by Alembic.
-revision = '2e82aab8ef20'
-down_revision = '1968acfc09e3'
+revision = "2e82aab8ef20"
+down_revision = "1968acfc09e3"
branch_labels = None
depends_on = None
-airflow_version = '1.7.1'
+airflow_version = "1.7.1"
def upgrade():
- op.rename_table('user', 'users')
+ op.rename_table("user", "users")
def downgrade():
- op.rename_table('users', 'user')
+ op.rename_table("users", "user")
diff --git a/airflow/migrations/versions/0016_1_7_1_add_ti_state_index.py b/airflow/migrations/versions/0016_1_7_1_add_ti_state_index.py
index eae9bfd06450d..1f10af8d156fe 100644
--- a/airflow/migrations/versions/0016_1_7_1_add_ti_state_index.py
+++ b/airflow/migrations/versions/0016_1_7_1_add_ti_state_index.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Add TI state index
Revision ID: 211e584da130
@@ -23,19 +22,21 @@
Create Date: 2016-06-30 10:54:24.323588
"""
+from __future__ import annotations
+
from alembic import op
# revision identifiers, used by Alembic.
-revision = '211e584da130'
-down_revision = '2e82aab8ef20'
+revision = "211e584da130"
+down_revision = "2e82aab8ef20"
branch_labels = None
depends_on = None
-airflow_version = '1.7.1.3'
+airflow_version = "1.7.1.3"
def upgrade():
- op.create_index('ti_state', 'task_instance', ['state'], unique=False)
+ op.create_index("ti_state", "task_instance", ["state"], unique=False)
def downgrade():
- op.drop_index('ti_state', table_name='task_instance')
+ op.drop_index("ti_state", table_name="task_instance")
diff --git a/airflow/migrations/versions/0017_1_7_1_add_task_fails_journal_table.py b/airflow/migrations/versions/0017_1_7_1_add_task_fails_journal_table.py
index 679efb70a1818..45aad1d2c403f 100644
--- a/airflow/migrations/versions/0017_1_7_1_add_task_fails_journal_table.py
+++ b/airflow/migrations/versions/0017_1_7_1_add_task_fails_journal_table.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Add ``task_fail`` table
Revision ID: 64de9cddf6c9
@@ -23,32 +22,34 @@
Create Date: 2016-08-03 14:02:59.203021
"""
+from __future__ import annotations
+
import sqlalchemy as sa
from alembic import op
from airflow.migrations.db_types import StringID
# revision identifiers, used by Alembic.
-revision = '64de9cddf6c9'
-down_revision = '211e584da130'
+revision = "64de9cddf6c9"
+down_revision = "211e584da130"
branch_labels = None
depends_on = None
-airflow_version = '1.7.1.3'
+airflow_version = "1.7.1.3"
def upgrade():
op.create_table(
- 'task_fail',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('task_id', StringID(), nullable=False),
- sa.Column('dag_id', StringID(), nullable=False),
- sa.Column('execution_date', sa.DateTime(), nullable=False),
- sa.Column('start_date', sa.DateTime(), nullable=True),
- sa.Column('end_date', sa.DateTime(), nullable=True),
- sa.Column('duration', sa.Integer(), nullable=True),
- sa.PrimaryKeyConstraint('id'),
+ "task_fail",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("task_id", StringID(), nullable=False),
+ sa.Column("dag_id", StringID(), nullable=False),
+ sa.Column("execution_date", sa.DateTime(), nullable=False),
+ sa.Column("start_date", sa.DateTime(), nullable=True),
+ sa.Column("end_date", sa.DateTime(), nullable=True),
+ sa.Column("duration", sa.Integer(), nullable=True),
+ sa.PrimaryKeyConstraint("id"),
)
def downgrade():
- op.drop_table('task_fail')
+ op.drop_table("task_fail")
diff --git a/airflow/migrations/versions/0018_1_7_1_add_dag_stats_table.py b/airflow/migrations/versions/0018_1_7_1_add_dag_stats_table.py
index 9cb37d2a41cbc..19726d872e395 100644
--- a/airflow/migrations/versions/0018_1_7_1_add_dag_stats_table.py
+++ b/airflow/migrations/versions/0018_1_7_1_add_dag_stats_table.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Add ``dag_stats`` table
Revision ID: f2ca10b85618
@@ -23,29 +22,31 @@
Create Date: 2016-07-20 15:08:28.247537
"""
+from __future__ import annotations
+
import sqlalchemy as sa
from alembic import op
from airflow.migrations.db_types import StringID
# revision identifiers, used by Alembic.
-revision = 'f2ca10b85618'
-down_revision = '64de9cddf6c9'
+revision = "f2ca10b85618"
+down_revision = "64de9cddf6c9"
branch_labels = None
depends_on = None
-airflow_version = '1.7.1.3'
+airflow_version = "1.7.1.3"
def upgrade():
op.create_table(
- 'dag_stats',
- sa.Column('dag_id', StringID(), nullable=False),
- sa.Column('state', sa.String(length=50), nullable=False),
- sa.Column('count', sa.Integer(), nullable=False, default=0),
- sa.Column('dirty', sa.Boolean(), nullable=False, default=False),
- sa.PrimaryKeyConstraint('dag_id', 'state'),
+ "dag_stats",
+ sa.Column("dag_id", StringID(), nullable=False),
+ sa.Column("state", sa.String(length=50), nullable=False),
+ sa.Column("count", sa.Integer(), nullable=False, default=0),
+ sa.Column("dirty", sa.Boolean(), nullable=False, default=False),
+ sa.PrimaryKeyConstraint("dag_id", "state"),
)
def downgrade():
- op.drop_table('dag_stats')
+ op.drop_table("dag_stats")
diff --git a/airflow/migrations/versions/0019_1_7_1_add_fractional_seconds_to_mysql_tables.py b/airflow/migrations/versions/0019_1_7_1_add_fractional_seconds_to_mysql_tables.py
index 9cee531e749d5..23ec9bf1776d1 100644
--- a/airflow/migrations/versions/0019_1_7_1_add_fractional_seconds_to_mysql_tables.py
+++ b/airflow/migrations/versions/0019_1_7_1_add_fractional_seconds_to_mysql_tables.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Add fractional seconds to MySQL tables
Revision ID: 4addfa1236f1
@@ -23,100 +22,101 @@
Create Date: 2016-09-11 13:39:18.592072
"""
+from __future__ import annotations
from alembic import op
from sqlalchemy.dialects import mysql
# revision identifiers, used by Alembic.
-revision = '4addfa1236f1'
-down_revision = 'f2ca10b85618'
+revision = "4addfa1236f1"
+down_revision = "f2ca10b85618"
branch_labels = None
depends_on = None
-airflow_version = '1.7.1.3'
+airflow_version = "1.7.1.3"
def upgrade():
conn = op.get_bind()
if conn.dialect.name == "mysql":
- op.alter_column(table_name='dag', column_name='last_scheduler_run', type_=mysql.DATETIME(fsp=6))
- op.alter_column(table_name='dag', column_name='last_pickled', type_=mysql.DATETIME(fsp=6))
- op.alter_column(table_name='dag', column_name='last_expired', type_=mysql.DATETIME(fsp=6))
+ op.alter_column(table_name="dag", column_name="last_scheduler_run", type_=mysql.DATETIME(fsp=6))
+ op.alter_column(table_name="dag", column_name="last_pickled", type_=mysql.DATETIME(fsp=6))
+ op.alter_column(table_name="dag", column_name="last_expired", type_=mysql.DATETIME(fsp=6))
- op.alter_column(table_name='dag_pickle', column_name='created_dttm', type_=mysql.DATETIME(fsp=6))
+ op.alter_column(table_name="dag_pickle", column_name="created_dttm", type_=mysql.DATETIME(fsp=6))
- op.alter_column(table_name='dag_run', column_name='execution_date', type_=mysql.DATETIME(fsp=6))
- op.alter_column(table_name='dag_run', column_name='start_date', type_=mysql.DATETIME(fsp=6))
- op.alter_column(table_name='dag_run', column_name='end_date', type_=mysql.DATETIME(fsp=6))
+ op.alter_column(table_name="dag_run", column_name="execution_date", type_=mysql.DATETIME(fsp=6))
+ op.alter_column(table_name="dag_run", column_name="start_date", type_=mysql.DATETIME(fsp=6))
+ op.alter_column(table_name="dag_run", column_name="end_date", type_=mysql.DATETIME(fsp=6))
- op.alter_column(table_name='import_error', column_name='timestamp', type_=mysql.DATETIME(fsp=6))
+ op.alter_column(table_name="import_error", column_name="timestamp", type_=mysql.DATETIME(fsp=6))
- op.alter_column(table_name='job', column_name='start_date', type_=mysql.DATETIME(fsp=6))
- op.alter_column(table_name='job', column_name='end_date', type_=mysql.DATETIME(fsp=6))
- op.alter_column(table_name='job', column_name='latest_heartbeat', type_=mysql.DATETIME(fsp=6))
+ op.alter_column(table_name="job", column_name="start_date", type_=mysql.DATETIME(fsp=6))
+ op.alter_column(table_name="job", column_name="end_date", type_=mysql.DATETIME(fsp=6))
+ op.alter_column(table_name="job", column_name="latest_heartbeat", type_=mysql.DATETIME(fsp=6))
- op.alter_column(table_name='log', column_name='dttm', type_=mysql.DATETIME(fsp=6))
- op.alter_column(table_name='log', column_name='execution_date', type_=mysql.DATETIME(fsp=6))
+ op.alter_column(table_name="log", column_name="dttm", type_=mysql.DATETIME(fsp=6))
+ op.alter_column(table_name="log", column_name="execution_date", type_=mysql.DATETIME(fsp=6))
op.alter_column(
- table_name='sla_miss', column_name='execution_date', type_=mysql.DATETIME(fsp=6), nullable=False
+ table_name="sla_miss", column_name="execution_date", type_=mysql.DATETIME(fsp=6), nullable=False
)
- op.alter_column(table_name='sla_miss', column_name='timestamp', type_=mysql.DATETIME(fsp=6))
+ op.alter_column(table_name="sla_miss", column_name="timestamp", type_=mysql.DATETIME(fsp=6))
- op.alter_column(table_name='task_fail', column_name='execution_date', type_=mysql.DATETIME(fsp=6))
- op.alter_column(table_name='task_fail', column_name='start_date', type_=mysql.DATETIME(fsp=6))
- op.alter_column(table_name='task_fail', column_name='end_date', type_=mysql.DATETIME(fsp=6))
+ op.alter_column(table_name="task_fail", column_name="execution_date", type_=mysql.DATETIME(fsp=6))
+ op.alter_column(table_name="task_fail", column_name="start_date", type_=mysql.DATETIME(fsp=6))
+ op.alter_column(table_name="task_fail", column_name="end_date", type_=mysql.DATETIME(fsp=6))
op.alter_column(
- table_name='task_instance',
- column_name='execution_date',
+ table_name="task_instance",
+ column_name="execution_date",
type_=mysql.DATETIME(fsp=6),
nullable=False,
)
- op.alter_column(table_name='task_instance', column_name='start_date', type_=mysql.DATETIME(fsp=6))
- op.alter_column(table_name='task_instance', column_name='end_date', type_=mysql.DATETIME(fsp=6))
- op.alter_column(table_name='task_instance', column_name='queued_dttm', type_=mysql.DATETIME(fsp=6))
+ op.alter_column(table_name="task_instance", column_name="start_date", type_=mysql.DATETIME(fsp=6))
+ op.alter_column(table_name="task_instance", column_name="end_date", type_=mysql.DATETIME(fsp=6))
+ op.alter_column(table_name="task_instance", column_name="queued_dttm", type_=mysql.DATETIME(fsp=6))
- op.alter_column(table_name='xcom', column_name='timestamp', type_=mysql.DATETIME(fsp=6))
- op.alter_column(table_name='xcom', column_name='execution_date', type_=mysql.DATETIME(fsp=6))
+ op.alter_column(table_name="xcom", column_name="timestamp", type_=mysql.DATETIME(fsp=6))
+ op.alter_column(table_name="xcom", column_name="execution_date", type_=mysql.DATETIME(fsp=6))
def downgrade():
conn = op.get_bind()
if conn.dialect.name == "mysql":
- op.alter_column(table_name='dag', column_name='last_scheduler_run', type_=mysql.DATETIME())
- op.alter_column(table_name='dag', column_name='last_pickled', type_=mysql.DATETIME())
- op.alter_column(table_name='dag', column_name='last_expired', type_=mysql.DATETIME())
+ op.alter_column(table_name="dag", column_name="last_scheduler_run", type_=mysql.DATETIME())
+ op.alter_column(table_name="dag", column_name="last_pickled", type_=mysql.DATETIME())
+ op.alter_column(table_name="dag", column_name="last_expired", type_=mysql.DATETIME())
- op.alter_column(table_name='dag_pickle', column_name='created_dttm', type_=mysql.DATETIME())
+ op.alter_column(table_name="dag_pickle", column_name="created_dttm", type_=mysql.DATETIME())
- op.alter_column(table_name='dag_run', column_name='execution_date', type_=mysql.DATETIME())
- op.alter_column(table_name='dag_run', column_name='start_date', type_=mysql.DATETIME())
- op.alter_column(table_name='dag_run', column_name='end_date', type_=mysql.DATETIME())
+ op.alter_column(table_name="dag_run", column_name="execution_date", type_=mysql.DATETIME())
+ op.alter_column(table_name="dag_run", column_name="start_date", type_=mysql.DATETIME())
+ op.alter_column(table_name="dag_run", column_name="end_date", type_=mysql.DATETIME())
- op.alter_column(table_name='import_error', column_name='timestamp', type_=mysql.DATETIME())
+ op.alter_column(table_name="import_error", column_name="timestamp", type_=mysql.DATETIME())
- op.alter_column(table_name='job', column_name='start_date', type_=mysql.DATETIME())
- op.alter_column(table_name='job', column_name='end_date', type_=mysql.DATETIME())
- op.alter_column(table_name='job', column_name='latest_heartbeat', type_=mysql.DATETIME())
+ op.alter_column(table_name="job", column_name="start_date", type_=mysql.DATETIME())
+ op.alter_column(table_name="job", column_name="end_date", type_=mysql.DATETIME())
+ op.alter_column(table_name="job", column_name="latest_heartbeat", type_=mysql.DATETIME())
- op.alter_column(table_name='log', column_name='dttm', type_=mysql.DATETIME())
- op.alter_column(table_name='log', column_name='execution_date', type_=mysql.DATETIME())
+ op.alter_column(table_name="log", column_name="dttm", type_=mysql.DATETIME())
+ op.alter_column(table_name="log", column_name="execution_date", type_=mysql.DATETIME())
op.alter_column(
- table_name='sla_miss', column_name='execution_date', type_=mysql.DATETIME(), nullable=False
+ table_name="sla_miss", column_name="execution_date", type_=mysql.DATETIME(), nullable=False
)
- op.alter_column(table_name='sla_miss', column_name='timestamp', type_=mysql.DATETIME())
+ op.alter_column(table_name="sla_miss", column_name="timestamp", type_=mysql.DATETIME())
- op.alter_column(table_name='task_fail', column_name='execution_date', type_=mysql.DATETIME())
- op.alter_column(table_name='task_fail', column_name='start_date', type_=mysql.DATETIME())
- op.alter_column(table_name='task_fail', column_name='end_date', type_=mysql.DATETIME())
+ op.alter_column(table_name="task_fail", column_name="execution_date", type_=mysql.DATETIME())
+ op.alter_column(table_name="task_fail", column_name="start_date", type_=mysql.DATETIME())
+ op.alter_column(table_name="task_fail", column_name="end_date", type_=mysql.DATETIME())
op.alter_column(
- table_name='task_instance', column_name='execution_date', type_=mysql.DATETIME(), nullable=False
+ table_name="task_instance", column_name="execution_date", type_=mysql.DATETIME(), nullable=False
)
- op.alter_column(table_name='task_instance', column_name='start_date', type_=mysql.DATETIME())
- op.alter_column(table_name='task_instance', column_name='end_date', type_=mysql.DATETIME())
- op.alter_column(table_name='task_instance', column_name='queued_dttm', type_=mysql.DATETIME())
+ op.alter_column(table_name="task_instance", column_name="start_date", type_=mysql.DATETIME())
+ op.alter_column(table_name="task_instance", column_name="end_date", type_=mysql.DATETIME())
+ op.alter_column(table_name="task_instance", column_name="queued_dttm", type_=mysql.DATETIME())
- op.alter_column(table_name='xcom', column_name='timestamp', type_=mysql.DATETIME())
- op.alter_column(table_name='xcom', column_name='execution_date', type_=mysql.DATETIME())
+ op.alter_column(table_name="xcom", column_name="timestamp", type_=mysql.DATETIME())
+ op.alter_column(table_name="xcom", column_name="execution_date", type_=mysql.DATETIME())
diff --git a/airflow/migrations/versions/0020_1_7_1_xcom_dag_task_indices.py b/airflow/migrations/versions/0020_1_7_1_xcom_dag_task_indices.py
index ea2aba8e32620..d162a5872ccc4 100644
--- a/airflow/migrations/versions/0020_1_7_1_xcom_dag_task_indices.py
+++ b/airflow/migrations/versions/0020_1_7_1_xcom_dag_task_indices.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Add indices on ``xcom`` table
Revision ID: 8504051e801b
@@ -23,22 +22,23 @@
Create Date: 2016-11-29 08:13:03.253312
"""
+from __future__ import annotations
from alembic import op
# revision identifiers, used by Alembic.
-revision = '8504051e801b'
-down_revision = '4addfa1236f1'
+revision = "8504051e801b"
+down_revision = "4addfa1236f1"
branch_labels = None
depends_on = None
-airflow_version = '1.7.1.3'
+airflow_version = "1.7.1.3"
def upgrade():
"""Create Index."""
- op.create_index('idx_xcom_dag_task_date', 'xcom', ['dag_id', 'task_id', 'execution_date'], unique=False)
+ op.create_index("idx_xcom_dag_task_date", "xcom", ["dag_id", "task_id", "execution_date"], unique=False)
def downgrade():
"""Drop Index."""
- op.drop_index('idx_xcom_dag_task_date', table_name='xcom')
+ op.drop_index("idx_xcom_dag_task_date", table_name="xcom")
diff --git a/airflow/migrations/versions/0021_1_7_1_add_pid_field_to_taskinstance.py b/airflow/migrations/versions/0021_1_7_1_add_pid_field_to_taskinstance.py
index cea1d1010a888..0eb7e9bf730d4 100644
--- a/airflow/migrations/versions/0021_1_7_1_add_pid_field_to_taskinstance.py
+++ b/airflow/migrations/versions/0021_1_7_1_add_pid_field_to_taskinstance.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Add ``pid`` field to ``TaskInstance``
Revision ID: 5e7d17757c7a
@@ -23,23 +22,24 @@
Create Date: 2016-12-07 15:51:37.119478
"""
+from __future__ import annotations
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
-revision = '5e7d17757c7a'
-down_revision = '8504051e801b'
+revision = "5e7d17757c7a"
+down_revision = "8504051e801b"
branch_labels = None
depends_on = None
-airflow_version = '1.7.1.3'
+airflow_version = "1.7.1.3"
def upgrade():
"""Add pid column to task_instance table."""
- op.add_column('task_instance', sa.Column('pid', sa.Integer))
+ op.add_column("task_instance", sa.Column("pid", sa.Integer))
def downgrade():
"""Drop pid column from task_instance table."""
- op.drop_column('task_instance', 'pid')
+ op.drop_column("task_instance", "pid")
diff --git a/airflow/migrations/versions/0022_1_7_1_add_dag_id_state_index_on_dag_run_table.py b/airflow/migrations/versions/0022_1_7_1_add_dag_id_state_index_on_dag_run_table.py
index a7acdb2e7b187..bdf070b739f15 100644
--- a/airflow/migrations/versions/0022_1_7_1_add_dag_id_state_index_on_dag_run_table.py
+++ b/airflow/migrations/versions/0022_1_7_1_add_dag_id_state_index_on_dag_run_table.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Add ``dag_id``/``state`` index on ``dag_run`` table
Revision ID: 127d2bf2dfa7
@@ -23,19 +22,21 @@
Create Date: 2017-01-25 11:43:51.635667
"""
+from __future__ import annotations
+
from alembic import op
# revision identifiers, used by Alembic.
-revision = '127d2bf2dfa7'
-down_revision = '5e7d17757c7a'
+revision = "127d2bf2dfa7"
+down_revision = "5e7d17757c7a"
branch_labels = None
depends_on = None
-airflow_version = '1.7.1.3'
+airflow_version = "1.7.1.3"
def upgrade():
- op.create_index('dag_id_state', 'dag_run', ['dag_id', 'state'], unique=False)
+ op.create_index("dag_id_state", "dag_run", ["dag_id", "state"], unique=False)
def downgrade():
- op.drop_index('dag_id_state', table_name='dag_run')
+ op.drop_index("dag_id_state", table_name="dag_run")
diff --git a/airflow/migrations/versions/0023_1_8_2_add_max_tries_column_to_task_instance.py b/airflow/migrations/versions/0023_1_8_2_add_max_tries_column_to_task_instance.py
index 7685b77afd04d..c6eb6222dc190 100644
--- a/airflow/migrations/versions/0023_1_8_2_add_max_tries_column_to_task_instance.py
+++ b/airflow/migrations/versions/0023_1_8_2_add_max_tries_column_to_task_instance.py
@@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Add ``max_tries`` column to ``task_instance``
Revision ID: cc1e65623dc7
@@ -22,22 +21,22 @@
Create Date: 2017-06-19 16:53:12.851141
"""
+from __future__ import annotations
import sqlalchemy as sa
from alembic import op
-from sqlalchemy import Column, Integer, String
+from sqlalchemy import Column, Integer, String, inspect
from sqlalchemy.ext.declarative import declarative_base
from airflow import settings
-from airflow.compat.sqlalchemy import inspect
from airflow.models import DagBag
# revision identifiers, used by Alembic.
-revision = 'cc1e65623dc7'
-down_revision = '127d2bf2dfa7'
+revision = "cc1e65623dc7"
+down_revision = "127d2bf2dfa7"
branch_labels = None
depends_on = None
-airflow_version = '1.8.2'
+airflow_version = "1.8.2"
Base = declarative_base()
BATCH_SIZE = 5000
@@ -56,7 +55,7 @@ class TaskInstance(Base): # type: ignore
def upgrade():
- op.add_column('task_instance', sa.Column('max_tries', sa.Integer, server_default="-1"))
+ op.add_column("task_instance", sa.Column("max_tries", sa.Integer, server_default="-1"))
# Check if table task_instance exist before data migration. This check is
# needed for database that does not create table until migration finishes.
# Checking task_instance table exists prevent the error of querying
@@ -65,7 +64,7 @@ def upgrade():
inspector = inspect(connection)
tables = inspector.get_table_names()
- if 'task_instance' in tables:
+ if "task_instance" in tables:
# Get current session
sessionmaker = sa.orm.sessionmaker()
session = sessionmaker(bind=connection)
@@ -102,7 +101,7 @@ def upgrade():
def downgrade():
engine = settings.engine
connection = op.get_bind()
- if engine.dialect.has_table(connection, 'task_instance'):
+ if engine.dialect.has_table(connection, "task_instance"):
sessionmaker = sa.orm.sessionmaker()
session = sessionmaker(bind=connection)
dagbag = DagBag(settings.DAGS_FOLDER)
@@ -124,4 +123,4 @@ def downgrade():
session.merge(ti)
session.commit()
session.commit()
- op.drop_column('task_instance', 'max_tries')
+ op.drop_column("task_instance", "max_tries")
diff --git a/airflow/migrations/versions/0024_1_8_2_make_xcom_value_column_a_large_binary.py b/airflow/migrations/versions/0024_1_8_2_make_xcom_value_column_a_large_binary.py
index 00a21bf543125..4512918b5ed7d 100644
--- a/airflow/migrations/versions/0024_1_8_2_make_xcom_value_column_a_large_binary.py
+++ b/airflow/migrations/versions/0024_1_8_2_make_xcom_value_column_a_large_binary.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Make xcom value column a large binary
Revision ID: bdaa763e6c56
@@ -23,16 +22,18 @@
Create Date: 2017-08-14 16:06:31.568971
"""
+from __future__ import annotations
+
import dill
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
-revision = 'bdaa763e6c56'
-down_revision = 'cc1e65623dc7'
+revision = "bdaa763e6c56"
+down_revision = "cc1e65623dc7"
branch_labels = None
depends_on = None
-airflow_version = '1.8.2'
+airflow_version = "1.8.2"
def upgrade():
@@ -40,10 +41,10 @@ def upgrade():
# type.
# use batch_alter_table to support SQLite workaround
with op.batch_alter_table("xcom") as batch_op:
- batch_op.alter_column('value', type_=sa.LargeBinary())
+ batch_op.alter_column("value", type_=sa.LargeBinary())
def downgrade():
# use batch_alter_table to support SQLite workaround
with op.batch_alter_table("xcom") as batch_op:
- batch_op.alter_column('value', type_=sa.PickleType(pickler=dill))
+ batch_op.alter_column("value", type_=sa.PickleType(pickler=dill))
diff --git a/airflow/migrations/versions/0025_1_8_2_add_ti_job_id_index.py b/airflow/migrations/versions/0025_1_8_2_add_ti_job_id_index.py
index 364e57d50360d..79f148ff17acf 100644
--- a/airflow/migrations/versions/0025_1_8_2_add_ti_job_id_index.py
+++ b/airflow/migrations/versions/0025_1_8_2_add_ti_job_id_index.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Create index on ``job_id`` column in ``task_instance`` table
Revision ID: 947454bf1dff
@@ -23,19 +22,21 @@
Create Date: 2017-08-15 15:12:13.845074
"""
+from __future__ import annotations
+
from alembic import op
# revision identifiers, used by Alembic.
-revision = '947454bf1dff'
-down_revision = 'bdaa763e6c56'
+revision = "947454bf1dff"
+down_revision = "bdaa763e6c56"
branch_labels = None
depends_on = None
-airflow_version = '1.8.2'
+airflow_version = "1.8.2"
def upgrade():
- op.create_index('ti_job_id', 'task_instance', ['job_id'], unique=False)
+ op.create_index("ti_job_id", "task_instance", ["job_id"], unique=False)
def downgrade():
- op.drop_index('ti_job_id', table_name='task_instance')
+ op.drop_index("ti_job_id", table_name="task_instance")
diff --git a/airflow/migrations/versions/0026_1_8_2_increase_text_size_for_mysql.py b/airflow/migrations/versions/0026_1_8_2_increase_text_size_for_mysql.py
index 41450aec0f9c2..711ba3662a4a9 100644
--- a/airflow/migrations/versions/0026_1_8_2_increase_text_size_for_mysql.py
+++ b/airflow/migrations/versions/0026_1_8_2_increase_text_size_for_mysql.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Increase text size for MySQL (not relevant for other DBs' text types)
Revision ID: d2ae31099d61
@@ -23,24 +22,26 @@
Create Date: 2017-08-18 17:07:16.686130
"""
+from __future__ import annotations
+
from alembic import op
from sqlalchemy.dialects import mysql
# revision identifiers, used by Alembic.
-revision = 'd2ae31099d61'
-down_revision = '947454bf1dff'
+revision = "d2ae31099d61"
+down_revision = "947454bf1dff"
branch_labels = None
depends_on = None
-airflow_version = '1.8.2'
+airflow_version = "1.8.2"
def upgrade():
conn = op.get_bind()
if conn.dialect.name == "mysql":
- op.alter_column(table_name='variable', column_name='val', type_=mysql.MEDIUMTEXT)
+ op.alter_column(table_name="variable", column_name="val", type_=mysql.MEDIUMTEXT)
def downgrade():
conn = op.get_bind()
if conn.dialect.name == "mysql":
- op.alter_column(table_name='variable', column_name='val', type_=mysql.TEXT)
+ op.alter_column(table_name="variable", column_name="val", type_=mysql.TEXT)
diff --git a/airflow/migrations/versions/0027_1_10_0_add_time_zone_awareness.py b/airflow/migrations/versions/0027_1_10_0_add_time_zone_awareness.py
index 8945ff06c99ca..2b36361c5b87e 100644
--- a/airflow/migrations/versions/0027_1_10_0_add_time_zone_awareness.py
+++ b/airflow/migrations/versions/0027_1_10_0_add_time_zone_awareness.py
@@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Add time zone awareness
Revision ID: 0e2a74e0fc9f
@@ -22,6 +21,7 @@
Create Date: 2017-11-10 22:22:31.326152
"""
+from __future__ import annotations
import sqlalchemy as sa
from alembic import op
@@ -32,7 +32,7 @@
down_revision = "d2ae31099d61"
branch_labels = None
depends_on = None
-airflow_version = '1.10.0'
+airflow_version = "1.10.0"
def upgrade():
diff --git a/airflow/migrations/versions/0028_1_10_0_add_kubernetes_resource_checkpointing.py b/airflow/migrations/versions/0028_1_10_0_add_kubernetes_resource_checkpointing.py
index c37831ac05c0d..8ed7187a508fa 100644
--- a/airflow/migrations/versions/0028_1_10_0_add_kubernetes_resource_checkpointing.py
+++ b/airflow/migrations/versions/0028_1_10_0_add_kubernetes_resource_checkpointing.py
@@ -14,8 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
-
"""Add Kubernetes resource check-pointing
Revision ID: 33ae817a1ff4
@@ -23,17 +21,18 @@
Create Date: 2017-09-11 15:26:47.598494
"""
+from __future__ import annotations
+
import sqlalchemy as sa
from alembic import op
-
-from airflow.compat.sqlalchemy import inspect
+from sqlalchemy import inspect
# revision identifiers, used by Alembic.
-revision = '33ae817a1ff4'
-down_revision = 'd2ae31099d61'
+revision = "33ae817a1ff4"
+down_revision = "d2ae31099d61"
branch_labels = None
depends_on = None
-airflow_version = '1.10.0'
+airflow_version = "1.10.0"
RESOURCE_TABLE = "kube_resource_version"
diff --git a/airflow/migrations/versions/0029_1_10_0_add_executor_config_to_task_instance.py b/airflow/migrations/versions/0029_1_10_0_add_executor_config_to_task_instance.py
index 67bda056ecb51..9808c2c882363 100644
--- a/airflow/migrations/versions/0029_1_10_0_add_executor_config_to_task_instance.py
+++ b/airflow/migrations/versions/0029_1_10_0_add_executor_config_to_task_instance.py
@@ -14,8 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
-
"""Add ``executor_config`` column to ``task_instance`` table
Revision ID: 33ae817a1ff4
@@ -23,17 +21,18 @@
Create Date: 2017-09-11 15:26:47.598494
"""
+from __future__ import annotations
import dill
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
-revision = '27c6a30d7c24'
-down_revision = '33ae817a1ff4'
+revision = "27c6a30d7c24"
+down_revision = "33ae817a1ff4"
branch_labels = None
depends_on = None
-airflow_version = '1.10.0'
+airflow_version = "1.10.0"
TASK_INSTANCE_TABLE = "task_instance"
NEW_COLUMN = "executor_config"
diff --git a/airflow/migrations/versions/0030_1_10_0_add_kubernetes_scheduler_uniqueness.py b/airflow/migrations/versions/0030_1_10_0_add_kubernetes_scheduler_uniqueness.py
index 4dd1864580a9d..c90934576e8c7 100644
--- a/airflow/migrations/versions/0030_1_10_0_add_kubernetes_scheduler_uniqueness.py
+++ b/airflow/migrations/versions/0030_1_10_0_add_kubernetes_scheduler_uniqueness.py
@@ -14,8 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
-
"""Add kubernetes scheduler uniqueness
Revision ID: 86770d1215c0
@@ -23,15 +21,17 @@
Create Date: 2018-04-03 15:31:20.814328
"""
+from __future__ import annotations
+
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
-revision = '86770d1215c0'
-down_revision = '27c6a30d7c24'
+revision = "86770d1215c0"
+down_revision = "27c6a30d7c24"
branch_labels = None
depends_on = None
-airflow_version = '1.10.0'
+airflow_version = "1.10.0"
RESOURCE_TABLE = "kube_worker_uuid"
diff --git a/airflow/migrations/versions/0031_1_10_0_merge_heads.py b/airflow/migrations/versions/0031_1_10_0_merge_heads.py
index 33691b7b04ec3..2f9f404351ae0 100644
--- a/airflow/migrations/versions/0031_1_10_0_merge_heads.py
+++ b/airflow/migrations/versions/0031_1_10_0_merge_heads.py
@@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Merge migrations Heads
Revision ID: 05f30312d566
@@ -22,13 +21,14 @@
Create Date: 2018-06-17 10:47:23.339972
"""
+from __future__ import annotations
# revision identifiers, used by Alembic.
-revision = '05f30312d566'
-down_revision = ('86770d1215c0', '0e2a74e0fc9f')
+revision = "05f30312d566"
+down_revision = ("86770d1215c0", "0e2a74e0fc9f")
branch_labels = None
depends_on = None
-airflow_version = '1.10.0'
+airflow_version = "1.10.0"
def upgrade():
diff --git a/airflow/migrations/versions/0032_1_10_0_fix_mysql_not_null_constraint.py b/airflow/migrations/versions/0032_1_10_0_fix_mysql_not_null_constraint.py
index 765d43762e1aa..c5d0eecb9eb9a 100644
--- a/airflow/migrations/versions/0032_1_10_0_fix_mysql_not_null_constraint.py
+++ b/airflow/migrations/versions/0032_1_10_0_fix_mysql_not_null_constraint.py
@@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Fix MySQL not null constraint
Revision ID: f23433877c24
@@ -22,30 +21,32 @@
Create Date: 2018-06-17 10:16:31.412131
"""
+from __future__ import annotations
+
from alembic import op
from sqlalchemy.dialects import mysql
# revision identifiers, used by Alembic.
-revision = 'f23433877c24'
-down_revision = '05f30312d566'
+revision = "f23433877c24"
+down_revision = "05f30312d566"
branch_labels = None
depends_on = None
-airflow_version = '1.10.0'
+airflow_version = "1.10.0"
def upgrade():
conn = op.get_bind()
- if conn.dialect.name == 'mysql':
+ if conn.dialect.name == "mysql":
conn.execute("SET time_zone = '+00:00'")
- op.alter_column('task_fail', 'execution_date', existing_type=mysql.TIMESTAMP(fsp=6), nullable=False)
- op.alter_column('xcom', 'execution_date', existing_type=mysql.TIMESTAMP(fsp=6), nullable=False)
- op.alter_column('xcom', 'timestamp', existing_type=mysql.TIMESTAMP(fsp=6), nullable=False)
+ op.alter_column("task_fail", "execution_date", existing_type=mysql.TIMESTAMP(fsp=6), nullable=False)
+ op.alter_column("xcom", "execution_date", existing_type=mysql.TIMESTAMP(fsp=6), nullable=False)
+ op.alter_column("xcom", "timestamp", existing_type=mysql.TIMESTAMP(fsp=6), nullable=False)
def downgrade():
conn = op.get_bind()
- if conn.dialect.name == 'mysql':
+ if conn.dialect.name == "mysql":
conn.execute("SET time_zone = '+00:00'")
- op.alter_column('xcom', 'timestamp', existing_type=mysql.TIMESTAMP(fsp=6), nullable=True)
- op.alter_column('xcom', 'execution_date', existing_type=mysql.TIMESTAMP(fsp=6), nullable=True)
- op.alter_column('task_fail', 'execution_date', existing_type=mysql.TIMESTAMP(fsp=6), nullable=True)
+ op.alter_column("xcom", "timestamp", existing_type=mysql.TIMESTAMP(fsp=6), nullable=True)
+ op.alter_column("xcom", "execution_date", existing_type=mysql.TIMESTAMP(fsp=6), nullable=True)
+ op.alter_column("task_fail", "execution_date", existing_type=mysql.TIMESTAMP(fsp=6), nullable=True)
diff --git a/airflow/migrations/versions/0033_1_10_0_fix_sqlite_foreign_key.py b/airflow/migrations/versions/0033_1_10_0_fix_sqlite_foreign_key.py
index abf289572fbda..f13891b4d4079 100644
--- a/airflow/migrations/versions/0033_1_10_0_fix_sqlite_foreign_key.py
+++ b/airflow/migrations/versions/0033_1_10_0_fix_sqlite_foreign_key.py
@@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Fix Sqlite foreign key
Revision ID: 856955da8476
@@ -22,22 +21,23 @@
Create Date: 2018-06-17 15:54:53.844230
"""
+from __future__ import annotations
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
-revision = '856955da8476'
-down_revision = 'f23433877c24'
+revision = "856955da8476"
+down_revision = "f23433877c24"
branch_labels = None
depends_on = None
-airflow_version = '1.10.0'
+airflow_version = "1.10.0"
def upgrade():
"""Fix broken foreign-key constraint for existing SQLite DBs."""
conn = op.get_bind()
- if conn.dialect.name == 'sqlite':
+ if conn.dialect.name == "sqlite":
# Fix broken foreign-key constraint for existing SQLite DBs.
#
# Re-define tables and use copy_from to avoid reflection
@@ -45,27 +45,27 @@ def upgrade():
#
# Use batch_alter_table to support SQLite workaround.
chart_table = sa.Table(
- 'chart',
+ "chart",
sa.MetaData(),
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('label', sa.String(length=200), nullable=True),
- sa.Column('conn_id', sa.String(length=250), nullable=False),
- sa.Column('user_id', sa.Integer(), nullable=True),
- sa.Column('chart_type', sa.String(length=100), nullable=True),
- sa.Column('sql_layout', sa.String(length=50), nullable=True),
- sa.Column('sql', sa.Text(), nullable=True),
- sa.Column('y_log_scale', sa.Boolean(), nullable=True),
- sa.Column('show_datatable', sa.Boolean(), nullable=True),
- sa.Column('show_sql', sa.Boolean(), nullable=True),
- sa.Column('height', sa.Integer(), nullable=True),
- sa.Column('default_params', sa.String(length=5000), nullable=True),
- sa.Column('x_is_date', sa.Boolean(), nullable=True),
- sa.Column('iteration_no', sa.Integer(), nullable=True),
- sa.Column('last_modified', sa.DateTime(), nullable=True),
- sa.PrimaryKeyConstraint('id'),
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("label", sa.String(length=200), nullable=True),
+ sa.Column("conn_id", sa.String(length=250), nullable=False),
+ sa.Column("user_id", sa.Integer(), nullable=True),
+ sa.Column("chart_type", sa.String(length=100), nullable=True),
+ sa.Column("sql_layout", sa.String(length=50), nullable=True),
+ sa.Column("sql", sa.Text(), nullable=True),
+ sa.Column("y_log_scale", sa.Boolean(), nullable=True),
+ sa.Column("show_datatable", sa.Boolean(), nullable=True),
+ sa.Column("show_sql", sa.Boolean(), nullable=True),
+ sa.Column("height", sa.Integer(), nullable=True),
+ sa.Column("default_params", sa.String(length=5000), nullable=True),
+ sa.Column("x_is_date", sa.Boolean(), nullable=True),
+ sa.Column("iteration_no", sa.Integer(), nullable=True),
+ sa.Column("last_modified", sa.DateTime(), nullable=True),
+ sa.PrimaryKeyConstraint("id"),
)
- with op.batch_alter_table('chart', copy_from=chart_table) as batch_op:
- batch_op.create_foreign_key('chart_user_id_fkey', 'users', ['user_id'], ['id'])
+ with op.batch_alter_table("chart", copy_from=chart_table) as batch_op:
+ batch_op.create_foreign_key("chart_user_id_fkey", "users", ["user_id"], ["id"])
def downgrade():
diff --git a/airflow/migrations/versions/0034_1_10_0_index_taskfail.py b/airflow/migrations/versions/0034_1_10_0_index_taskfail.py
index 1323e77f7f040..084cb0d925ac9 100644
--- a/airflow/migrations/versions/0034_1_10_0_index_taskfail.py
+++ b/airflow/migrations/versions/0034_1_10_0_index_taskfail.py
@@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Create index on ``task_fail`` table
Revision ID: 9635ae0956e7
@@ -22,21 +21,23 @@
Create Date: 2018-06-17 21:40:01.963540
"""
+from __future__ import annotations
+
from alembic import op
# revision identifiers, used by Alembic.
-revision = '9635ae0956e7'
-down_revision = '856955da8476'
+revision = "9635ae0956e7"
+down_revision = "856955da8476"
branch_labels = None
depends_on = None
-airflow_version = '1.10.0'
+airflow_version = "1.10.0"
def upgrade():
op.create_index(
- 'idx_task_fail_dag_task_date', 'task_fail', ['dag_id', 'task_id', 'execution_date'], unique=False
+ "idx_task_fail_dag_task_date", "task_fail", ["dag_id", "task_id", "execution_date"], unique=False
)
def downgrade():
- op.drop_index('idx_task_fail_dag_task_date', table_name='task_fail')
+ op.drop_index("idx_task_fail_dag_task_date", table_name="task_fail")
diff --git a/airflow/migrations/versions/0035_1_10_2_add_idx_log_dag.py b/airflow/migrations/versions/0035_1_10_2_add_idx_log_dag.py
index 76fcfd375b9e2..149c6ca8d52e3 100644
--- a/airflow/migrations/versions/0035_1_10_2_add_idx_log_dag.py
+++ b/airflow/migrations/versions/0035_1_10_2_add_idx_log_dag.py
@@ -22,19 +22,21 @@
Create Date: 2018-08-07 06:41:41.028249
"""
+from __future__ import annotations
+
from alembic import op
# revision identifiers, used by Alembic.
-revision = 'dd25f486b8ea'
-down_revision = '9635ae0956e7'
+revision = "dd25f486b8ea"
+down_revision = "9635ae0956e7"
branch_labels = None
depends_on = None
-airflow_version = '1.10.2'
+airflow_version = "1.10.2"
def upgrade():
- op.create_index('idx_log_dag', 'log', ['dag_id'], unique=False)
+ op.create_index("idx_log_dag", "log", ["dag_id"], unique=False)
def downgrade():
- op.drop_index('idx_log_dag', table_name='log')
+ op.drop_index("idx_log_dag", table_name="log")
diff --git a/airflow/migrations/versions/0036_1_10_2_add_index_to_taskinstance.py b/airflow/migrations/versions/0036_1_10_2_add_index_to_taskinstance.py
index a52663dcc9dea..f0453d0e8fb36 100644
--- a/airflow/migrations/versions/0036_1_10_2_add_index_to_taskinstance.py
+++ b/airflow/migrations/versions/0036_1_10_2_add_index_to_taskinstance.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Add index to ``task_instance`` table
Revision ID: bf00311e1990
@@ -23,20 +22,21 @@
Create Date: 2018-09-12 09:53:52.007433
"""
+from __future__ import annotations
from alembic import op
# revision identifiers, used by Alembic.
-revision = 'bf00311e1990'
-down_revision = 'dd25f486b8ea'
+revision = "bf00311e1990"
+down_revision = "dd25f486b8ea"
branch_labels = None
depends_on = None
-airflow_version = '1.10.2'
+airflow_version = "1.10.2"
def upgrade():
- op.create_index('ti_dag_date', 'task_instance', ['dag_id', 'execution_date'], unique=False)
+ op.create_index("ti_dag_date", "task_instance", ["dag_id", "execution_date"], unique=False)
def downgrade():
- op.drop_index('ti_dag_date', table_name='task_instance')
+ op.drop_index("ti_dag_date", table_name="task_instance")
diff --git a/airflow/migrations/versions/0037_1_10_2_add_task_reschedule_table.py b/airflow/migrations/versions/0037_1_10_2_add_task_reschedule_table.py
index a26204fff122e..c30916725f154 100644
--- a/airflow/migrations/versions/0037_1_10_2_add_task_reschedule_table.py
+++ b/airflow/migrations/versions/0037_1_10_2_add_task_reschedule_table.py
@@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Add ``task_reschedule`` table
Revision ID: 0a2a5b66e19d
@@ -22,49 +21,51 @@
Create Date: 2018-06-17 22:50:00.053620
"""
+from __future__ import annotations
+
import sqlalchemy as sa
from alembic import op
from airflow.migrations.db_types import TIMESTAMP, StringID
# revision identifiers, used by Alembic.
-revision = '0a2a5b66e19d'
-down_revision = '9635ae0956e7'
+revision = "0a2a5b66e19d"
+down_revision = "9635ae0956e7"
branch_labels = None
depends_on = None
-airflow_version = '1.10.2'
+airflow_version = "1.10.2"
-TABLE_NAME = 'task_reschedule'
-INDEX_NAME = 'idx_' + TABLE_NAME + '_dag_task_date'
+TABLE_NAME = "task_reschedule"
+INDEX_NAME = "idx_" + TABLE_NAME + "_dag_task_date"
def upgrade():
# See 0e2a74e0fc9f_add_time_zone_awareness
timestamp = TIMESTAMP
- if op.get_bind().dialect.name == 'mssql':
+ if op.get_bind().dialect.name == "mssql":
# We need to keep this as it was for this old migration on mssql
timestamp = sa.DateTime()
op.create_table(
TABLE_NAME,
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('task_id', StringID(), nullable=False),
- sa.Column('dag_id', StringID(), nullable=False),
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("task_id", StringID(), nullable=False),
+ sa.Column("dag_id", StringID(), nullable=False),
# use explicit server_default=None otherwise mysql implies defaults for first timestamp column
- sa.Column('execution_date', timestamp, nullable=False, server_default=None),
- sa.Column('try_number', sa.Integer(), nullable=False),
- sa.Column('start_date', timestamp, nullable=False),
- sa.Column('end_date', timestamp, nullable=False),
- sa.Column('duration', sa.Integer(), nullable=False),
- sa.Column('reschedule_date', timestamp, nullable=False),
- sa.PrimaryKeyConstraint('id'),
+ sa.Column("execution_date", timestamp, nullable=False, server_default=None),
+ sa.Column("try_number", sa.Integer(), nullable=False),
+ sa.Column("start_date", timestamp, nullable=False),
+ sa.Column("end_date", timestamp, nullable=False),
+ sa.Column("duration", sa.Integer(), nullable=False),
+ sa.Column("reschedule_date", timestamp, nullable=False),
+ sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(
- ['task_id', 'dag_id', 'execution_date'],
- ['task_instance.task_id', 'task_instance.dag_id', 'task_instance.execution_date'],
- name='task_reschedule_dag_task_date_fkey',
+ ["task_id", "dag_id", "execution_date"],
+ ["task_instance.task_id", "task_instance.dag_id", "task_instance.execution_date"],
+ name="task_reschedule_dag_task_date_fkey",
),
)
- op.create_index(INDEX_NAME, TABLE_NAME, ['dag_id', 'task_id', 'execution_date'], unique=False)
+ op.create_index(INDEX_NAME, TABLE_NAME, ["dag_id", "task_id", "execution_date"], unique=False)
def downgrade():
diff --git a/airflow/migrations/versions/0038_1_10_2_add_sm_dag_index.py b/airflow/migrations/versions/0038_1_10_2_add_sm_dag_index.py
index a4bf9a03e1985..f3188901e46d6 100644
--- a/airflow/migrations/versions/0038_1_10_2_add_sm_dag_index.py
+++ b/airflow/migrations/versions/0038_1_10_2_add_sm_dag_index.py
@@ -14,28 +14,28 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
-"""Merge migrations Heads
+"""Merge migrations Heads.
Revision ID: 03bc53e68815
Revises: 0a2a5b66e19d, bf00311e1990
Create Date: 2018-11-24 20:21:46.605414
"""
+from __future__ import annotations
from alembic import op
# revision identifiers, used by Alembic.
-revision = '03bc53e68815'
-down_revision = ('0a2a5b66e19d', 'bf00311e1990')
+revision = "03bc53e68815"
+down_revision = ("0a2a5b66e19d", "bf00311e1990")
branch_labels = None
depends_on = None
-airflow_version = '1.10.2'
+airflow_version = "1.10.2"
def upgrade():
- op.create_index('sm_dag', 'sla_miss', ['dag_id'], unique=False)
+ op.create_index("sm_dag", "sla_miss", ["dag_id"], unique=False)
def downgrade():
- op.drop_index('sm_dag', table_name='sla_miss')
+ op.drop_index("sm_dag", table_name="sla_miss")
diff --git a/airflow/migrations/versions/0039_1_10_2_add_superuser_field.py b/airflow/migrations/versions/0039_1_10_2_add_superuser_field.py
index 11ea2938deb7d..314be0bba8c65 100644
--- a/airflow/migrations/versions/0039_1_10_2_add_superuser_field.py
+++ b/airflow/migrations/versions/0039_1_10_2_add_superuser_field.py
@@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Add superuser field
Revision ID: 41f5f12752f8
@@ -22,21 +21,22 @@
Create Date: 2018-12-04 15:50:04.456875
"""
+from __future__ import annotations
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
-revision = '41f5f12752f8'
-down_revision = '03bc53e68815'
+revision = "41f5f12752f8"
+down_revision = "03bc53e68815"
branch_labels = None
depends_on = None
-airflow_version = '1.10.2'
+airflow_version = "1.10.2"
def upgrade():
- op.add_column('users', sa.Column('superuser', sa.Boolean(), default=False))
+ op.add_column("users", sa.Column("superuser", sa.Boolean(), default=False))
def downgrade():
- op.drop_column('users', 'superuser')
+ op.drop_column("users", "superuser")
diff --git a/airflow/migrations/versions/0040_1_10_3_add_fields_to_dag.py b/airflow/migrations/versions/0040_1_10_3_add_fields_to_dag.py
index a1eb00f14fb52..2c9c2d04a1bb1 100644
--- a/airflow/migrations/versions/0040_1_10_3_add_fields_to_dag.py
+++ b/airflow/migrations/versions/0040_1_10_3_add_fields_to_dag.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Add ``description`` and ``default_view`` column to ``dag`` table
Revision ID: c8ffec048a3b
@@ -23,23 +22,24 @@
Create Date: 2018-12-23 21:55:46.463634
"""
+from __future__ import annotations
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
-revision = 'c8ffec048a3b'
-down_revision = '41f5f12752f8'
+revision = "c8ffec048a3b"
+down_revision = "41f5f12752f8"
branch_labels = None
depends_on = None
-airflow_version = '1.10.3'
+airflow_version = "1.10.3"
def upgrade():
- op.add_column('dag', sa.Column('description', sa.Text(), nullable=True))
- op.add_column('dag', sa.Column('default_view', sa.String(25), nullable=True))
+ op.add_column("dag", sa.Column("description", sa.Text(), nullable=True))
+ op.add_column("dag", sa.Column("default_view", sa.String(25), nullable=True))
def downgrade():
- op.drop_column('dag', 'description')
- op.drop_column('dag', 'default_view')
+ op.drop_column("dag", "description")
+ op.drop_column("dag", "default_view")
diff --git a/airflow/migrations/versions/0041_1_10_3_add_schedule_interval_to_dag.py b/airflow/migrations/versions/0041_1_10_3_add_schedule_interval_to_dag.py
index 42e0a5db92882..2f741f0e65b65 100644
--- a/airflow/migrations/versions/0041_1_10_3_add_schedule_interval_to_dag.py
+++ b/airflow/migrations/versions/0041_1_10_3_add_schedule_interval_to_dag.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Add schedule interval to dag
Revision ID: dd4ecb8fbee3
@@ -23,21 +22,22 @@
Create Date: 2018-12-27 18:39:25.748032
"""
+from __future__ import annotations
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
-revision = 'dd4ecb8fbee3'
-down_revision = 'c8ffec048a3b'
+revision = "dd4ecb8fbee3"
+down_revision = "c8ffec048a3b"
branch_labels = None
depends_on = None
-airflow_version = '1.10.3'
+airflow_version = "1.10.3"
def upgrade():
- op.add_column('dag', sa.Column('schedule_interval', sa.Text(), nullable=True))
+ op.add_column("dag", sa.Column("schedule_interval", sa.Text(), nullable=True))
def downgrade():
- op.drop_column('dag', 'schedule_interval')
+ op.drop_column("dag", "schedule_interval")
diff --git a/airflow/migrations/versions/0042_1_10_3_task_reschedule_fk_on_cascade_delete.py b/airflow/migrations/versions/0042_1_10_3_task_reschedule_fk_on_cascade_delete.py
index 9a535cbc99c48..9c42ef2b188ff 100644
--- a/airflow/migrations/versions/0042_1_10_3_task_reschedule_fk_on_cascade_delete.py
+++ b/airflow/migrations/versions/0042_1_10_3_task_reschedule_fk_on_cascade_delete.py
@@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""task reschedule foreign key on cascade delete
Revision ID: 939bb1e647c8
@@ -22,35 +21,36 @@
Create Date: 2019-02-04 20:21:50.669751
"""
+from __future__ import annotations
from alembic import op
# revision identifiers, used by Alembic.
-revision = '939bb1e647c8'
-down_revision = 'dd4ecb8fbee3'
+revision = "939bb1e647c8"
+down_revision = "dd4ecb8fbee3"
branch_labels = None
depends_on = None
-airflow_version = '1.10.3'
+airflow_version = "1.10.3"
def upgrade():
- with op.batch_alter_table('task_reschedule') as batch_op:
- batch_op.drop_constraint('task_reschedule_dag_task_date_fkey', type_='foreignkey')
+ with op.batch_alter_table("task_reschedule") as batch_op:
+ batch_op.drop_constraint("task_reschedule_dag_task_date_fkey", type_="foreignkey")
batch_op.create_foreign_key(
- 'task_reschedule_dag_task_date_fkey',
- 'task_instance',
- ['task_id', 'dag_id', 'execution_date'],
- ['task_id', 'dag_id', 'execution_date'],
- ondelete='CASCADE',
+ "task_reschedule_dag_task_date_fkey",
+ "task_instance",
+ ["task_id", "dag_id", "execution_date"],
+ ["task_id", "dag_id", "execution_date"],
+ ondelete="CASCADE",
)
def downgrade():
- with op.batch_alter_table('task_reschedule') as batch_op:
- batch_op.drop_constraint('task_reschedule_dag_task_date_fkey', type_='foreignkey')
+ with op.batch_alter_table("task_reschedule") as batch_op:
+ batch_op.drop_constraint("task_reschedule_dag_task_date_fkey", type_="foreignkey")
batch_op.create_foreign_key(
- 'task_reschedule_dag_task_date_fkey',
- 'task_instance',
- ['task_id', 'dag_id', 'execution_date'],
- ['task_id', 'dag_id', 'execution_date'],
+ "task_reschedule_dag_task_date_fkey",
+ "task_instance",
+ ["task_id", "dag_id", "execution_date"],
+ ["task_id", "dag_id", "execution_date"],
)
diff --git a/airflow/migrations/versions/0043_1_10_4_make_taskinstance_pool_not_nullable.py b/airflow/migrations/versions/0043_1_10_4_make_taskinstance_pool_not_nullable.py
index 53261af4a87e8..52069c56daba3 100644
--- a/airflow/migrations/versions/0043_1_10_4_make_taskinstance_pool_not_nullable.py
+++ b/airflow/migrations/versions/0043_1_10_4_make_taskinstance_pool_not_nullable.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Make ``TaskInstance.pool`` not nullable
Revision ID: 6e96a59344a4
@@ -23,6 +22,7 @@
Create Date: 2019-06-13 21:51:32.878437
"""
+from __future__ import annotations
import sqlalchemy as sa
from alembic import op
@@ -33,11 +33,11 @@
from airflow.utils.sqlalchemy import UtcDateTime
# revision identifiers, used by Alembic.
-revision = '6e96a59344a4'
-down_revision = '939bb1e647c8'
+revision = "6e96a59344a4"
+down_revision = "939bb1e647c8"
branch_labels = None
depends_on = None
-airflow_version = '1.10.4'
+airflow_version = "1.10.4"
Base = declarative_base()
ID_LEN = 250
@@ -58,45 +58,45 @@ def upgrade():
"""Make TaskInstance.pool field not nullable."""
with create_session() as session:
session.query(TaskInstance).filter(TaskInstance.pool.is_(None)).update(
- {TaskInstance.pool: 'default_pool'}, synchronize_session=False
+ {TaskInstance.pool: "default_pool"}, synchronize_session=False
) # Avoid select updated rows
session.commit()
conn = op.get_bind()
if conn.dialect.name == "mssql":
- op.drop_index('ti_pool', table_name='task_instance')
+ op.drop_index("ti_pool", table_name="task_instance")
# use batch_alter_table to support SQLite workaround
- with op.batch_alter_table('task_instance') as batch_op:
+ with op.batch_alter_table("task_instance") as batch_op:
batch_op.alter_column(
- column_name='pool',
+ column_name="pool",
type_=sa.String(50),
nullable=False,
)
if conn.dialect.name == "mssql":
- op.create_index('ti_pool', 'task_instance', ['pool', 'state', 'priority_weight'])
+ op.create_index("ti_pool", "task_instance", ["pool", "state", "priority_weight"])
def downgrade():
"""Make TaskInstance.pool field nullable."""
conn = op.get_bind()
if conn.dialect.name == "mssql":
- op.drop_index('ti_pool', table_name='task_instance')
+ op.drop_index("ti_pool", table_name="task_instance")
# use batch_alter_table to support SQLite workaround
- with op.batch_alter_table('task_instance') as batch_op:
+ with op.batch_alter_table("task_instance") as batch_op:
batch_op.alter_column(
- column_name='pool',
+ column_name="pool",
type_=sa.String(50),
nullable=True,
)
if conn.dialect.name == "mssql":
- op.create_index('ti_pool', 'task_instance', ['pool', 'state', 'priority_weight'])
+ op.create_index("ti_pool", "task_instance", ["pool", "state", "priority_weight"])
with create_session() as session:
- session.query(TaskInstance).filter(TaskInstance.pool == 'default_pool').update(
+ session.query(TaskInstance).filter(TaskInstance.pool == "default_pool").update(
{TaskInstance.pool: None}, synchronize_session=False
) # Avoid select updated rows
session.commit()
diff --git a/airflow/migrations/versions/0044_1_10_7_add_serialized_dag_table.py b/airflow/migrations/versions/0044_1_10_7_add_serialized_dag_table.py
index 38179f2a7033a..2feb636de5e07 100644
--- a/airflow/migrations/versions/0044_1_10_7_add_serialized_dag_table.py
+++ b/airflow/migrations/versions/0044_1_10_7_add_serialized_dag_table.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Add ``serialized_dag`` table
Revision ID: d38e04c12aa2
@@ -23,6 +22,8 @@
Create Date: 2019-08-01 14:39:35.616417
"""
+from __future__ import annotations
+
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import mysql
@@ -30,11 +31,11 @@
from airflow.migrations.db_types import StringID
# revision identifiers, used by Alembic.
-revision = 'd38e04c12aa2'
-down_revision = '6e96a59344a4'
+revision = "d38e04c12aa2"
+down_revision = "6e96a59344a4"
branch_labels = None
depends_on = None
-airflow_version = '1.10.7'
+airflow_version = "1.10.7"
def upgrade():
@@ -51,15 +52,15 @@ def upgrade():
json_type = sa.Text
op.create_table(
- 'serialized_dag',
- sa.Column('dag_id', StringID(), nullable=False),
- sa.Column('fileloc', sa.String(length=2000), nullable=False),
- sa.Column('fileloc_hash', sa.Integer(), nullable=False),
- sa.Column('data', json_type(), nullable=False),
- sa.Column('last_updated', sa.DateTime(), nullable=False),
- sa.PrimaryKeyConstraint('dag_id'),
+ "serialized_dag",
+ sa.Column("dag_id", StringID(), nullable=False),
+ sa.Column("fileloc", sa.String(length=2000), nullable=False),
+ sa.Column("fileloc_hash", sa.Integer(), nullable=False),
+ sa.Column("data", json_type(), nullable=False),
+ sa.Column("last_updated", sa.DateTime(), nullable=False),
+ sa.PrimaryKeyConstraint("dag_id"),
)
- op.create_index('idx_fileloc_hash', 'serialized_dag', ['fileloc_hash'])
+ op.create_index("idx_fileloc_hash", "serialized_dag", ["fileloc_hash"])
if conn.dialect.name == "mysql":
conn.execute("SET time_zone = '+00:00'")
@@ -93,4 +94,4 @@ def upgrade():
def downgrade():
"""Downgrade version."""
- op.drop_table('serialized_dag')
+ op.drop_table("serialized_dag")
diff --git a/airflow/migrations/versions/0045_1_10_7_add_root_dag_id_to_dag.py b/airflow/migrations/versions/0045_1_10_7_add_root_dag_id_to_dag.py
index f879450369ab8..8c230df3830da 100644
--- a/airflow/migrations/versions/0045_1_10_7_add_root_dag_id_to_dag.py
+++ b/airflow/migrations/versions/0045_1_10_7_add_root_dag_id_to_dag.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Add ``root_dag_id`` to ``DAG``
Revision ID: b3b105409875
@@ -23,6 +22,7 @@
Create Date: 2019-09-28 23:20:01.744775
"""
+from __future__ import annotations
import sqlalchemy as sa
from alembic import op
@@ -30,20 +30,20 @@
from airflow.migrations.db_types import StringID
# revision identifiers, used by Alembic.
-revision = 'b3b105409875'
-down_revision = 'd38e04c12aa2'
+revision = "b3b105409875"
+down_revision = "d38e04c12aa2"
branch_labels = None
depends_on = None
-airflow_version = '1.10.7'
+airflow_version = "1.10.7"
def upgrade():
"""Apply Add ``root_dag_id`` to ``DAG``"""
- op.add_column('dag', sa.Column('root_dag_id', StringID(), nullable=True))
- op.create_index('idx_root_dag_id', 'dag', ['root_dag_id'], unique=False)
+ op.add_column("dag", sa.Column("root_dag_id", StringID(), nullable=True))
+ op.create_index("idx_root_dag_id", "dag", ["root_dag_id"], unique=False)
def downgrade():
"""Unapply Add ``root_dag_id`` to ``DAG``"""
- op.drop_index('idx_root_dag_id', table_name='dag')
- op.drop_column('dag', 'root_dag_id')
+ op.drop_index("idx_root_dag_id", table_name="dag")
+ op.drop_column("dag", "root_dag_id")
diff --git a/airflow/migrations/versions/0046_1_10_5_change_datetime_to_datetime2_6_on_mssql_.py b/airflow/migrations/versions/0046_1_10_5_change_datetime_to_datetime2_6_on_mssql_.py
index 49e3e9b84a7e0..3b2970aae7139 100644
--- a/airflow/migrations/versions/0046_1_10_5_change_datetime_to_datetime2_6_on_mssql_.py
+++ b/airflow/migrations/versions/0046_1_10_5_change_datetime_to_datetime2_6_on_mssql_.py
@@ -15,14 +15,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
-"""change datetime to datetime2(6) on MSSQL tables
+"""change datetime to datetime2(6) on MSSQL tables.
Revision ID: 74effc47d867
Revises: 6e96a59344a4
Create Date: 2019-08-01 15:19:57.585620
"""
+from __future__ import annotations
from collections import defaultdict
@@ -30,15 +30,15 @@
from sqlalchemy.dialects import mssql
# revision identifiers, used by Alembic.
-revision = '74effc47d867'
-down_revision = '6e96a59344a4'
+revision = "74effc47d867"
+down_revision = "6e96a59344a4"
branch_labels = None
depends_on = None
-airflow_version = '1.10.5'
+airflow_version = "1.10.5"
def upgrade():
- """Change datetime to datetime2(6) when using MSSQL as backend"""
+ """Change datetime to datetime2(6) when using MSSQL as backend."""
conn = op.get_bind()
if conn.dialect.name == "mssql":
result = conn.execute(
@@ -50,106 +50,106 @@ def upgrade():
if mssql_version in ("2000", "2005"):
return
- with op.batch_alter_table('task_reschedule') as task_reschedule_batch_op:
- task_reschedule_batch_op.drop_index('idx_task_reschedule_dag_task_date')
- task_reschedule_batch_op.drop_constraint('task_reschedule_dag_task_date_fkey', type_='foreignkey')
+ with op.batch_alter_table("task_reschedule") as task_reschedule_batch_op:
+ task_reschedule_batch_op.drop_index("idx_task_reschedule_dag_task_date")
+ task_reschedule_batch_op.drop_constraint("task_reschedule_dag_task_date_fkey", type_="foreignkey")
task_reschedule_batch_op.alter_column(
column_name="execution_date",
type_=mssql.DATETIME2(precision=6),
nullable=False,
)
task_reschedule_batch_op.alter_column(
- column_name='start_date', type_=mssql.DATETIME2(precision=6)
+ column_name="start_date", type_=mssql.DATETIME2(precision=6)
)
- task_reschedule_batch_op.alter_column(column_name='end_date', type_=mssql.DATETIME2(precision=6))
+ task_reschedule_batch_op.alter_column(column_name="end_date", type_=mssql.DATETIME2(precision=6))
task_reschedule_batch_op.alter_column(
- column_name='reschedule_date', type_=mssql.DATETIME2(precision=6)
+ column_name="reschedule_date", type_=mssql.DATETIME2(precision=6)
)
- with op.batch_alter_table('task_instance') as task_instance_batch_op:
- task_instance_batch_op.drop_index('ti_state_lkp')
- task_instance_batch_op.drop_index('ti_dag_date')
+ with op.batch_alter_table("task_instance") as task_instance_batch_op:
+ task_instance_batch_op.drop_index("ti_state_lkp")
+ task_instance_batch_op.drop_index("ti_dag_date")
modify_execution_date_with_constraint(
- conn, task_instance_batch_op, 'task_instance', mssql.DATETIME2(precision=6), False
+ conn, task_instance_batch_op, "task_instance", mssql.DATETIME2(precision=6), False
)
- task_instance_batch_op.alter_column(column_name='start_date', type_=mssql.DATETIME2(precision=6))
- task_instance_batch_op.alter_column(column_name='end_date', type_=mssql.DATETIME2(precision=6))
- task_instance_batch_op.alter_column(column_name='queued_dttm', type_=mssql.DATETIME2(precision=6))
+ task_instance_batch_op.alter_column(column_name="start_date", type_=mssql.DATETIME2(precision=6))
+ task_instance_batch_op.alter_column(column_name="end_date", type_=mssql.DATETIME2(precision=6))
+ task_instance_batch_op.alter_column(column_name="queued_dttm", type_=mssql.DATETIME2(precision=6))
task_instance_batch_op.create_index(
- 'ti_state_lkp', ['dag_id', 'task_id', 'execution_date'], unique=False
+ "ti_state_lkp", ["dag_id", "task_id", "execution_date"], unique=False
)
- task_instance_batch_op.create_index('ti_dag_date', ['dag_id', 'execution_date'], unique=False)
+ task_instance_batch_op.create_index("ti_dag_date", ["dag_id", "execution_date"], unique=False)
- with op.batch_alter_table('task_reschedule') as task_reschedule_batch_op:
+ with op.batch_alter_table("task_reschedule") as task_reschedule_batch_op:
task_reschedule_batch_op.create_foreign_key(
- 'task_reschedule_dag_task_date_fkey',
- 'task_instance',
- ['task_id', 'dag_id', 'execution_date'],
- ['task_id', 'dag_id', 'execution_date'],
- ondelete='CASCADE',
+ "task_reschedule_dag_task_date_fkey",
+ "task_instance",
+ ["task_id", "dag_id", "execution_date"],
+ ["task_id", "dag_id", "execution_date"],
+ ondelete="CASCADE",
)
task_reschedule_batch_op.create_index(
- 'idx_task_reschedule_dag_task_date', ['dag_id', 'task_id', 'execution_date'], unique=False
+ "idx_task_reschedule_dag_task_date", ["dag_id", "task_id", "execution_date"], unique=False
)
- with op.batch_alter_table('dag_run') as dag_run_batch_op:
+ with op.batch_alter_table("dag_run") as dag_run_batch_op:
modify_execution_date_with_constraint(
- conn, dag_run_batch_op, 'dag_run', mssql.DATETIME2(precision=6), None
+ conn, dag_run_batch_op, "dag_run", mssql.DATETIME2(precision=6), None
)
- dag_run_batch_op.alter_column(column_name='start_date', type_=mssql.DATETIME2(precision=6))
- dag_run_batch_op.alter_column(column_name='end_date', type_=mssql.DATETIME2(precision=6))
+ dag_run_batch_op.alter_column(column_name="start_date", type_=mssql.DATETIME2(precision=6))
+ dag_run_batch_op.alter_column(column_name="end_date", type_=mssql.DATETIME2(precision=6))
op.alter_column(table_name="log", column_name="execution_date", type_=mssql.DATETIME2(precision=6))
- op.alter_column(table_name='log', column_name='dttm', type_=mssql.DATETIME2(precision=6))
+ op.alter_column(table_name="log", column_name="dttm", type_=mssql.DATETIME2(precision=6))
- with op.batch_alter_table('sla_miss') as sla_miss_batch_op:
+ with op.batch_alter_table("sla_miss") as sla_miss_batch_op:
modify_execution_date_with_constraint(
- conn, sla_miss_batch_op, 'sla_miss', mssql.DATETIME2(precision=6), False
+ conn, sla_miss_batch_op, "sla_miss", mssql.DATETIME2(precision=6), False
)
- sla_miss_batch_op.alter_column(column_name='timestamp', type_=mssql.DATETIME2(precision=6))
+ sla_miss_batch_op.alter_column(column_name="timestamp", type_=mssql.DATETIME2(precision=6))
- op.drop_index('idx_task_fail_dag_task_date', table_name='task_fail')
+ op.drop_index("idx_task_fail_dag_task_date", table_name="task_fail")
op.alter_column(
table_name="task_fail", column_name="execution_date", type_=mssql.DATETIME2(precision=6)
)
- op.alter_column(table_name='task_fail', column_name='start_date', type_=mssql.DATETIME2(precision=6))
- op.alter_column(table_name='task_fail', column_name='end_date', type_=mssql.DATETIME2(precision=6))
+ op.alter_column(table_name="task_fail", column_name="start_date", type_=mssql.DATETIME2(precision=6))
+ op.alter_column(table_name="task_fail", column_name="end_date", type_=mssql.DATETIME2(precision=6))
op.create_index(
- 'idx_task_fail_dag_task_date', 'task_fail', ['dag_id', 'task_id', 'execution_date'], unique=False
+ "idx_task_fail_dag_task_date", "task_fail", ["dag_id", "task_id", "execution_date"], unique=False
)
- op.drop_index('idx_xcom_dag_task_date', table_name='xcom')
+ op.drop_index("idx_xcom_dag_task_date", table_name="xcom")
op.alter_column(table_name="xcom", column_name="execution_date", type_=mssql.DATETIME2(precision=6))
- op.alter_column(table_name='xcom', column_name='timestamp', type_=mssql.DATETIME2(precision=6))
+ op.alter_column(table_name="xcom", column_name="timestamp", type_=mssql.DATETIME2(precision=6))
op.create_index(
- 'idx_xcom_dag_task_date', 'xcom', ['dag_id', 'task_id', 'execution_date'], unique=False
+ "idx_xcom_dag_task_date", "xcom", ["dag_id", "task_id", "execution_date"], unique=False
)
op.alter_column(
- table_name='dag', column_name='last_scheduler_run', type_=mssql.DATETIME2(precision=6)
+ table_name="dag", column_name="last_scheduler_run", type_=mssql.DATETIME2(precision=6)
)
- op.alter_column(table_name='dag', column_name='last_pickled', type_=mssql.DATETIME2(precision=6))
- op.alter_column(table_name='dag', column_name='last_expired', type_=mssql.DATETIME2(precision=6))
+ op.alter_column(table_name="dag", column_name="last_pickled", type_=mssql.DATETIME2(precision=6))
+ op.alter_column(table_name="dag", column_name="last_expired", type_=mssql.DATETIME2(precision=6))
op.alter_column(
- table_name='dag_pickle', column_name='created_dttm', type_=mssql.DATETIME2(precision=6)
+ table_name="dag_pickle", column_name="created_dttm", type_=mssql.DATETIME2(precision=6)
)
op.alter_column(
- table_name='import_error', column_name='timestamp', type_=mssql.DATETIME2(precision=6)
+ table_name="import_error", column_name="timestamp", type_=mssql.DATETIME2(precision=6)
)
- op.drop_index('job_type_heart', table_name='job')
- op.drop_index('idx_job_state_heartbeat', table_name='job')
- op.alter_column(table_name='job', column_name='start_date', type_=mssql.DATETIME2(precision=6))
- op.alter_column(table_name='job', column_name='end_date', type_=mssql.DATETIME2(precision=6))
- op.alter_column(table_name='job', column_name='latest_heartbeat', type_=mssql.DATETIME2(precision=6))
- op.create_index('idx_job_state_heartbeat', 'job', ['state', 'latest_heartbeat'], unique=False)
- op.create_index('job_type_heart', 'job', ['job_type', 'latest_heartbeat'], unique=False)
+ op.drop_index("job_type_heart", table_name="job")
+ op.drop_index("idx_job_state_heartbeat", table_name="job")
+ op.alter_column(table_name="job", column_name="start_date", type_=mssql.DATETIME2(precision=6))
+ op.alter_column(table_name="job", column_name="end_date", type_=mssql.DATETIME2(precision=6))
+ op.alter_column(table_name="job", column_name="latest_heartbeat", type_=mssql.DATETIME2(precision=6))
+ op.create_index("idx_job_state_heartbeat", "job", ["state", "latest_heartbeat"], unique=False)
+ op.create_index("job_type_heart", "job", ["job_type", "latest_heartbeat"], unique=False)
def downgrade():
- """Change datetime2(6) back to datetime"""
+ """Change datetime2(6) back to datetime."""
conn = op.get_bind()
if conn.dialect.name == "mssql":
result = conn.execute(
@@ -161,88 +161,89 @@ def downgrade():
if mssql_version in ("2000", "2005"):
return
- with op.batch_alter_table('task_reschedule') as task_reschedule_batch_op:
- task_reschedule_batch_op.drop_index('idx_task_reschedule_dag_task_date')
- task_reschedule_batch_op.drop_constraint('task_reschedule_dag_task_date_fkey', type_='foreignkey')
+ with op.batch_alter_table("task_reschedule") as task_reschedule_batch_op:
+ task_reschedule_batch_op.drop_index("idx_task_reschedule_dag_task_date")
+ task_reschedule_batch_op.drop_constraint("task_reschedule_dag_task_date_fkey", type_="foreignkey")
task_reschedule_batch_op.alter_column(
column_name="execution_date", type_=mssql.DATETIME, nullable=False
)
- task_reschedule_batch_op.alter_column(column_name='start_date', type_=mssql.DATETIME)
- task_reschedule_batch_op.alter_column(column_name='end_date', type_=mssql.DATETIME)
- task_reschedule_batch_op.alter_column(column_name='reschedule_date', type_=mssql.DATETIME)
+ task_reschedule_batch_op.alter_column(column_name="start_date", type_=mssql.DATETIME)
+ task_reschedule_batch_op.alter_column(column_name="end_date", type_=mssql.DATETIME)
+ task_reschedule_batch_op.alter_column(column_name="reschedule_date", type_=mssql.DATETIME)
- with op.batch_alter_table('task_instance') as task_instance_batch_op:
- task_instance_batch_op.drop_index('ti_state_lkp')
- task_instance_batch_op.drop_index('ti_dag_date')
+ with op.batch_alter_table("task_instance") as task_instance_batch_op:
+ task_instance_batch_op.drop_index("ti_state_lkp")
+ task_instance_batch_op.drop_index("ti_dag_date")
modify_execution_date_with_constraint(
- conn, task_instance_batch_op, 'task_instance', mssql.DATETIME, False
+ conn, task_instance_batch_op, "task_instance", mssql.DATETIME, False
)
- task_instance_batch_op.alter_column(column_name='start_date', type_=mssql.DATETIME)
- task_instance_batch_op.alter_column(column_name='end_date', type_=mssql.DATETIME)
- task_instance_batch_op.alter_column(column_name='queued_dttm', type_=mssql.DATETIME)
+ task_instance_batch_op.alter_column(column_name="start_date", type_=mssql.DATETIME)
+ task_instance_batch_op.alter_column(column_name="end_date", type_=mssql.DATETIME)
+ task_instance_batch_op.alter_column(column_name="queued_dttm", type_=mssql.DATETIME)
task_instance_batch_op.create_index(
- 'ti_state_lkp', ['dag_id', 'task_id', 'execution_date'], unique=False
+ "ti_state_lkp", ["dag_id", "task_id", "execution_date"], unique=False
)
- task_instance_batch_op.create_index('ti_dag_date', ['dag_id', 'execution_date'], unique=False)
+ task_instance_batch_op.create_index("ti_dag_date", ["dag_id", "execution_date"], unique=False)
- with op.batch_alter_table('task_reschedule') as task_reschedule_batch_op:
+ with op.batch_alter_table("task_reschedule") as task_reschedule_batch_op:
task_reschedule_batch_op.create_foreign_key(
- 'task_reschedule_dag_task_date_fkey',
- 'task_instance',
- ['task_id', 'dag_id', 'execution_date'],
- ['task_id', 'dag_id', 'execution_date'],
- ondelete='CASCADE',
+ "task_reschedule_dag_task_date_fkey",
+ "task_instance",
+ ["task_id", "dag_id", "execution_date"],
+ ["task_id", "dag_id", "execution_date"],
+ ondelete="CASCADE",
)
task_reschedule_batch_op.create_index(
- 'idx_task_reschedule_dag_task_date', ['dag_id', 'task_id', 'execution_date'], unique=False
+ "idx_task_reschedule_dag_task_date", ["dag_id", "task_id", "execution_date"], unique=False
)
- with op.batch_alter_table('dag_run') as dag_run_batch_op:
- modify_execution_date_with_constraint(conn, dag_run_batch_op, 'dag_run', mssql.DATETIME, None)
- dag_run_batch_op.alter_column(column_name='start_date', type_=mssql.DATETIME)
- dag_run_batch_op.alter_column(column_name='end_date', type_=mssql.DATETIME)
+ with op.batch_alter_table("dag_run") as dag_run_batch_op:
+ modify_execution_date_with_constraint(conn, dag_run_batch_op, "dag_run", mssql.DATETIME, None)
+ dag_run_batch_op.alter_column(column_name="start_date", type_=mssql.DATETIME)
+ dag_run_batch_op.alter_column(column_name="end_date", type_=mssql.DATETIME)
op.alter_column(table_name="log", column_name="execution_date", type_=mssql.DATETIME)
- op.alter_column(table_name='log', column_name='dttm', type_=mssql.DATETIME)
+ op.alter_column(table_name="log", column_name="dttm", type_=mssql.DATETIME)
- with op.batch_alter_table('sla_miss') as sla_miss_batch_op:
- modify_execution_date_with_constraint(conn, sla_miss_batch_op, 'sla_miss', mssql.DATETIME, False)
- sla_miss_batch_op.alter_column(column_name='timestamp', type_=mssql.DATETIME)
+ with op.batch_alter_table("sla_miss") as sla_miss_batch_op:
+ modify_execution_date_with_constraint(conn, sla_miss_batch_op, "sla_miss", mssql.DATETIME, False)
+ sla_miss_batch_op.alter_column(column_name="timestamp", type_=mssql.DATETIME)
- op.drop_index('idx_task_fail_dag_task_date', table_name='task_fail')
+ op.drop_index("idx_task_fail_dag_task_date", table_name="task_fail")
op.alter_column(table_name="task_fail", column_name="execution_date", type_=mssql.DATETIME)
- op.alter_column(table_name='task_fail', column_name='start_date', type_=mssql.DATETIME)
- op.alter_column(table_name='task_fail', column_name='end_date', type_=mssql.DATETIME)
+ op.alter_column(table_name="task_fail", column_name="start_date", type_=mssql.DATETIME)
+ op.alter_column(table_name="task_fail", column_name="end_date", type_=mssql.DATETIME)
op.create_index(
- 'idx_task_fail_dag_task_date', 'task_fail', ['dag_id', 'task_id', 'execution_date'], unique=False
+ "idx_task_fail_dag_task_date", "task_fail", ["dag_id", "task_id", "execution_date"], unique=False
)
- op.drop_index('idx_xcom_dag_task_date', table_name='xcom')
+ op.drop_index("idx_xcom_dag_task_date", table_name="xcom")
op.alter_column(table_name="xcom", column_name="execution_date", type_=mssql.DATETIME)
- op.alter_column(table_name='xcom', column_name='timestamp', type_=mssql.DATETIME)
+ op.alter_column(table_name="xcom", column_name="timestamp", type_=mssql.DATETIME)
op.create_index(
- 'idx_xcom_dag_task_date', 'xcom', ['dag_id', 'task_ild', 'execution_date'], unique=False
+ "idx_xcom_dag_task_date", "xcom", ["dag_id", "task_ild", "execution_date"], unique=False
)
- op.alter_column(table_name='dag', column_name='last_scheduler_run', type_=mssql.DATETIME)
- op.alter_column(table_name='dag', column_name='last_pickled', type_=mssql.DATETIME)
- op.alter_column(table_name='dag', column_name='last_expired', type_=mssql.DATETIME)
+ op.alter_column(table_name="dag", column_name="last_scheduler_run", type_=mssql.DATETIME)
+ op.alter_column(table_name="dag", column_name="last_pickled", type_=mssql.DATETIME)
+ op.alter_column(table_name="dag", column_name="last_expired", type_=mssql.DATETIME)
- op.alter_column(table_name='dag_pickle', column_name='created_dttm', type_=mssql.DATETIME)
+ op.alter_column(table_name="dag_pickle", column_name="created_dttm", type_=mssql.DATETIME)
- op.alter_column(table_name='import_error', column_name='timestamp', type_=mssql.DATETIME)
+ op.alter_column(table_name="import_error", column_name="timestamp", type_=mssql.DATETIME)
- op.drop_index('job_type_heart', table_name='job')
- op.drop_index('idx_job_state_heartbeat', table_name='job')
- op.alter_column(table_name='job', column_name='start_date', type_=mssql.DATETIME)
- op.alter_column(table_name='job', column_name='end_date', type_=mssql.DATETIME)
- op.alter_column(table_name='job', column_name='latest_heartbeat', type_=mssql.DATETIME)
- op.create_index('idx_job_state_heartbeat', 'job', ['state', 'latest_heartbeat'], unique=False)
- op.create_index('job_type_heart', 'job', ['job_type', 'latest_heartbeat'], unique=False)
+ op.drop_index("job_type_heart", table_name="job")
+ op.drop_index("idx_job_state_heartbeat", table_name="job")
+ op.alter_column(table_name="job", column_name="start_date", type_=mssql.DATETIME)
+ op.alter_column(table_name="job", column_name="end_date", type_=mssql.DATETIME)
+ op.alter_column(table_name="job", column_name="latest_heartbeat", type_=mssql.DATETIME)
+ op.create_index("idx_job_state_heartbeat", "job", ["state", "latest_heartbeat"], unique=False)
+ op.create_index("job_type_heart", "job", ["job_type", "latest_heartbeat"], unique=False)
-def get_table_constraints(conn, table_name):
- """
+def get_table_constraints(conn, table_name) -> dict[tuple[str, str], list[str]]:
+ """Return primary and unique constraint along with column name.
+
This function return primary and unique constraint
along with column name. some tables like task_instance
is missing primary key constraint name and the name is
@@ -252,7 +253,6 @@ def get_table_constraints(conn, table_name):
:param conn: sql connection object
:param table_name: table name
:return: a dictionary of ((constraint name, constraint type), column name) of table
- :rtype: defaultdict(list)
"""
query = f"""SELECT tc.CONSTRAINT_NAME , tc.CONSTRAINT_TYPE, ccu.COLUMN_NAME
FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS tc
@@ -268,49 +268,47 @@ def get_table_constraints(conn, table_name):
def reorder_columns(columns):
- """
- Reorder the columns for creating constraint, preserve primary key ordering
+ """Reorder the columns for creating constraint.
+ Preserve primary key ordering
``['task_id', 'dag_id', 'execution_date']``
:param columns: columns retrieved from DB related to constraint
:return: ordered column
"""
ordered_columns = []
- for column in ['task_id', 'dag_id', 'execution_date']:
+ for column in ["task_id", "dag_id", "execution_date"]:
if column in columns:
ordered_columns.append(column)
for column in columns:
- if column not in ['task_id', 'dag_id', 'execution_date']:
+ if column not in ["task_id", "dag_id", "execution_date"]:
ordered_columns.append(column)
return ordered_columns
def drop_constraint(operator, constraint_dict):
- """
- Drop a primary key or unique constraint
+ """Drop a primary key or unique constraint.
:param operator: batch_alter_table for the table
:param constraint_dict: a dictionary of ((constraint name, constraint type), column name) of table
"""
for constraint, columns in constraint_dict.items():
- if 'execution_date' in columns:
+ if "execution_date" in columns:
if constraint[1].lower().startswith("primary"):
- operator.drop_constraint(constraint[0], type_='primary')
+ operator.drop_constraint(constraint[0], type_="primary")
elif constraint[1].lower().startswith("unique"):
- operator.drop_constraint(constraint[0], type_='unique')
+ operator.drop_constraint(constraint[0], type_="unique")
def create_constraint(operator, constraint_dict):
- """
- Create a primary key or unique constraint
+ """Create a primary key or unique constraint.
:param operator: batch_alter_table for the table
:param constraint_dict: a dictionary of ((constraint name, constraint type), column name) of table
"""
for constraint, columns in constraint_dict.items():
- if 'execution_date' in columns:
+ if "execution_date" in columns:
if constraint[1].lower().startswith("primary"):
operator.create_primary_key(constraint_name=constraint[0], columns=reorder_columns(columns))
elif constraint[1].lower().startswith("unique"):
@@ -319,8 +317,8 @@ def create_constraint(operator, constraint_dict):
)
-def modify_execution_date_with_constraint(conn, batch_operator, table_name, type_, nullable):
- """
+def modify_execution_date_with_constraint(conn, batch_operator, table_name, type_, nullable) -> None:
+ """Change type of column execution_date.
Helper function changes type of column execution_date by
dropping and recreating any primary/unique constraint associated with
the column
@@ -331,7 +329,6 @@ def modify_execution_date_with_constraint(conn, batch_operator, table_name, type
:param type_: DB column type
:param nullable: nullable (boolean)
:return: a dictionary of ((constraint name, constraint type), column name) of table
- :rtype: defaultdict(list)
"""
constraint_dict = get_table_constraints(conn, table_name)
drop_constraint(batch_operator, constraint_dict)
diff --git a/airflow/migrations/versions/0047_1_10_4_increase_queue_name_size_limit.py b/airflow/migrations/versions/0047_1_10_4_increase_queue_name_size_limit.py
index 50dbc43caea89..806495981a175 100644
--- a/airflow/migrations/versions/0047_1_10_4_increase_queue_name_size_limit.py
+++ b/airflow/migrations/versions/0047_1_10_4_increase_queue_name_size_limit.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Increase queue name size limit
Revision ID: 004c1210f153
@@ -23,16 +22,17 @@
Create Date: 2019-06-07 07:46:04.262275
"""
+from __future__ import annotations
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
-revision = '004c1210f153'
-down_revision = '939bb1e647c8'
+revision = "004c1210f153"
+down_revision = "939bb1e647c8"
branch_labels = None
depends_on = None
-airflow_version = '1.10.4'
+airflow_version = "1.10.4"
def upgrade():
@@ -41,12 +41,12 @@ def upgrade():
by broker backends that might use unusually large queue names.
"""
# use batch_alter_table to support SQLite workaround
- with op.batch_alter_table('task_instance') as batch_op:
- batch_op.alter_column('queue', type_=sa.String(256))
+ with op.batch_alter_table("task_instance") as batch_op:
+ batch_op.alter_column("queue", type_=sa.String(256))
def downgrade():
"""Revert column size from 256 to 50 characters, might result in data loss."""
# use batch_alter_table to support SQLite workaround
- with op.batch_alter_table('task_instance') as batch_op:
- batch_op.alter_column('queue', type_=sa.String(50))
+ with op.batch_alter_table("task_instance") as batch_op:
+ batch_op.alter_column("queue", type_=sa.String(50))
diff --git a/airflow/migrations/versions/0048_1_10_3_remove_dag_stat_table.py b/airflow/migrations/versions/0048_1_10_3_remove_dag_stat_table.py
index 454482e71877a..b837c10487f1a 100644
--- a/airflow/migrations/versions/0048_1_10_3_remove_dag_stat_table.py
+++ b/airflow/migrations/versions/0048_1_10_3_remove_dag_stat_table.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Remove ``dag_stat`` table
Revision ID: a56c9515abdc
@@ -23,16 +22,17 @@
Create Date: 2018-12-27 10:27:59.715872
"""
+from __future__ import annotations
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
-revision = 'a56c9515abdc'
-down_revision = 'c8ffec048a3b'
+revision = "a56c9515abdc"
+down_revision = "c8ffec048a3b"
branch_labels = None
depends_on = None
-airflow_version = '1.10.3'
+airflow_version = "1.10.3"
def upgrade():
@@ -43,10 +43,10 @@ def upgrade():
def downgrade():
"""Create dag_stats table"""
op.create_table(
- 'dag_stats',
- sa.Column('dag_id', sa.String(length=250), nullable=False),
- sa.Column('state', sa.String(length=50), nullable=False),
- sa.Column('count', sa.Integer(), nullable=False, default=0),
- sa.Column('dirty', sa.Boolean(), nullable=False, default=False),
- sa.PrimaryKeyConstraint('dag_id', 'state'),
+ "dag_stats",
+ sa.Column("dag_id", sa.String(length=250), nullable=False),
+ sa.Column("state", sa.String(length=50), nullable=False),
+ sa.Column("count", sa.Integer(), nullable=False, default=0),
+ sa.Column("dirty", sa.Boolean(), nullable=False, default=False),
+ sa.PrimaryKeyConstraint("dag_id", "state"),
)
diff --git a/airflow/migrations/versions/0049_1_10_7_merge_heads.py b/airflow/migrations/versions/0049_1_10_7_merge_heads.py
index 1f589e993fe7a..c600a17ab2d05 100644
--- a/airflow/migrations/versions/0049_1_10_7_merge_heads.py
+++ b/airflow/migrations/versions/0049_1_10_7_merge_heads.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Straighten out the migrations
Revision ID: 08364691d074
@@ -23,13 +22,14 @@
Create Date: 2019-11-19 22:05:11.752222
"""
+from __future__ import annotations
# revision identifiers, used by Alembic.
-revision = '08364691d074'
-down_revision = ('a56c9515abdc', '004c1210f153', '74effc47d867', 'b3b105409875')
+revision = "08364691d074"
+down_revision = ("a56c9515abdc", "004c1210f153", "74effc47d867", "b3b105409875")
branch_labels = None
depends_on = None
-airflow_version = '1.10.7'
+airflow_version = "1.10.7"
def upgrade():
diff --git a/airflow/migrations/versions/0050_1_10_7_increase_length_for_connection_password.py b/airflow/migrations/versions/0050_1_10_7_increase_length_for_connection_password.py
index 955e1a668623c..c855ce7750411 100644
--- a/airflow/migrations/versions/0050_1_10_7_increase_length_for_connection_password.py
+++ b/airflow/migrations/versions/0050_1_10_7_increase_length_for_connection_password.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Increase length for connection password
Revision ID: fe461863935f
@@ -23,23 +22,24 @@
Create Date: 2019-12-08 09:47:09.033009
"""
+from __future__ import annotations
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
-revision = 'fe461863935f'
-down_revision = '08364691d074'
+revision = "fe461863935f"
+down_revision = "08364691d074"
branch_labels = None
depends_on = None
-airflow_version = '1.10.7'
+airflow_version = "1.10.7"
def upgrade():
"""Apply Increase length for connection password"""
- with op.batch_alter_table('connection', schema=None) as batch_op:
+ with op.batch_alter_table("connection", schema=None) as batch_op:
batch_op.alter_column(
- 'password',
+ "password",
existing_type=sa.VARCHAR(length=500),
type_=sa.String(length=5000),
existing_nullable=True,
@@ -48,9 +48,9 @@ def upgrade():
def downgrade():
"""Unapply Increase length for connection password"""
- with op.batch_alter_table('connection', schema=None) as batch_op:
+ with op.batch_alter_table("connection", schema=None) as batch_op:
batch_op.alter_column(
- 'password',
+ "password",
existing_type=sa.String(length=5000),
type_=sa.VARCHAR(length=500),
existing_nullable=True,
diff --git a/airflow/migrations/versions/0051_1_10_8_add_dagtags_table.py b/airflow/migrations/versions/0051_1_10_8_add_dagtags_table.py
index a7ae10860e8aa..7d1c03f7ba1d0 100644
--- a/airflow/migrations/versions/0051_1_10_8_add_dagtags_table.py
+++ b/airflow/migrations/versions/0051_1_10_8_add_dagtags_table.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Add ``DagTags`` table
Revision ID: 7939bcff74ba
@@ -23,6 +22,7 @@
Create Date: 2020-01-07 19:39:01.247442
"""
+from __future__ import annotations
import sqlalchemy as sa
from alembic import op
@@ -30,27 +30,27 @@
from airflow.migrations.db_types import StringID
# revision identifiers, used by Alembic.
-revision = '7939bcff74ba'
-down_revision = 'fe461863935f'
+revision = "7939bcff74ba"
+down_revision = "fe461863935f"
branch_labels = None
depends_on = None
-airflow_version = '1.10.8'
+airflow_version = "1.10.8"
def upgrade():
"""Apply Add ``DagTags`` table"""
op.create_table(
- 'dag_tag',
- sa.Column('name', sa.String(length=100), nullable=False),
- sa.Column('dag_id', StringID(), nullable=False),
+ "dag_tag",
+ sa.Column("name", sa.String(length=100), nullable=False),
+ sa.Column("dag_id", StringID(), nullable=False),
sa.ForeignKeyConstraint(
- ['dag_id'],
- ['dag.dag_id'],
+ ["dag_id"],
+ ["dag.dag_id"],
),
- sa.PrimaryKeyConstraint('name', 'dag_id'),
+ sa.PrimaryKeyConstraint("name", "dag_id"),
)
def downgrade():
"""Unapply Add ``DagTags`` table"""
- op.drop_table('dag_tag')
+ op.drop_table("dag_tag")
diff --git a/airflow/migrations/versions/0052_1_10_10_add_pool_slots_field_to_task_instance.py b/airflow/migrations/versions/0052_1_10_10_add_pool_slots_field_to_task_instance.py
index 26bff68ba2141..ac209a87c46a6 100644
--- a/airflow/migrations/versions/0052_1_10_10_add_pool_slots_field_to_task_instance.py
+++ b/airflow/migrations/versions/0052_1_10_10_add_pool_slots_field_to_task_instance.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Add ``pool_slots`` field to ``task_instance``
Revision ID: a4c2fd67d16b
@@ -23,21 +22,22 @@
Create Date: 2020-01-14 03:35:01.161519
"""
+from __future__ import annotations
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
-revision = 'a4c2fd67d16b'
-down_revision = '7939bcff74ba'
+revision = "a4c2fd67d16b"
+down_revision = "7939bcff74ba"
branch_labels = None
depends_on = None
-airflow_version = '1.10.10'
+airflow_version = "1.10.10"
def upgrade():
- op.add_column('task_instance', sa.Column('pool_slots', sa.Integer, default=1))
+ op.add_column("task_instance", sa.Column("pool_slots", sa.Integer, default=1))
def downgrade():
- op.drop_column('task_instance', 'pool_slots')
+ op.drop_column("task_instance", "pool_slots")
diff --git a/airflow/migrations/versions/0053_1_10_10_add_rendered_task_instance_fields_table.py b/airflow/migrations/versions/0053_1_10_10_add_rendered_task_instance_fields_table.py
index 2027dd4745fb7..9bffb96b6db48 100644
--- a/airflow/migrations/versions/0053_1_10_10_add_rendered_task_instance_fields_table.py
+++ b/airflow/migrations/versions/0053_1_10_10_add_rendered_task_instance_fields_table.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Add ``RenderedTaskInstanceFields`` table
Revision ID: 852ae6c715af
@@ -23,6 +22,7 @@
Create Date: 2020-03-10 22:19:18.034961
"""
+from __future__ import annotations
import sqlalchemy as sa
from alembic import op
@@ -30,13 +30,13 @@
from airflow.migrations.db_types import StringID
# revision identifiers, used by Alembic.
-revision = '852ae6c715af'
-down_revision = 'a4c2fd67d16b'
+revision = "852ae6c715af"
+down_revision = "a4c2fd67d16b"
branch_labels = None
depends_on = None
-airflow_version = '1.10.10'
+airflow_version = "1.10.10"
-TABLE_NAME = 'rendered_task_instance_fields'
+TABLE_NAME = "rendered_task_instance_fields"
def upgrade():
@@ -54,11 +54,11 @@ def upgrade():
op.create_table(
TABLE_NAME,
- sa.Column('dag_id', StringID(), nullable=False),
- sa.Column('task_id', StringID(), nullable=False),
- sa.Column('execution_date', sa.TIMESTAMP(timezone=True), nullable=False),
- sa.Column('rendered_fields', json_type(), nullable=False),
- sa.PrimaryKeyConstraint('dag_id', 'task_id', 'execution_date'),
+ sa.Column("dag_id", StringID(), nullable=False),
+ sa.Column("task_id", StringID(), nullable=False),
+ sa.Column("execution_date", sa.TIMESTAMP(timezone=True), nullable=False),
+ sa.Column("rendered_fields", json_type(), nullable=False),
+ sa.PrimaryKeyConstraint("dag_id", "task_id", "execution_date"),
)
diff --git a/airflow/migrations/versions/0054_1_10_10_add_dag_code_table.py b/airflow/migrations/versions/0054_1_10_10_add_dag_code_table.py
index 523a1a0f8c0e5..e0628d142d9bb 100644
--- a/airflow/migrations/versions/0054_1_10_10_add_dag_code_table.py
+++ b/airflow/migrations/versions/0054_1_10_10_add_dag_code_table.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Add ``dag_code`` table
Revision ID: 952da73b5eff
@@ -23,6 +22,7 @@
Create Date: 2020-03-12 12:39:01.797462
"""
+from __future__ import annotations
import sqlalchemy as sa
from alembic import op
@@ -30,11 +30,11 @@
from airflow.models.dagcode import DagCode
# revision identifiers, used by Alembic.
-revision = '952da73b5eff'
-down_revision = '852ae6c715af'
+revision = "952da73b5eff"
+down_revision = "852ae6c715af"
branch_labels = None
depends_on = None
-airflow_version = '1.10.10'
+airflow_version = "1.10.10"
def upgrade():
@@ -44,7 +44,7 @@ def upgrade():
Base = declarative_base()
class SerializedDagModel(Base):
- __tablename__ = 'serialized_dag'
+ __tablename__ = "serialized_dag"
# There are other columns here, but these are the only ones we need for the SELECT/UPDATE we are doing
dag_id = sa.Column(sa.String(250), primary_key=True)
@@ -53,23 +53,23 @@ class SerializedDagModel(Base):
"""Apply add source code table"""
op.create_table(
- 'dag_code',
- sa.Column('fileloc_hash', sa.BigInteger(), nullable=False, primary_key=True, autoincrement=False),
- sa.Column('fileloc', sa.String(length=2000), nullable=False),
- sa.Column('source_code', sa.UnicodeText(), nullable=False),
- sa.Column('last_updated', sa.TIMESTAMP(timezone=True), nullable=False),
+ "dag_code",
+ sa.Column("fileloc_hash", sa.BigInteger(), nullable=False, primary_key=True, autoincrement=False),
+ sa.Column("fileloc", sa.String(length=2000), nullable=False),
+ sa.Column("source_code", sa.UnicodeText(), nullable=False),
+ sa.Column("last_updated", sa.TIMESTAMP(timezone=True), nullable=False),
)
conn = op.get_bind()
- if conn.dialect.name != 'sqlite':
+ if conn.dialect.name != "sqlite":
if conn.dialect.name == "mssql":
- op.drop_index('idx_fileloc_hash', 'serialized_dag')
+ op.drop_index("idx_fileloc_hash", "serialized_dag")
op.alter_column(
- table_name='serialized_dag', column_name='fileloc_hash', type_=sa.BigInteger(), nullable=False
+ table_name="serialized_dag", column_name="fileloc_hash", type_=sa.BigInteger(), nullable=False
)
if conn.dialect.name == "mssql":
- op.create_index('idx_fileloc_hash', 'serialized_dag', ['fileloc_hash'])
+ op.create_index("idx_fileloc_hash", "serialized_dag", ["fileloc_hash"])
sessionmaker = sa.orm.sessionmaker()
session = sessionmaker(bind=conn)
@@ -82,4 +82,4 @@ class SerializedDagModel(Base):
def downgrade():
"""Unapply add source code table"""
- op.drop_table('dag_code')
+ op.drop_table("dag_code")
diff --git a/airflow/migrations/versions/0055_1_10_11_add_precision_to_execution_date_in_mysql.py b/airflow/migrations/versions/0055_1_10_11_add_precision_to_execution_date_in_mysql.py
index 873ba148d03d6..d9e13b8ec6d6a 100644
--- a/airflow/migrations/versions/0055_1_10_11_add_precision_to_execution_date_in_mysql.py
+++ b/airflow/migrations/versions/0055_1_10_11_add_precision_to_execution_date_in_mysql.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Add Precision to ``execution_date`` in ``RenderedTaskInstanceFields`` table
Revision ID: a66efa278eea
@@ -23,19 +22,20 @@
Create Date: 2020-06-16 21:44:02.883132
"""
+from __future__ import annotations
from alembic import op
from sqlalchemy.dialects import mysql
# revision identifiers, used by Alembic.
-revision = 'a66efa278eea'
-down_revision = '952da73b5eff'
+revision = "a66efa278eea"
+down_revision = "952da73b5eff"
branch_labels = None
depends_on = None
-airflow_version = '1.10.11'
+airflow_version = "1.10.11"
-TABLE_NAME = 'rendered_task_instance_fields'
-COLUMN_NAME = 'execution_date'
+TABLE_NAME = "rendered_task_instance_fields"
+COLUMN_NAME = "execution_date"
def upgrade():
diff --git a/airflow/migrations/versions/0056_1_10_12_add_dag_hash_column_to_serialized_dag_.py b/airflow/migrations/versions/0056_1_10_12_add_dag_hash_column_to_serialized_dag_.py
index 5113e2f28cec0..bbed486d096d8 100644
--- a/airflow/migrations/versions/0056_1_10_12_add_dag_hash_column_to_serialized_dag_.py
+++ b/airflow/migrations/versions/0056_1_10_12_add_dag_hash_column_to_serialized_dag_.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Add ``dag_hash`` Column to ``serialized_dag`` table
Revision ID: da3f683c3a5a
@@ -23,26 +22,27 @@
Create Date: 2020-08-07 20:52:09.178296
"""
+from __future__ import annotations
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
-revision = 'da3f683c3a5a'
-down_revision = 'a66efa278eea'
+revision = "da3f683c3a5a"
+down_revision = "a66efa278eea"
branch_labels = None
depends_on = None
-airflow_version = '1.10.12'
+airflow_version = "1.10.12"
def upgrade():
"""Apply Add ``dag_hash`` Column to ``serialized_dag`` table"""
op.add_column(
- 'serialized_dag',
- sa.Column('dag_hash', sa.String(32), nullable=False, server_default='Hash not calculated yet'),
+ "serialized_dag",
+ sa.Column("dag_hash", sa.String(32), nullable=False, server_default="Hash not calculated yet"),
)
def downgrade():
"""Unapply Add ``dag_hash`` Column to ``serialized_dag`` table"""
- op.drop_column('serialized_dag', 'dag_hash')
+ op.drop_column("serialized_dag", "dag_hash")
diff --git a/airflow/migrations/versions/0057_1_10_13_add_fab_tables.py b/airflow/migrations/versions/0057_1_10_13_add_fab_tables.py
index bd3fe44e9ee87..56d0a8981528a 100644
--- a/airflow/migrations/versions/0057_1_10_13_add_fab_tables.py
+++ b/airflow/migrations/versions/0057_1_10_13_add_fab_tables.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Create FAB Tables
Revision ID: 92c57b58940d
@@ -23,18 +22,18 @@
Create Date: 2020-11-13 19:27:10.161814
"""
+from __future__ import annotations
import sqlalchemy as sa
from alembic import op
-
-from airflow.compat.sqlalchemy import inspect
+from sqlalchemy import inspect
# revision identifiers, used by Alembic.
-revision = '92c57b58940d'
-down_revision = 'da3f683c3a5a'
+revision = "92c57b58940d"
+down_revision = "da3f683c3a5a"
branch_labels = None
depends_on = None
-airflow_version = '1.10.13'
+airflow_version = "1.10.13"
def upgrade():
@@ -44,110 +43,110 @@ def upgrade():
tables = inspector.get_table_names()
if "ab_permission" not in tables:
op.create_table(
- 'ab_permission',
- sa.Column('id', sa.Integer(), nullable=False, primary_key=True),
- sa.Column('name', sa.String(length=100), nullable=False),
- sa.PrimaryKeyConstraint('id'),
- sa.UniqueConstraint('name'),
+ "ab_permission",
+ sa.Column("id", sa.Integer(), nullable=False, primary_key=True),
+ sa.Column("name", sa.String(length=100), nullable=False),
+ sa.PrimaryKeyConstraint("id"),
+ sa.UniqueConstraint("name"),
)
if "ab_view_menu" not in tables:
op.create_table(
- 'ab_view_menu',
- sa.Column('id', sa.Integer(), nullable=False, primary_key=True),
- sa.Column('name', sa.String(length=100), nullable=False),
- sa.PrimaryKeyConstraint('id'),
- sa.UniqueConstraint('name'),
+ "ab_view_menu",
+ sa.Column("id", sa.Integer(), nullable=False, primary_key=True),
+ sa.Column("name", sa.String(length=100), nullable=False),
+ sa.PrimaryKeyConstraint("id"),
+ sa.UniqueConstraint("name"),
)
if "ab_role" not in tables:
op.create_table(
- 'ab_role',
- sa.Column('id', sa.Integer(), nullable=False, primary_key=True),
- sa.Column('name', sa.String(length=64), nullable=False),
- sa.PrimaryKeyConstraint('id'),
- sa.UniqueConstraint('name'),
+ "ab_role",
+ sa.Column("id", sa.Integer(), nullable=False, primary_key=True),
+ sa.Column("name", sa.String(length=64), nullable=False),
+ sa.PrimaryKeyConstraint("id"),
+ sa.UniqueConstraint("name"),
)
if "ab_permission_view" not in tables:
op.create_table(
- 'ab_permission_view',
- sa.Column('id', sa.Integer(), nullable=False, primary_key=True),
- sa.Column('permission_id', sa.Integer(), nullable=True),
- sa.Column('view_menu_id', sa.Integer(), nullable=True),
- sa.ForeignKeyConstraint(['permission_id'], ['ab_permission.id']),
- sa.ForeignKeyConstraint(['view_menu_id'], ['ab_view_menu.id']),
- sa.PrimaryKeyConstraint('id'),
- sa.UniqueConstraint('permission_id', 'view_menu_id'),
+ "ab_permission_view",
+ sa.Column("id", sa.Integer(), nullable=False, primary_key=True),
+ sa.Column("permission_id", sa.Integer(), nullable=True),
+ sa.Column("view_menu_id", sa.Integer(), nullable=True),
+ sa.ForeignKeyConstraint(["permission_id"], ["ab_permission.id"]),
+ sa.ForeignKeyConstraint(["view_menu_id"], ["ab_view_menu.id"]),
+ sa.PrimaryKeyConstraint("id"),
+ sa.UniqueConstraint("permission_id", "view_menu_id"),
)
if "ab_permission_view_role" not in tables:
op.create_table(
- 'ab_permission_view_role',
- sa.Column('id', sa.Integer(), nullable=False, primary_key=True),
- sa.Column('permission_view_id', sa.Integer(), nullable=True),
- sa.Column('role_id', sa.Integer(), nullable=True),
- sa.ForeignKeyConstraint(['permission_view_id'], ['ab_permission_view.id']),
- sa.ForeignKeyConstraint(['role_id'], ['ab_role.id']),
- sa.PrimaryKeyConstraint('id'),
+ "ab_permission_view_role",
+ sa.Column("id", sa.Integer(), nullable=False, primary_key=True),
+ sa.Column("permission_view_id", sa.Integer(), nullable=True),
+ sa.Column("role_id", sa.Integer(), nullable=True),
+ sa.ForeignKeyConstraint(["permission_view_id"], ["ab_permission_view.id"]),
+ sa.ForeignKeyConstraint(["role_id"], ["ab_role.id"]),
+ sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("permission_view_id", "role_id"),
)
if "ab_user" not in tables:
op.create_table(
- 'ab_user',
- sa.Column('id', sa.Integer(), nullable=False, primary_key=True),
- sa.Column('first_name', sa.String(length=64), nullable=False),
- sa.Column('last_name', sa.String(length=64), nullable=False),
- sa.Column('username', sa.String(length=64), nullable=False),
- sa.Column('password', sa.String(length=256), nullable=True),
- sa.Column('active', sa.Boolean(), nullable=True),
- sa.Column('email', sa.String(length=64), nullable=False),
- sa.Column('last_login', sa.DateTime(), nullable=True),
- sa.Column('login_count', sa.Integer(), nullable=True),
- sa.Column('fail_login_count', sa.Integer(), nullable=True),
- sa.Column('created_on', sa.DateTime(), nullable=True),
- sa.Column('changed_on', sa.DateTime(), nullable=True),
- sa.Column('created_by_fk', sa.Integer(), nullable=True),
- sa.Column('changed_by_fk', sa.Integer(), nullable=True),
- sa.ForeignKeyConstraint(['changed_by_fk'], ['ab_user.id']),
- sa.ForeignKeyConstraint(['created_by_fk'], ['ab_user.id']),
- sa.PrimaryKeyConstraint('id'),
- sa.UniqueConstraint('email'),
- sa.UniqueConstraint('username'),
+ "ab_user",
+ sa.Column("id", sa.Integer(), nullable=False, primary_key=True),
+ sa.Column("first_name", sa.String(length=64), nullable=False),
+ sa.Column("last_name", sa.String(length=64), nullable=False),
+ sa.Column("username", sa.String(length=64), nullable=False),
+ sa.Column("password", sa.String(length=256), nullable=True),
+ sa.Column("active", sa.Boolean(), nullable=True),
+ sa.Column("email", sa.String(length=64), nullable=False),
+ sa.Column("last_login", sa.DateTime(), nullable=True),
+ sa.Column("login_count", sa.Integer(), nullable=True),
+ sa.Column("fail_login_count", sa.Integer(), nullable=True),
+ sa.Column("created_on", sa.DateTime(), nullable=True),
+ sa.Column("changed_on", sa.DateTime(), nullable=True),
+ sa.Column("created_by_fk", sa.Integer(), nullable=True),
+ sa.Column("changed_by_fk", sa.Integer(), nullable=True),
+ sa.ForeignKeyConstraint(["changed_by_fk"], ["ab_user.id"]),
+ sa.ForeignKeyConstraint(["created_by_fk"], ["ab_user.id"]),
+ sa.PrimaryKeyConstraint("id"),
+ sa.UniqueConstraint("email"),
+ sa.UniqueConstraint("username"),
)
if "ab_user_role" not in tables:
op.create_table(
- 'ab_user_role',
- sa.Column('id', sa.Integer(), nullable=False, primary_key=True),
- sa.Column('user_id', sa.Integer(), nullable=True),
- sa.Column('role_id', sa.Integer(), nullable=True),
+ "ab_user_role",
+ sa.Column("id", sa.Integer(), nullable=False, primary_key=True),
+ sa.Column("user_id", sa.Integer(), nullable=True),
+ sa.Column("role_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(
- ['role_id'],
- ['ab_role.id'],
+ ["role_id"],
+ ["ab_role.id"],
),
sa.ForeignKeyConstraint(
- ['user_id'],
- ['ab_user.id'],
+ ["user_id"],
+ ["ab_user.id"],
),
- sa.PrimaryKeyConstraint('id'),
- sa.UniqueConstraint('user_id', 'role_id'),
+ sa.PrimaryKeyConstraint("id"),
+ sa.UniqueConstraint("user_id", "role_id"),
)
if "ab_register_user" not in tables:
op.create_table(
- 'ab_register_user',
- sa.Column('id', sa.Integer(), nullable=False, primary_key=True),
- sa.Column('first_name', sa.String(length=64), nullable=False),
- sa.Column('last_name', sa.String(length=64), nullable=False),
- sa.Column('username', sa.String(length=64), nullable=False),
- sa.Column('password', sa.String(length=256), nullable=True),
- sa.Column('email', sa.String(length=64), nullable=False),
- sa.Column('registration_date', sa.DateTime(), nullable=True),
- sa.Column('registration_hash', sa.String(length=256), nullable=True),
- sa.PrimaryKeyConstraint('id'),
- sa.UniqueConstraint('username'),
+ "ab_register_user",
+ sa.Column("id", sa.Integer(), nullable=False, primary_key=True),
+ sa.Column("first_name", sa.String(length=64), nullable=False),
+ sa.Column("last_name", sa.String(length=64), nullable=False),
+ sa.Column("username", sa.String(length=64), nullable=False),
+ sa.Column("password", sa.String(length=256), nullable=True),
+ sa.Column("email", sa.String(length=64), nullable=False),
+ sa.Column("registration_date", sa.DateTime(), nullable=True),
+ sa.Column("registration_hash", sa.String(length=256), nullable=True),
+ sa.PrimaryKeyConstraint("id"),
+ sa.UniqueConstraint("username"),
)
@@ -172,7 +171,7 @@ def downgrade():
indexes = inspector.get_foreign_keys(table)
for index in indexes:
if conn.dialect.name != "sqlite":
- op.drop_constraint(index.get('name'), table, type_='foreignkey')
+ op.drop_constraint(index.get("name"), table, type_="foreignkey")
for table in fab_tables:
if table in tables:
diff --git a/airflow/migrations/versions/0058_1_10_13_increase_length_of_fab_ab_view_menu_.py b/airflow/migrations/versions/0058_1_10_13_increase_length_of_fab_ab_view_menu_.py
index 4378c8bd0c084..811562c286877 100644
--- a/airflow/migrations/versions/0058_1_10_13_increase_length_of_fab_ab_view_menu_.py
+++ b/airflow/migrations/versions/0058_1_10_13_increase_length_of_fab_ab_view_menu_.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Increase length of ``Flask-AppBuilder`` ``ab_view_menu.name`` column
Revision ID: 03afc6b6f902
@@ -23,19 +22,20 @@
Create Date: 2020-11-13 22:21:41.619565
"""
+from __future__ import annotations
import sqlalchemy as sa
from alembic import op
+from sqlalchemy import inspect
-from airflow.compat.sqlalchemy import inspect
from airflow.migrations.db_types import StringID
# revision identifiers, used by Alembic.
-revision = '03afc6b6f902'
-down_revision = '92c57b58940d'
+revision = "03afc6b6f902"
+down_revision = "92c57b58940d"
branch_labels = None
depends_on = None
-airflow_version = '1.10.13'
+airflow_version = "1.10.13"
def upgrade():
@@ -62,8 +62,8 @@ def upgrade():
op.execute("PRAGMA foreign_keys=on")
else:
op.alter_column(
- table_name='ab_view_menu',
- column_name='name',
+ table_name="ab_view_menu",
+ column_name="name",
type_=StringID(length=250),
nullable=False,
)
@@ -92,5 +92,5 @@ def downgrade():
op.execute("PRAGMA foreign_keys=on")
else:
op.alter_column(
- table_name='ab_view_menu', column_name='name', type_=sa.String(length=100), nullable=False
+ table_name="ab_view_menu", column_name="name", type_=sa.String(length=100), nullable=False
)
diff --git a/airflow/migrations/versions/0059_2_0_0_drop_user_and_chart.py b/airflow/migrations/versions/0059_2_0_0_drop_user_and_chart.py
index deb2778661f68..a427133ae394c 100644
--- a/airflow/migrations/versions/0059_2_0_0_drop_user_and_chart.py
+++ b/airflow/migrations/versions/0059_2_0_0_drop_user_and_chart.py
@@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Drop ``user`` and ``chart`` table
Revision ID: cf5dc11e79ad
@@ -22,18 +21,19 @@
Create Date: 2019-01-24 15:30:35.834740
"""
+from __future__ import annotations
+
import sqlalchemy as sa
from alembic import op
+from sqlalchemy import inspect
from sqlalchemy.dialects import mysql
-from airflow.compat.sqlalchemy import inspect
-
# revision identifiers, used by Alembic.
-revision = 'cf5dc11e79ad'
-down_revision = '03afc6b6f902'
+revision = "cf5dc11e79ad"
+down_revision = "03afc6b6f902"
branch_labels = None
depends_on = None
-airflow_version = '2.0.0'
+airflow_version = "2.0.0"
def upgrade():
@@ -47,11 +47,11 @@ def upgrade():
inspector = inspect(conn)
tables = inspector.get_table_names()
- if 'known_event' in tables:
+ if "known_event" in tables:
for fkey in inspector.get_foreign_keys(table_name="known_event", referred_table="users"):
- if fkey['name']:
- with op.batch_alter_table(table_name='known_event') as bop:
- bop.drop_constraint(fkey['name'], type_="foreignkey")
+ if fkey["name"]:
+ with op.batch_alter_table(table_name="known_event") as bop:
+ bop.drop_constraint(fkey["name"], type_="foreignkey")
if "chart" in tables:
op.drop_table(
@@ -66,48 +66,48 @@ def downgrade():
conn = op.get_bind()
op.create_table(
- 'users',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('username', sa.String(length=250), nullable=True),
- sa.Column('email', sa.String(length=500), nullable=True),
- sa.Column('password', sa.String(255)),
- sa.Column('superuser', sa.Boolean(), default=False),
- sa.PrimaryKeyConstraint('id'),
- sa.UniqueConstraint('username'),
+ "users",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("username", sa.String(length=250), nullable=True),
+ sa.Column("email", sa.String(length=500), nullable=True),
+ sa.Column("password", sa.String(255)),
+ sa.Column("superuser", sa.Boolean(), default=False),
+ sa.PrimaryKeyConstraint("id"),
+ sa.UniqueConstraint("username"),
)
op.create_table(
- 'chart',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('label', sa.String(length=200), nullable=True),
- sa.Column('conn_id', sa.String(length=250), nullable=False),
- sa.Column('user_id', sa.Integer(), nullable=True),
- sa.Column('chart_type', sa.String(length=100), nullable=True),
- sa.Column('sql_layout', sa.String(length=50), nullable=True),
- sa.Column('sql', sa.Text(), nullable=True),
- sa.Column('y_log_scale', sa.Boolean(), nullable=True),
- sa.Column('show_datatable', sa.Boolean(), nullable=True),
- sa.Column('show_sql', sa.Boolean(), nullable=True),
- sa.Column('height', sa.Integer(), nullable=True),
- sa.Column('default_params', sa.String(length=5000), nullable=True),
- sa.Column('x_is_date', sa.Boolean(), nullable=True),
- sa.Column('iteration_no', sa.Integer(), nullable=True),
- sa.Column('last_modified', sa.DateTime(), nullable=True),
+ "chart",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("label", sa.String(length=200), nullable=True),
+ sa.Column("conn_id", sa.String(length=250), nullable=False),
+ sa.Column("user_id", sa.Integer(), nullable=True),
+ sa.Column("chart_type", sa.String(length=100), nullable=True),
+ sa.Column("sql_layout", sa.String(length=50), nullable=True),
+ sa.Column("sql", sa.Text(), nullable=True),
+ sa.Column("y_log_scale", sa.Boolean(), nullable=True),
+ sa.Column("show_datatable", sa.Boolean(), nullable=True),
+ sa.Column("show_sql", sa.Boolean(), nullable=True),
+ sa.Column("height", sa.Integer(), nullable=True),
+ sa.Column("default_params", sa.String(length=5000), nullable=True),
+ sa.Column("x_is_date", sa.Boolean(), nullable=True),
+ sa.Column("iteration_no", sa.Integer(), nullable=True),
+ sa.Column("last_modified", sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(
- ['user_id'],
- ['users.id'],
+ ["user_id"],
+ ["users.id"],
),
- sa.PrimaryKeyConstraint('id'),
+ sa.PrimaryKeyConstraint("id"),
)
- if conn.dialect.name == 'mysql':
+ if conn.dialect.name == "mysql":
conn.execute("SET time_zone = '+00:00'")
- op.alter_column(table_name='chart', column_name='last_modified', type_=mysql.TIMESTAMP(fsp=6))
+ op.alter_column(table_name="chart", column_name="last_modified", type_=mysql.TIMESTAMP(fsp=6))
else:
- if conn.dialect.name in ('sqlite', 'mssql'):
+ if conn.dialect.name in ("sqlite", "mssql"):
return
- if conn.dialect.name == 'postgresql':
+ if conn.dialect.name == "postgresql":
conn.execute("set timezone=UTC")
- op.alter_column(table_name='chart', column_name='last_modified', type_=sa.TIMESTAMP(timezone=True))
+ op.alter_column(table_name="chart", column_name="last_modified", type_=sa.TIMESTAMP(timezone=True))
diff --git a/airflow/migrations/versions/0060_2_0_0_remove_id_column_from_xcom.py b/airflow/migrations/versions/0060_2_0_0_remove_id_column_from_xcom.py
index a588af5c53917..64ccc753a5bf0 100644
--- a/airflow/migrations/versions/0060_2_0_0_remove_id_column_from_xcom.py
+++ b/airflow/migrations/versions/0060_2_0_0_remove_id_column_from_xcom.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Remove id column from xcom
Revision ID: bbf4a7ad0465
@@ -23,23 +22,22 @@
Create Date: 2019-10-29 13:53:09.445943
"""
+from __future__ import annotations
from collections import defaultdict
from alembic import op
-from sqlalchemy import Column, Integer
-
-from airflow.compat.sqlalchemy import inspect
+from sqlalchemy import Column, Integer, inspect
# revision identifiers, used by Alembic.
-revision = 'bbf4a7ad0465'
-down_revision = 'cf5dc11e79ad'
+revision = "bbf4a7ad0465"
+down_revision = "cf5dc11e79ad"
branch_labels = None
depends_on = None
-airflow_version = '2.0.0'
+airflow_version = "2.0.0"
-def get_table_constraints(conn, table_name):
+def get_table_constraints(conn, table_name) -> dict[tuple[str, str], list[str]]:
"""
This function return primary and unique constraint
along with column name. Some tables like `task_instance`
@@ -50,7 +48,6 @@ def get_table_constraints(conn, table_name):
:param conn: sql connection object
:param table_name: table name
:return: a dictionary of ((constraint name, constraint type), column name) of table
- :rtype: defaultdict(list)
"""
query = f"""SELECT tc.CONSTRAINT_NAME , tc.CONSTRAINT_TYPE, ccu.COLUMN_NAME
FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS tc
@@ -75,9 +72,9 @@ def drop_column_constraints(operator, column_name, constraint_dict):
for constraint, columns in constraint_dict.items():
if column_name in columns:
if constraint[1].lower().startswith("primary"):
- operator.drop_constraint(constraint[0], type_='primary')
+ operator.drop_constraint(constraint[0], type_="primary")
elif constraint[1].lower().startswith("unique"):
- operator.drop_constraint(constraint[0], type_='unique')
+ operator.drop_constraint(constraint[0], type_="unique")
def create_constraints(operator, column_name, constraint_dict):
@@ -100,25 +97,25 @@ def upgrade():
conn = op.get_bind()
inspector = inspect(conn)
- with op.batch_alter_table('xcom') as bop:
- xcom_columns = [col.get('name') for col in inspector.get_columns("xcom")]
+ with op.batch_alter_table("xcom") as bop:
+ xcom_columns = [col.get("name") for col in inspector.get_columns("xcom")]
if "id" in xcom_columns:
- if conn.dialect.name == 'mssql':
+ if conn.dialect.name == "mssql":
constraint_dict = get_table_constraints(conn, "xcom")
- drop_column_constraints(bop, 'id', constraint_dict)
- bop.drop_column('id')
- bop.drop_index('idx_xcom_dag_task_date')
+ drop_column_constraints(bop, "id", constraint_dict)
+ bop.drop_column("id")
+ bop.drop_index("idx_xcom_dag_task_date")
# mssql doesn't allow primary keys with nullable columns
- if conn.dialect.name != 'mssql':
- bop.create_primary_key('pk_xcom', ['dag_id', 'task_id', 'key', 'execution_date'])
+ if conn.dialect.name != "mssql":
+ bop.create_primary_key("pk_xcom", ["dag_id", "task_id", "key", "execution_date"])
def downgrade():
"""Unapply Remove id column from xcom"""
conn = op.get_bind()
- with op.batch_alter_table('xcom') as bop:
- if conn.dialect.name != 'mssql':
- bop.drop_constraint('pk_xcom', type_='primary')
- bop.add_column(Column('id', Integer, nullable=False))
- bop.create_primary_key('id', ['id'])
- bop.create_index('idx_xcom_dag_task_date', ['dag_id', 'task_id', 'key', 'execution_date'])
+ with op.batch_alter_table("xcom") as bop:
+ if conn.dialect.name != "mssql":
+ bop.drop_constraint("pk_xcom", type_="primary")
+ bop.add_column(Column("id", Integer, nullable=False))
+ bop.create_primary_key("id", ["id"])
+ bop.create_index("idx_xcom_dag_task_date", ["dag_id", "task_id", "key", "execution_date"])
diff --git a/airflow/migrations/versions/0061_2_0_0_increase_length_of_pool_name.py b/airflow/migrations/versions/0061_2_0_0_increase_length_of_pool_name.py
index 1174c9be6bf90..62256e2d6ab30 100644
--- a/airflow/migrations/versions/0061_2_0_0_increase_length_of_pool_name.py
+++ b/airflow/migrations/versions/0061_2_0_0_increase_length_of_pool_name.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Increase length of pool name
Revision ID: b25a55525161
@@ -23,6 +22,7 @@
Create Date: 2020-03-09 08:48:14.534700
"""
+from __future__ import annotations
import sqlalchemy as sa
from alembic import op
@@ -30,21 +30,21 @@
from airflow.models.base import COLLATION_ARGS
# revision identifiers, used by Alembic.
-revision = 'b25a55525161'
-down_revision = 'bbf4a7ad0465'
+revision = "b25a55525161"
+down_revision = "bbf4a7ad0465"
branch_labels = None
depends_on = None
-airflow_version = '2.0.0'
+airflow_version = "2.0.0"
def upgrade():
"""Increase column length of pool name from 50 to 256 characters"""
# use batch_alter_table to support SQLite workaround
- with op.batch_alter_table('slot_pool', table_args=sa.UniqueConstraint('pool')) as batch_op:
- batch_op.alter_column('pool', type_=sa.String(256, **COLLATION_ARGS))
+ with op.batch_alter_table("slot_pool", table_args=sa.UniqueConstraint("pool")) as batch_op:
+ batch_op.alter_column("pool", type_=sa.String(256, **COLLATION_ARGS))
def downgrade():
"""Revert Increased length of pool name from 256 to 50 characters"""
- with op.batch_alter_table('slot_pool', table_args=sa.UniqueConstraint('pool')) as batch_op:
- batch_op.alter_column('pool', type_=sa.String(50))
+ with op.batch_alter_table("slot_pool", table_args=sa.UniqueConstraint("pool")) as batch_op:
+ batch_op.alter_column("pool", type_=sa.String(50))
diff --git a/airflow/migrations/versions/0062_2_0_0_add_dagrun_run_type.py b/airflow/migrations/versions/0062_2_0_0_add_dagrun_run_type.py
index ce3d090742dd9..69da52c3850ff 100644
--- a/airflow/migrations/versions/0062_2_0_0_add_dagrun_run_type.py
+++ b/airflow/migrations/versions/0062_2_0_0_add_dagrun_run_type.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""
Add ``run_type`` column in ``dag_run`` table
@@ -24,13 +23,13 @@
Create Date: 2020-04-08 13:35:25.671327
"""
+from __future__ import annotations
import sqlalchemy as sa
from alembic import op
-from sqlalchemy import Column, Integer, String
+from sqlalchemy import Column, Integer, String, inspect
from sqlalchemy.ext.declarative import declarative_base
-from airflow.compat.sqlalchemy import inspect
from airflow.utils.types import DagRunType
# revision identifiers, used by Alembic.
@@ -38,7 +37,7 @@
down_revision = "b25a55525161"
branch_labels = None
depends_on = None
-airflow_version = '2.0.0'
+airflow_version = "2.0.0"
Base = declarative_base()
@@ -59,7 +58,7 @@ def upgrade():
conn = op.get_bind()
inspector = inspect(conn)
- dag_run_columns = [col.get('name') for col in inspector.get_columns("dag_run")]
+ dag_run_columns = [col.get("name") for col in inspector.get_columns("dag_run")]
if "run_type" not in dag_run_columns:
diff --git a/airflow/migrations/versions/0063_2_0_0_set_conn_type_as_non_nullable.py b/airflow/migrations/versions/0063_2_0_0_set_conn_type_as_non_nullable.py
index 1df5bfbfe324c..411a5785980e9 100644
--- a/airflow/migrations/versions/0063_2_0_0_set_conn_type_as_non_nullable.py
+++ b/airflow/migrations/versions/0063_2_0_0_set_conn_type_as_non_nullable.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Set ``conn_type`` as non-nullable
Revision ID: 8f966b9c467a
@@ -23,6 +22,7 @@
Create Date: 2020-06-08 22:36:34.534121
"""
+from __future__ import annotations
import sqlalchemy as sa
from alembic import op
@@ -33,7 +33,7 @@
down_revision = "3c20cacc0044"
branch_labels = None
depends_on = None
-airflow_version = '2.0.0'
+airflow_version = "2.0.0"
def upgrade():
diff --git a/airflow/migrations/versions/0064_2_0_0_add_unique_constraint_to_conn_id.py b/airflow/migrations/versions/0064_2_0_0_add_unique_constraint_to_conn_id.py
index e3172ddf34172..094935b141190 100644
--- a/airflow/migrations/versions/0064_2_0_0_add_unique_constraint_to_conn_id.py
+++ b/airflow/migrations/versions/0064_2_0_0_add_unique_constraint_to_conn_id.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Add unique constraint to ``conn_id``
Revision ID: 8d48763f6d53
@@ -23,6 +22,7 @@
Create Date: 2020-05-03 16:55:01.834231
"""
+from __future__ import annotations
import sqlalchemy as sa
from alembic import op
@@ -30,17 +30,17 @@
from airflow.models.base import COLLATION_ARGS
# revision identifiers, used by Alembic.
-revision = '8d48763f6d53'
-down_revision = '8f966b9c467a'
+revision = "8d48763f6d53"
+down_revision = "8f966b9c467a"
branch_labels = None
depends_on = None
-airflow_version = '2.0.0'
+airflow_version = "2.0.0"
def upgrade():
"""Apply Add unique constraint to ``conn_id`` and set it as non-nullable"""
try:
- with op.batch_alter_table('connection') as batch_op:
+ with op.batch_alter_table("connection") as batch_op:
batch_op.alter_column("conn_id", nullable=False, existing_type=sa.String(250, **COLLATION_ARGS))
batch_op.create_unique_constraint(constraint_name="unique_conn_id", columns=["conn_id"])
@@ -50,7 +50,7 @@ def upgrade():
def downgrade():
"""Unapply Add unique constraint to ``conn_id`` and set it as non-nullable"""
- with op.batch_alter_table('connection') as batch_op:
+ with op.batch_alter_table("connection") as batch_op:
batch_op.drop_constraint(constraint_name="unique_conn_id", type_="unique")
batch_op.alter_column("conn_id", nullable=True, existing_type=sa.String(250))
diff --git a/airflow/migrations/versions/0065_2_0_0_update_schema_for_smart_sensor.py b/airflow/migrations/versions/0065_2_0_0_update_schema_for_smart_sensor.py
index d98be012b7b52..93279c3dea092 100644
--- a/airflow/migrations/versions/0065_2_0_0_update_schema_for_smart_sensor.py
+++ b/airflow/migrations/versions/0065_2_0_0_update_schema_for_smart_sensor.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Add ``sensor_instance`` table
Revision ID: e38be357a868
@@ -23,19 +22,20 @@
Create Date: 2019-06-07 04:03:17.003939
"""
+from __future__ import annotations
+
import sqlalchemy as sa
from alembic import op
-from sqlalchemy import func
+from sqlalchemy import func, inspect
-from airflow.compat.sqlalchemy import inspect
from airflow.migrations.db_types import TIMESTAMP, StringID
# revision identifiers, used by Alembic.
-revision = 'e38be357a868'
-down_revision = '8d48763f6d53'
+revision = "e38be357a868"
+down_revision = "8d48763f6d53"
branch_labels = None
depends_on = None
-airflow_version = '2.0.0'
+airflow_version = "2.0.0"
def upgrade():
@@ -43,38 +43,38 @@ def upgrade():
conn = op.get_bind()
inspector = inspect(conn)
tables = inspector.get_table_names()
- if 'sensor_instance' in tables:
+ if "sensor_instance" in tables:
return
op.create_table(
- 'sensor_instance',
- sa.Column('id', sa.Integer(), nullable=False),
- sa.Column('task_id', StringID(), nullable=False),
- sa.Column('dag_id', StringID(), nullable=False),
- sa.Column('execution_date', TIMESTAMP, nullable=False),
- sa.Column('state', sa.String(length=20), nullable=True),
- sa.Column('try_number', sa.Integer(), nullable=True),
- sa.Column('start_date', TIMESTAMP, nullable=True),
- sa.Column('operator', sa.String(length=1000), nullable=False),
- sa.Column('op_classpath', sa.String(length=1000), nullable=False),
- sa.Column('hashcode', sa.BigInteger(), nullable=False),
- sa.Column('shardcode', sa.Integer(), nullable=False),
- sa.Column('poke_context', sa.Text(), nullable=False),
- sa.Column('execution_context', sa.Text(), nullable=True),
- sa.Column('created_at', TIMESTAMP, default=func.now(), nullable=False),
- sa.Column('updated_at', TIMESTAMP, default=func.now(), nullable=False),
- sa.PrimaryKeyConstraint('id'),
+ "sensor_instance",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("task_id", StringID(), nullable=False),
+ sa.Column("dag_id", StringID(), nullable=False),
+ sa.Column("execution_date", TIMESTAMP, nullable=False),
+ sa.Column("state", sa.String(length=20), nullable=True),
+ sa.Column("try_number", sa.Integer(), nullable=True),
+ sa.Column("start_date", TIMESTAMP, nullable=True),
+ sa.Column("operator", sa.String(length=1000), nullable=False),
+ sa.Column("op_classpath", sa.String(length=1000), nullable=False),
+ sa.Column("hashcode", sa.BigInteger(), nullable=False),
+ sa.Column("shardcode", sa.Integer(), nullable=False),
+ sa.Column("poke_context", sa.Text(), nullable=False),
+ sa.Column("execution_context", sa.Text(), nullable=True),
+ sa.Column("created_at", TIMESTAMP, default=func.now, nullable=False),
+ sa.Column("updated_at", TIMESTAMP, default=func.now, nullable=False),
+ sa.PrimaryKeyConstraint("id"),
)
- op.create_index('ti_primary_key', 'sensor_instance', ['dag_id', 'task_id', 'execution_date'], unique=True)
- op.create_index('si_hashcode', 'sensor_instance', ['hashcode'], unique=False)
- op.create_index('si_shardcode', 'sensor_instance', ['shardcode'], unique=False)
- op.create_index('si_state_shard', 'sensor_instance', ['state', 'shardcode'], unique=False)
- op.create_index('si_updated_at', 'sensor_instance', ['updated_at'], unique=False)
+ op.create_index("ti_primary_key", "sensor_instance", ["dag_id", "task_id", "execution_date"], unique=True)
+ op.create_index("si_hashcode", "sensor_instance", ["hashcode"], unique=False)
+ op.create_index("si_shardcode", "sensor_instance", ["shardcode"], unique=False)
+ op.create_index("si_state_shard", "sensor_instance", ["state", "shardcode"], unique=False)
+ op.create_index("si_updated_at", "sensor_instance", ["updated_at"], unique=False)
def downgrade():
conn = op.get_bind()
inspector = inspect(conn)
tables = inspector.get_table_names()
- if 'sensor_instance' in tables:
- op.drop_table('sensor_instance')
+ if "sensor_instance" in tables:
+ op.drop_table("sensor_instance")
diff --git a/airflow/migrations/versions/0066_2_0_0_add_queued_by_job_id_to_ti.py b/airflow/migrations/versions/0066_2_0_0_add_queued_by_job_id_to_ti.py
index 90247d02ef1b5..011fef7a3aa9e 100644
--- a/airflow/migrations/versions/0066_2_0_0_add_queued_by_job_id_to_ti.py
+++ b/airflow/migrations/versions/0066_2_0_0_add_queued_by_job_id_to_ti.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Add queued by Job ID to TI
Revision ID: b247b1e3d1ed
@@ -23,25 +22,26 @@
Create Date: 2020-09-04 11:53:00.978882
"""
+from __future__ import annotations
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
-revision = 'b247b1e3d1ed'
-down_revision = 'e38be357a868'
+revision = "b247b1e3d1ed"
+down_revision = "e38be357a868"
branch_labels = None
depends_on = None
-airflow_version = '2.0.0'
+airflow_version = "2.0.0"
def upgrade():
"""Apply Add queued by Job ID to TI"""
- with op.batch_alter_table('task_instance') as batch_op:
- batch_op.add_column(sa.Column('queued_by_job_id', sa.Integer(), nullable=True))
+ with op.batch_alter_table("task_instance") as batch_op:
+ batch_op.add_column(sa.Column("queued_by_job_id", sa.Integer(), nullable=True))
def downgrade():
"""Unapply Add queued by Job ID to TI"""
- with op.batch_alter_table('task_instance') as batch_op:
- batch_op.drop_column('queued_by_job_id')
+ with op.batch_alter_table("task_instance") as batch_op:
+ batch_op.drop_column("queued_by_job_id")
diff --git a/airflow/migrations/versions/0067_2_0_0_add_external_executor_id_to_ti.py b/airflow/migrations/versions/0067_2_0_0_add_external_executor_id_to_ti.py
index a9b6030caff8f..aec294aef7f11 100644
--- a/airflow/migrations/versions/0067_2_0_0_add_external_executor_id_to_ti.py
+++ b/airflow/migrations/versions/0067_2_0_0_add_external_executor_id_to_ti.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Add external executor ID to TI
Revision ID: e1a11ece99cc
@@ -23,25 +22,26 @@
Create Date: 2020-09-12 08:23:45.698865
"""
+from __future__ import annotations
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
-revision = 'e1a11ece99cc'
-down_revision = 'b247b1e3d1ed'
+revision = "e1a11ece99cc"
+down_revision = "b247b1e3d1ed"
branch_labels = None
depends_on = None
-airflow_version = '2.0.0'
+airflow_version = "2.0.0"
def upgrade():
"""Apply Add external executor ID to TI"""
- with op.batch_alter_table('task_instance', schema=None) as batch_op:
- batch_op.add_column(sa.Column('external_executor_id', sa.String(length=250), nullable=True))
+ with op.batch_alter_table("task_instance", schema=None) as batch_op:
+ batch_op.add_column(sa.Column("external_executor_id", sa.String(length=250), nullable=True))
def downgrade():
"""Unapply Add external executor ID to TI"""
- with op.batch_alter_table('task_instance', schema=None) as batch_op:
- batch_op.drop_column('external_executor_id')
+ with op.batch_alter_table("task_instance", schema=None) as batch_op:
+ batch_op.drop_column("external_executor_id")
diff --git a/airflow/migrations/versions/0068_2_0_0_drop_kuberesourceversion_and_.py b/airflow/migrations/versions/0068_2_0_0_drop_kuberesourceversion_and_.py
index d2449aa48dc1f..2d90f587c259d 100644
--- a/airflow/migrations/versions/0068_2_0_0_drop_kuberesourceversion_and_.py
+++ b/airflow/migrations/versions/0068_2_0_0_drop_kuberesourceversion_and_.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Drop ``KubeResourceVersion`` and ``KubeWorkerId``
Revision ID: bef4f3d11e8b
@@ -23,18 +22,18 @@
Create Date: 2020-09-22 18:45:28.011654
"""
+from __future__ import annotations
import sqlalchemy as sa
from alembic import op
-
-from airflow.compat.sqlalchemy import inspect
+from sqlalchemy import inspect
# revision identifiers, used by Alembic.
-revision = 'bef4f3d11e8b'
-down_revision = 'e1a11ece99cc'
+revision = "bef4f3d11e8b"
+down_revision = "e1a11ece99cc"
branch_labels = None
depends_on = None
-airflow_version = '2.0.0'
+airflow_version = "2.0.0"
WORKER_UUID_TABLE = "kube_worker_uuid"
diff --git a/airflow/migrations/versions/0069_2_0_0_add_scheduling_decision_to_dagrun_and_.py b/airflow/migrations/versions/0069_2_0_0_add_scheduling_decision_to_dagrun_and_.py
index 66f25cf29fac8..43c79291729b3 100644
--- a/airflow/migrations/versions/0069_2_0_0_add_scheduling_decision_to_dagrun_and_.py
+++ b/airflow/migrations/versions/0069_2_0_0_add_scheduling_decision_to_dagrun_and_.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Add ``scheduling_decision`` to ``DagRun`` and ``DAG``
Revision ID: 98271e7606e2
@@ -23,6 +22,7 @@
Create Date: 2020-10-01 12:13:32.968148
"""
+from __future__ import annotations
import sqlalchemy as sa
from alembic import op
@@ -30,11 +30,11 @@
from airflow.migrations.db_types import TIMESTAMP
# revision identifiers, used by Alembic.
-revision = '98271e7606e2'
-down_revision = 'bef4f3d11e8b'
+revision = "98271e7606e2"
+down_revision = "bef4f3d11e8b"
branch_labels = None
depends_on = None
-airflow_version = '2.0.0'
+airflow_version = "2.0.0"
def upgrade():
@@ -46,24 +46,24 @@ def upgrade():
if is_sqlite:
op.execute("PRAGMA foreign_keys=off")
- with op.batch_alter_table('dag_run', schema=None) as batch_op:
- batch_op.add_column(sa.Column('last_scheduling_decision', TIMESTAMP, nullable=True))
- batch_op.create_index('idx_last_scheduling_decision', ['last_scheduling_decision'], unique=False)
- batch_op.add_column(sa.Column('dag_hash', sa.String(32), nullable=True))
+ with op.batch_alter_table("dag_run", schema=None) as batch_op:
+ batch_op.add_column(sa.Column("last_scheduling_decision", TIMESTAMP, nullable=True))
+ batch_op.create_index("idx_last_scheduling_decision", ["last_scheduling_decision"], unique=False)
+ batch_op.add_column(sa.Column("dag_hash", sa.String(32), nullable=True))
- with op.batch_alter_table('dag', schema=None) as batch_op:
- batch_op.add_column(sa.Column('next_dagrun', TIMESTAMP, nullable=True))
- batch_op.add_column(sa.Column('next_dagrun_create_after', TIMESTAMP, nullable=True))
+ with op.batch_alter_table("dag", schema=None) as batch_op:
+ batch_op.add_column(sa.Column("next_dagrun", TIMESTAMP, nullable=True))
+ batch_op.add_column(sa.Column("next_dagrun_create_after", TIMESTAMP, nullable=True))
# Create with nullable and no default, then ALTER to set values, to avoid table level lock
- batch_op.add_column(sa.Column('concurrency', sa.Integer(), nullable=True))
- batch_op.add_column(sa.Column('has_task_concurrency_limits', sa.Boolean(), nullable=True))
+ batch_op.add_column(sa.Column("concurrency", sa.Integer(), nullable=True))
+ batch_op.add_column(sa.Column("has_task_concurrency_limits", sa.Boolean(), nullable=True))
- batch_op.create_index('idx_next_dagrun_create_after', ['next_dagrun_create_after'], unique=False)
+ batch_op.create_index("idx_next_dagrun_create_after", ["next_dagrun_create_after"], unique=False)
try:
from airflow.configuration import conf
- concurrency = conf.getint('core', 'dag_concurrency', fallback=16)
+ concurrency = conf.getint("core", "dag_concurrency", fallback=16)
except: # noqa
concurrency = 16
@@ -79,9 +79,9 @@ def upgrade():
"""
)
- with op.batch_alter_table('dag', schema=None) as batch_op:
- batch_op.alter_column('concurrency', type_=sa.Integer(), nullable=False)
- batch_op.alter_column('has_task_concurrency_limits', type_=sa.Boolean(), nullable=False)
+ with op.batch_alter_table("dag", schema=None) as batch_op:
+ batch_op.alter_column("concurrency", type_=sa.Integer(), nullable=False)
+ batch_op.alter_column("has_task_concurrency_limits", type_=sa.Boolean(), nullable=False)
if is_sqlite:
op.execute("PRAGMA foreign_keys=on")
@@ -95,17 +95,17 @@ def downgrade():
if is_sqlite:
op.execute("PRAGMA foreign_keys=off")
- with op.batch_alter_table('dag_run', schema=None) as batch_op:
- batch_op.drop_index('idx_last_scheduling_decision')
- batch_op.drop_column('last_scheduling_decision')
- batch_op.drop_column('dag_hash')
-
- with op.batch_alter_table('dag', schema=None) as batch_op:
- batch_op.drop_index('idx_next_dagrun_create_after')
- batch_op.drop_column('next_dagrun_create_after')
- batch_op.drop_column('next_dagrun')
- batch_op.drop_column('concurrency')
- batch_op.drop_column('has_task_concurrency_limits')
+ with op.batch_alter_table("dag_run", schema=None) as batch_op:
+ batch_op.drop_index("idx_last_scheduling_decision")
+ batch_op.drop_column("last_scheduling_decision")
+ batch_op.drop_column("dag_hash")
+
+ with op.batch_alter_table("dag", schema=None) as batch_op:
+ batch_op.drop_index("idx_next_dagrun_create_after")
+ batch_op.drop_column("next_dagrun_create_after")
+ batch_op.drop_column("next_dagrun")
+ batch_op.drop_column("concurrency")
+ batch_op.drop_column("has_task_concurrency_limits")
if is_sqlite:
op.execute("PRAGMA foreign_keys=on")
diff --git a/airflow/migrations/versions/0070_2_0_0_fix_mssql_exec_date_rendered_task_instance.py b/airflow/migrations/versions/0070_2_0_0_fix_mssql_exec_date_rendered_task_instance.py
index dd14290729997..97bb6a4e83df2 100644
--- a/airflow/migrations/versions/0070_2_0_0_fix_mssql_exec_date_rendered_task_instance.py
+++ b/airflow/migrations/versions/0070_2_0_0_fix_mssql_exec_date_rendered_task_instance.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""fix_mssql_exec_date_rendered_task_instance_fields_for_MSSQL
Revision ID: 52d53670a240
@@ -23,18 +22,20 @@
Create Date: 2020-10-13 15:13:24.911486
"""
+from __future__ import annotations
+
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import mssql
# revision identifiers, used by Alembic.
-revision = '52d53670a240'
-down_revision = '98271e7606e2'
+revision = "52d53670a240"
+down_revision = "98271e7606e2"
branch_labels = None
depends_on = None
-airflow_version = '2.0.0'
+airflow_version = "2.0.0"
-TABLE_NAME = 'rendered_task_instance_fields'
+TABLE_NAME = "rendered_task_instance_fields"
def upgrade():
@@ -49,11 +50,11 @@ def upgrade():
op.create_table(
TABLE_NAME,
- sa.Column('dag_id', sa.String(length=250), nullable=False),
- sa.Column('task_id', sa.String(length=250), nullable=False),
- sa.Column('execution_date', mssql.DATETIME2, nullable=False),
- sa.Column('rendered_fields', json_type(), nullable=False),
- sa.PrimaryKeyConstraint('dag_id', 'task_id', 'execution_date'),
+ sa.Column("dag_id", sa.String(length=250), nullable=False),
+ sa.Column("task_id", sa.String(length=250), nullable=False),
+ sa.Column("execution_date", mssql.DATETIME2, nullable=False),
+ sa.Column("rendered_fields", json_type(), nullable=False),
+ sa.PrimaryKeyConstraint("dag_id", "task_id", "execution_date"),
)
@@ -69,9 +70,9 @@ def downgrade():
op.create_table(
TABLE_NAME,
- sa.Column('dag_id', sa.String(length=250), nullable=False),
- sa.Column('task_id', sa.String(length=250), nullable=False),
- sa.Column('execution_date', sa.TIMESTAMP, nullable=False),
- sa.Column('rendered_fields', json_type(), nullable=False),
- sa.PrimaryKeyConstraint('dag_id', 'task_id', 'execution_date'),
+ sa.Column("dag_id", sa.String(length=250), nullable=False),
+ sa.Column("task_id", sa.String(length=250), nullable=False),
+ sa.Column("execution_date", sa.TIMESTAMP, nullable=False),
+ sa.Column("rendered_fields", json_type(), nullable=False),
+ sa.PrimaryKeyConstraint("dag_id", "task_id", "execution_date"),
)
diff --git a/airflow/migrations/versions/0071_2_0_0_add_job_id_to_dagrun_table.py b/airflow/migrations/versions/0071_2_0_0_add_job_id_to_dagrun_table.py
index c42308393fef2..f0dd1c4d969a3 100644
--- a/airflow/migrations/versions/0071_2_0_0_add_job_id_to_dagrun_table.py
+++ b/airflow/migrations/versions/0071_2_0_0_add_job_id_to_dagrun_table.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Add ``creating_job_id`` to ``DagRun`` table
Revision ID: 364159666cbd
@@ -23,23 +22,24 @@
Create Date: 2020-10-10 09:08:07.332456
"""
+from __future__ import annotations
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
-revision = '364159666cbd'
-down_revision = '52d53670a240'
+revision = "364159666cbd"
+down_revision = "52d53670a240"
branch_labels = None
depends_on = None
-airflow_version = '2.0.0'
+airflow_version = "2.0.0"
def upgrade():
"""Apply Add ``creating_job_id`` to ``DagRun`` table"""
- op.add_column('dag_run', sa.Column('creating_job_id', sa.Integer))
+ op.add_column("dag_run", sa.Column("creating_job_id", sa.Integer))
def downgrade():
"""Unapply Add job_id to DagRun table"""
- op.drop_column('dag_run', 'creating_job_id')
+ op.drop_column("dag_run", "creating_job_id")
diff --git a/airflow/migrations/versions/0072_2_0_0_add_k8s_yaml_to_rendered_templates.py b/airflow/migrations/versions/0072_2_0_0_add_k8s_yaml_to_rendered_templates.py
index 675df14259230..242878d5a8b52 100644
--- a/airflow/migrations/versions/0072_2_0_0_add_k8s_yaml_to_rendered_templates.py
+++ b/airflow/migrations/versions/0072_2_0_0_add_k8s_yaml_to_rendered_templates.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""add-k8s-yaml-to-rendered-templates
Revision ID: 45ba3f1493b9
@@ -23,6 +22,7 @@
Create Date: 2020-10-23 23:01:52.471442
"""
+from __future__ import annotations
import sqlalchemy_jsonfield
from alembic import op
@@ -31,14 +31,14 @@
from airflow.settings import json
# revision identifiers, used by Alembic.
-revision = '45ba3f1493b9'
-down_revision = '364159666cbd'
+revision = "45ba3f1493b9"
+down_revision = "364159666cbd"
branch_labels = None
depends_on = None
-airflow_version = '2.0.0'
+airflow_version = "2.0.0"
__tablename__ = "rendered_task_instance_fields"
-k8s_pod_yaml = Column('k8s_pod_yaml', sqlalchemy_jsonfield.JSONField(json=json), nullable=True)
+k8s_pod_yaml = Column("k8s_pod_yaml", sqlalchemy_jsonfield.JSONField(json=json), nullable=True)
def upgrade():
@@ -50,4 +50,4 @@ def upgrade():
def downgrade():
"""Unapply add-k8s-yaml-to-rendered-templates"""
with op.batch_alter_table(__tablename__, schema=None) as batch_op:
- batch_op.drop_column('k8s_pod_yaml')
+ batch_op.drop_column("k8s_pod_yaml")
diff --git a/airflow/migrations/versions/0073_2_0_0_prefix_dag_permissions.py b/airflow/migrations/versions/0073_2_0_0_prefix_dag_permissions.py
index 2538a12a89494..660da4ac2c4f9 100644
--- a/airflow/migrations/versions/0073_2_0_0_prefix_dag_permissions.py
+++ b/airflow/migrations/versions/0073_2_0_0_prefix_dag_permissions.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Prefix DAG permissions.
Revision ID: 849da589634d
@@ -23,6 +22,7 @@
Create Date: 2020-10-01 17:25:10.006322
"""
+from __future__ import annotations
from flask_appbuilder import SQLA
@@ -31,23 +31,23 @@
from airflow.www.fab_security.sqla.models import Action, Permission, Resource
# revision identifiers, used by Alembic.
-revision = '849da589634d'
-down_revision = '45ba3f1493b9'
+revision = "849da589634d"
+down_revision = "45ba3f1493b9"
branch_labels = None
depends_on = None
-airflow_version = '2.0.0'
+airflow_version = "2.0.0"
def prefix_individual_dag_permissions(session):
- dag_perms = ['can_dag_read', 'can_dag_edit']
+ dag_perms = ["can_dag_read", "can_dag_edit"]
prefix = "DAG:"
perms = (
session.query(Permission)
.join(Action)
.filter(Action.name.in_(dag_perms))
.join(Resource)
- .filter(Resource.name != 'all_dags')
- .filter(Resource.name.notlike(prefix + '%'))
+ .filter(Resource.name != "all_dags")
+ .filter(Resource.name.notlike(prefix + "%"))
.all()
)
resource_ids = {permission.resource.id for permission in perms}
@@ -57,14 +57,14 @@ def prefix_individual_dag_permissions(session):
def remove_prefix_in_individual_dag_permissions(session):
- dag_perms = ['can_read', 'can_edit']
+ dag_perms = ["can_read", "can_edit"]
prefix = "DAG:"
perms = (
session.query(Permission)
.join(Action)
.filter(Action.name.in_(dag_perms))
.join(Resource)
- .filter(Resource.name.like(prefix + '%'))
+ .filter(Resource.name.like(prefix + "%"))
.all()
)
for permission in perms:
@@ -86,12 +86,12 @@ def get_or_create_dag_resource(session):
def get_or_create_all_dag_resource(session):
- all_dag_resource = get_resource_query(session, 'all_dags').first()
+ all_dag_resource = get_resource_query(session, "all_dags").first()
if all_dag_resource:
return all_dag_resource
all_dag_resource = Resource()
- all_dag_resource.name = 'all_dags'
+ all_dag_resource.name = "all_dags"
session.add(all_dag_resource)
session.commit()
@@ -156,19 +156,19 @@ def migrate_to_new_dag_permissions(db):
prefix_individual_dag_permissions(db.session)
# Update existing permissions to use `can_read` instead of `can_dag_read`
- can_dag_read_action = get_action_query(db.session, 'can_dag_read').first()
+ can_dag_read_action = get_action_query(db.session, "can_dag_read").first()
old_can_dag_read_permissions = get_permission_with_action_query(db.session, can_dag_read_action)
- can_read_action = get_or_create_action(db.session, 'can_read')
+ can_read_action = get_or_create_action(db.session, "can_read")
update_permission_action(db.session, old_can_dag_read_permissions, can_read_action)
# Update existing permissions to use `can_edit` instead of `can_dag_edit`
- can_dag_edit_action = get_action_query(db.session, 'can_dag_edit').first()
+ can_dag_edit_action = get_action_query(db.session, "can_dag_edit").first()
old_can_dag_edit_permissions = get_permission_with_action_query(db.session, can_dag_edit_action)
- can_edit_action = get_or_create_action(db.session, 'can_edit')
+ can_edit_action = get_or_create_action(db.session, "can_edit")
update_permission_action(db.session, old_can_dag_edit_permissions, can_edit_action)
# Update existing permissions for `all_dags` resource to use `DAGs` resource.
- all_dags_resource = get_resource_query(db.session, 'all_dags').first()
+ all_dags_resource = get_resource_query(db.session, "all_dags").first()
if all_dags_resource:
old_all_dags_permission = get_permission_with_resource_query(db.session, all_dags_resource)
dag_resource = get_or_create_dag_resource(db.session)
@@ -193,15 +193,15 @@ def undo_migrate_to_new_dag_permissions(session):
remove_prefix_in_individual_dag_permissions(session)
# Update existing permissions to use `can_dag_read` instead of `can_read`
- can_read_action = get_action_query(session, 'can_read').first()
+ can_read_action = get_action_query(session, "can_read").first()
new_can_read_permissions = get_permission_with_action_query(session, can_read_action)
- can_dag_read_action = get_or_create_action(session, 'can_dag_read')
+ can_dag_read_action = get_or_create_action(session, "can_dag_read")
update_permission_action(session, new_can_read_permissions, can_dag_read_action)
# Update existing permissions to use `can_dag_edit` instead of `can_edit`
- can_edit_action = get_action_query(session, 'can_edit').first()
+ can_edit_action = get_action_query(session, "can_edit").first()
new_can_edit_permissions = get_permission_with_action_query(session, can_edit_action)
- can_dag_edit_action = get_or_create_action(session, 'can_dag_edit')
+ can_dag_edit_action = get_or_create_action(session, "can_dag_edit")
update_permission_action(session, new_can_edit_permissions, can_dag_edit_action)
# Update existing permissions for `DAGs` resource to use `all_dags` resource.
diff --git a/airflow/migrations/versions/0074_2_0_0_resource_based_permissions.py b/airflow/migrations/versions/0074_2_0_0_resource_based_permissions.py
index 6942be25cfd18..1748ca3d5f3aa 100644
--- a/airflow/migrations/versions/0074_2_0_0_resource_based_permissions.py
+++ b/airflow/migrations/versions/0074_2_0_0_resource_based_permissions.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Resource based permissions.
Revision ID: 2c6edca13270
@@ -23,17 +22,19 @@
Create Date: 2020-10-21 00:18:52.529438
"""
+from __future__ import annotations
+
import logging
from airflow.security import permissions
-from airflow.www.app import create_app
+from airflow.www.app import cached_app
# revision identifiers, used by Alembic.
-revision = '2c6edca13270'
-down_revision = '849da589634d'
+revision = "2c6edca13270"
+down_revision = "849da589634d"
branch_labels = None
depends_on = None
-airflow_version = '2.0.0'
+airflow_version = "2.0.0"
mapping = {
@@ -287,7 +288,7 @@
def remap_permissions():
"""Apply Map Airflow permissions."""
- appbuilder = create_app(config={'FAB_UPDATE_PERMS': False}).appbuilder
+ appbuilder = cached_app(config={"FAB_UPDATE_PERMS": False}).appbuilder
for old, new in mapping.items():
(old_resource_name, old_action_name) = old
old_permission = appbuilder.sm.get_permission(old_action_name, old_resource_name)
@@ -312,7 +313,7 @@ def remap_permissions():
def undo_remap_permissions():
"""Unapply Map Airflow permissions"""
- appbuilder = create_app(config={'FAB_UPDATE_PERMS': False}).appbuilder
+ appbuilder = cached_app(config={"FAB_UPDATE_PERMS": False}).appbuilder
for old, new in mapping.items():
(new_resource_name, new_action_name) = new[0]
new_permission = appbuilder.sm.get_permission(new_action_name, new_resource_name)
diff --git a/airflow/migrations/versions/0075_2_0_0_add_description_field_to_connection.py b/airflow/migrations/versions/0075_2_0_0_add_description_field_to_connection.py
index 4c3f5835dcbfd..db03fae459818 100644
--- a/airflow/migrations/versions/0075_2_0_0_add_description_field_to_connection.py
+++ b/airflow/migrations/versions/0075_2_0_0_add_description_field_to_connection.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Add description field to ``connection`` table
Revision ID: 61ec73d9401f
@@ -23,33 +22,34 @@
Create Date: 2020-09-10 14:56:30.279248
"""
+from __future__ import annotations
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
-revision = '61ec73d9401f'
-down_revision = '2c6edca13270'
+revision = "61ec73d9401f"
+down_revision = "2c6edca13270"
branch_labels = None
depends_on = None
-airflow_version = '2.0.0'
+airflow_version = "2.0.0"
def upgrade():
"""Apply Add description field to ``connection`` table"""
conn = op.get_bind()
- with op.batch_alter_table('connection') as batch_op:
+ with op.batch_alter_table("connection") as batch_op:
if conn.dialect.name == "mysql":
# Handles case where on mysql with utf8mb4 this would exceed the size of row
# We have to set text type in this migration even if originally it was string
# This is permanently fixed in the follow-up migration 64a7d6477aae
- batch_op.add_column(sa.Column('description', sa.Text(length=5000), nullable=True))
+ batch_op.add_column(sa.Column("description", sa.Text(length=5000), nullable=True))
else:
- batch_op.add_column(sa.Column('description', sa.String(length=5000), nullable=True))
+ batch_op.add_column(sa.Column("description", sa.String(length=5000), nullable=True))
def downgrade():
"""Unapply Add description field to ``connection`` table"""
- with op.batch_alter_table('connection', schema=None) as batch_op:
- batch_op.drop_column('description')
+ with op.batch_alter_table("connection", schema=None) as batch_op:
+ batch_op.drop_column("description")
diff --git a/airflow/migrations/versions/0076_2_0_0_fix_description_field_in_connection_to_.py b/airflow/migrations/versions/0076_2_0_0_fix_description_field_in_connection_to_.py
index dba78c04c461a..a397378e4c6b8 100644
--- a/airflow/migrations/versions/0076_2_0_0_fix_description_field_in_connection_to_.py
+++ b/airflow/migrations/versions/0076_2_0_0_fix_description_field_in_connection_to_.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Fix description field in ``connection`` to be ``text``
Revision ID: 64a7d6477aae
@@ -23,16 +22,17 @@
Create Date: 2020-11-25 08:56:11.866607
"""
+from __future__ import annotations
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
-revision = '64a7d6477aae'
-down_revision = '61ec73d9401f'
+revision = "64a7d6477aae"
+down_revision = "61ec73d9401f"
branch_labels = None
depends_on = None
-airflow_version = '2.0.0'
+airflow_version = "2.0.0"
def upgrade():
@@ -43,15 +43,15 @@ def upgrade():
return
if conn.dialect.name == "mysql":
op.alter_column(
- 'connection',
- 'description',
+ "connection",
+ "description",
existing_type=sa.String(length=5000),
type_=sa.Text(length=5000),
existing_nullable=True,
)
else:
# postgres does not allow size modifier for text type
- op.alter_column('connection', 'description', existing_type=sa.String(length=5000), type_=sa.Text())
+ op.alter_column("connection", "description", existing_type=sa.String(length=5000), type_=sa.Text())
def downgrade():
@@ -62,8 +62,8 @@ def downgrade():
return
if conn.dialect.name == "mysql":
op.alter_column(
- 'connection',
- 'description',
+ "connection",
+ "description",
existing_type=sa.Text(5000),
type_=sa.String(length=5000),
existing_nullable=True,
@@ -71,8 +71,8 @@ def downgrade():
else:
# postgres does not allow size modifier for text type
op.alter_column(
- 'connection',
- 'description',
+ "connection",
+ "description",
existing_type=sa.Text(),
type_=sa.String(length=5000),
existing_nullable=True,
diff --git a/airflow/migrations/versions/0077_2_0_0_change_field_in_dagcode_to_mediumtext_.py b/airflow/migrations/versions/0077_2_0_0_change_field_in_dagcode_to_mediumtext_.py
index 2b905c8c2916d..c6343b7fc14f1 100644
--- a/airflow/migrations/versions/0077_2_0_0_change_field_in_dagcode_to_mediumtext_.py
+++ b/airflow/migrations/versions/0077_2_0_0_change_field_in_dagcode_to_mediumtext_.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Change field in ``DagCode`` to ``MEDIUMTEXT`` for MySql
Revision ID: e959f08ac86c
@@ -23,22 +22,24 @@
Create Date: 2020-12-07 16:31:43.982353
"""
+from __future__ import annotations
+
from alembic import op
from sqlalchemy.dialects import mysql
# revision identifiers, used by Alembic.
-revision = 'e959f08ac86c'
-down_revision = '64a7d6477aae'
+revision = "e959f08ac86c"
+down_revision = "64a7d6477aae"
branch_labels = None
depends_on = None
-airflow_version = '2.0.0'
+airflow_version = "2.0.0"
def upgrade():
conn = op.get_bind()
if conn.dialect.name == "mysql":
op.alter_column(
- table_name='dag_code', column_name='source_code', type_=mysql.MEDIUMTEXT, nullable=False
+ table_name="dag_code", column_name="source_code", type_=mysql.MEDIUMTEXT, nullable=False
)
diff --git a/airflow/migrations/versions/0078_2_0_1_remove_can_read_permission_on_config_.py b/airflow/migrations/versions/0078_2_0_1_remove_can_read_permission_on_config_.py
index 51e8f20dbaf6c..b9bc66d01e094 100644
--- a/airflow/migrations/versions/0078_2_0_1_remove_can_read_permission_on_config_.py
+++ b/airflow/migrations/versions/0078_2_0_1_remove_can_read_permission_on_config_.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Remove ``can_read`` permission on config resource for ``User`` and ``Viewer`` role
Revision ID: 82b7c48c147f
@@ -23,17 +22,19 @@
Create Date: 2021-02-04 12:45:58.138224
"""
+from __future__ import annotations
+
import logging
from airflow.security import permissions
-from airflow.www.app import create_app
+from airflow.www.app import cached_app
# revision identifiers, used by Alembic.
-revision = '82b7c48c147f'
-down_revision = 'e959f08ac86c'
+revision = "82b7c48c147f"
+down_revision = "e959f08ac86c"
branch_labels = None
depends_on = None
-airflow_version = '2.0.1'
+airflow_version = "2.0.1"
def upgrade():
@@ -41,7 +42,7 @@ def upgrade():
log = logging.getLogger()
handlers = log.handlers[:]
- appbuilder = create_app(config={'FAB_UPDATE_PERMS': False}).appbuilder
+ appbuilder = cached_app(config={"FAB_UPDATE_PERMS": False}).appbuilder
roles_to_modify = [role for role in appbuilder.sm.get_all_roles() if role.name in ["User", "Viewer"]]
can_read_on_config_perm = appbuilder.sm.get_permission(
permissions.ACTION_CAN_READ, permissions.RESOURCE_CONFIG
@@ -58,7 +59,7 @@ def upgrade():
def downgrade():
"""Add can_read action on config resource for User and Viewer role"""
- appbuilder = create_app(config={'FAB_UPDATE_PERMS': False}).appbuilder
+ appbuilder = cached_app(config={"FAB_UPDATE_PERMS": False}).appbuilder
roles_to_modify = [role for role in appbuilder.sm.get_all_roles() if role.name in ["User", "Viewer"]]
can_read_on_config_perm = appbuilder.sm.get_permission(
permissions.ACTION_CAN_READ, permissions.RESOURCE_CONFIG
diff --git a/airflow/migrations/versions/0079_2_0_2_increase_size_of_connection_extra_field_.py b/airflow/migrations/versions/0079_2_0_2_increase_size_of_connection_extra_field_.py
index 77873664f6c20..39a4be4cdd9b0 100644
--- a/airflow/migrations/versions/0079_2_0_2_increase_size_of_connection_extra_field_.py
+++ b/airflow/migrations/versions/0079_2_0_2_increase_size_of_connection_extra_field_.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Increase size of ``connection.extra`` field to handle multiple RSA keys
Revision ID: 449b4072c2da
@@ -23,23 +22,24 @@
Create Date: 2020-03-16 19:02:55.337710
"""
+from __future__ import annotations
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
-revision = '449b4072c2da'
-down_revision = '82b7c48c147f'
+revision = "449b4072c2da"
+down_revision = "82b7c48c147f"
branch_labels = None
depends_on = None
-airflow_version = '2.0.2'
+airflow_version = "2.0.2"
def upgrade():
"""Apply increase_length_for_connection_password"""
- with op.batch_alter_table('connection', schema=None) as batch_op:
+ with op.batch_alter_table("connection", schema=None) as batch_op:
batch_op.alter_column(
- 'extra',
+ "extra",
existing_type=sa.VARCHAR(length=5000),
type_=sa.TEXT(),
existing_nullable=True,
@@ -48,9 +48,9 @@ def upgrade():
def downgrade():
"""Unapply increase_length_for_connection_password"""
- with op.batch_alter_table('connection', schema=None) as batch_op:
+ with op.batch_alter_table("connection", schema=None) as batch_op:
batch_op.alter_column(
- 'extra',
+ "extra",
existing_type=sa.TEXT(),
type_=sa.VARCHAR(length=5000),
existing_nullable=True,
diff --git a/airflow/migrations/versions/0080_2_0_2_change_default_pool_slots_to_1.py b/airflow/migrations/versions/0080_2_0_2_change_default_pool_slots_to_1.py
index f5ae34c2977c9..16a2da2c71d84 100644
--- a/airflow/migrations/versions/0080_2_0_2_change_default_pool_slots_to_1.py
+++ b/airflow/migrations/versions/0080_2_0_2_change_default_pool_slots_to_1.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Change default ``pool_slots`` to ``1``
Revision ID: 8646922c8a04
@@ -23,25 +22,38 @@
Create Date: 2021-02-23 23:19:22.409973
"""
+from __future__ import annotations
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
-revision = '8646922c8a04'
-down_revision = '449b4072c2da'
+revision = "8646922c8a04"
+down_revision = "449b4072c2da"
branch_labels = None
depends_on = None
-airflow_version = '2.0.2'
+airflow_version = "2.0.2"
def upgrade():
"""Change default ``pool_slots`` to ``1`` and make pool_slots not nullable"""
+ op.execute("UPDATE task_instance SET pool_slots = 1 WHERE pool_slots IS NULL")
with op.batch_alter_table("task_instance", schema=None) as batch_op:
- batch_op.alter_column("pool_slots", existing_type=sa.Integer, nullable=False, server_default='1')
+ batch_op.alter_column("pool_slots", existing_type=sa.Integer, nullable=False, server_default="1")
def downgrade():
"""Unapply Change default ``pool_slots`` to ``1``"""
- with op.batch_alter_table("task_instance", schema=None) as batch_op:
- batch_op.alter_column("pool_slots", existing_type=sa.Integer, nullable=True, server_default=None)
+ conn = op.get_bind()
+ if conn.dialect.name == "mssql":
+ inspector = sa.inspect(conn.engine)
+ columns = inspector.get_columns("task_instance")
+ for col in columns:
+ if col["name"] == "pool_slots" and col["default"] == "('1')":
+ with op.batch_alter_table("task_instance", schema=None) as batch_op:
+ batch_op.alter_column(
+ "pool_slots", existing_type=sa.Integer, nullable=True, server_default=None
+ )
+ else:
+ with op.batch_alter_table("task_instance", schema=None) as batch_op:
+ batch_op.alter_column("pool_slots", existing_type=sa.Integer, nullable=True, server_default=None)
diff --git a/airflow/migrations/versions/0081_2_0_2_rename_last_scheduler_run_column.py b/airflow/migrations/versions/0081_2_0_2_rename_last_scheduler_run_column.py
index 487abdb11994b..78a498bf69c09 100644
--- a/airflow/migrations/versions/0081_2_0_2_rename_last_scheduler_run_column.py
+++ b/airflow/migrations/versions/0081_2_0_2_rename_last_scheduler_run_column.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Rename ``last_scheduler_run`` column in ``DAG`` table to ``last_parsed_time``
Revision ID: 2e42bb497a22
@@ -23,31 +22,32 @@
Create Date: 2021-03-04 19:50:38.880942
"""
+from __future__ import annotations
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import mssql
# revision identifiers, used by Alembic.
-revision = '2e42bb497a22'
-down_revision = '8646922c8a04'
+revision = "2e42bb497a22"
+down_revision = "8646922c8a04"
branch_labels = None
depends_on = None
-airflow_version = '2.0.2'
+airflow_version = "2.0.2"
def upgrade():
"""Apply Rename ``last_scheduler_run`` column in ``DAG`` table to ``last_parsed_time``"""
conn = op.get_bind()
if conn.dialect.name == "mssql":
- with op.batch_alter_table('dag') as batch_op:
+ with op.batch_alter_table("dag") as batch_op:
batch_op.alter_column(
- 'last_scheduler_run', new_column_name='last_parsed_time', type_=mssql.DATETIME2(precision=6)
+ "last_scheduler_run", new_column_name="last_parsed_time", type_=mssql.DATETIME2(precision=6)
)
else:
- with op.batch_alter_table('dag') as batch_op:
+ with op.batch_alter_table("dag") as batch_op:
batch_op.alter_column(
- 'last_scheduler_run', new_column_name='last_parsed_time', type_=sa.TIMESTAMP(timezone=True)
+ "last_scheduler_run", new_column_name="last_parsed_time", type_=sa.TIMESTAMP(timezone=True)
)
@@ -55,12 +55,12 @@ def downgrade():
"""Unapply Rename ``last_scheduler_run`` column in ``DAG`` table to ``last_parsed_time``"""
conn = op.get_bind()
if conn.dialect.name == "mssql":
- with op.batch_alter_table('dag') as batch_op:
+ with op.batch_alter_table("dag") as batch_op:
batch_op.alter_column(
- 'last_parsed_time', new_column_name='last_scheduler_run', type_=mssql.DATETIME2(precision=6)
+ "last_parsed_time", new_column_name="last_scheduler_run", type_=mssql.DATETIME2(precision=6)
)
else:
- with op.batch_alter_table('dag') as batch_op:
+ with op.batch_alter_table("dag") as batch_op:
batch_op.alter_column(
- 'last_parsed_time', new_column_name='last_scheduler_run', type_=sa.TIMESTAMP(timezone=True)
+ "last_parsed_time", new_column_name="last_scheduler_run", type_=sa.TIMESTAMP(timezone=True)
)
diff --git a/airflow/migrations/versions/0082_2_1_0_increase_pool_name_size_in_taskinstance.py b/airflow/migrations/versions/0082_2_1_0_increase_pool_name_size_in_taskinstance.py
index 30e8ca553f8c3..5b3c4b89b4bf1 100644
--- a/airflow/migrations/versions/0082_2_1_0_increase_pool_name_size_in_taskinstance.py
+++ b/airflow/migrations/versions/0082_2_1_0_increase_pool_name_size_in_taskinstance.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Increase maximum length of pool name in ``task_instance`` table to ``256`` characters
Revision ID: 90d1635d7b86
@@ -23,32 +22,33 @@
Create Date: 2021-04-05 09:37:54.848731
"""
+from __future__ import annotations
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
-revision = '90d1635d7b86'
-down_revision = '2e42bb497a22'
+revision = "90d1635d7b86"
+down_revision = "2e42bb497a22"
branch_labels = None
depends_on = None
-airflow_version = '2.1.0'
+airflow_version = "2.1.0"
def upgrade():
"""Apply Increase maximum length of pool name in ``task_instance`` table to ``256`` characters"""
- with op.batch_alter_table('task_instance') as batch_op:
- batch_op.alter_column('pool', type_=sa.String(256), nullable=False)
+ with op.batch_alter_table("task_instance") as batch_op:
+ batch_op.alter_column("pool", type_=sa.String(256), nullable=False)
def downgrade():
"""Unapply Increase maximum length of pool name in ``task_instance`` table to ``256`` characters"""
conn = op.get_bind()
- if conn.dialect.name == 'mssql':
- with op.batch_alter_table('task_instance') as batch_op:
- batch_op.drop_index('ti_pool')
- batch_op.alter_column('pool', type_=sa.String(50), nullable=False)
- batch_op.create_index('ti_pool', ['pool'])
+ if conn.dialect.name == "mssql":
+ with op.batch_alter_table("task_instance") as batch_op:
+ batch_op.drop_index("ti_pool")
+ batch_op.alter_column("pool", type_=sa.String(50), nullable=False)
+ batch_op.create_index("ti_pool", ["pool"])
else:
- with op.batch_alter_table('task_instance') as batch_op:
- batch_op.alter_column('pool', type_=sa.String(50), nullable=False)
+ with op.batch_alter_table("task_instance") as batch_op:
+ batch_op.alter_column("pool", type_=sa.String(50), nullable=False)
diff --git a/airflow/migrations/versions/0083_2_1_0_add_description_field_to_variable.py b/airflow/migrations/versions/0083_2_1_0_add_description_field_to_variable.py
index 761a718f32237..14fb651893be1 100644
--- a/airflow/migrations/versions/0083_2_1_0_add_description_field_to_variable.py
+++ b/airflow/migrations/versions/0083_2_1_0_add_description_field_to_variable.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Add description field to ``Variable`` model
Revision ID: e165e7455d70
@@ -23,25 +22,26 @@
Create Date: 2021-04-11 22:28:02.107290
"""
+from __future__ import annotations
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
-revision = 'e165e7455d70'
-down_revision = '90d1635d7b86'
+revision = "e165e7455d70"
+down_revision = "90d1635d7b86"
branch_labels = None
depends_on = None
-airflow_version = '2.1.0'
+airflow_version = "2.1.0"
def upgrade():
"""Apply Add description field to ``Variable`` model"""
- with op.batch_alter_table('variable', schema=None) as batch_op:
- batch_op.add_column(sa.Column('description', sa.Text(), nullable=True))
+ with op.batch_alter_table("variable", schema=None) as batch_op:
+ batch_op.add_column(sa.Column("description", sa.Text(), nullable=True))
def downgrade():
"""Unapply Add description field to ``Variable`` model"""
- with op.batch_alter_table('variable', schema=None) as batch_op:
- batch_op.drop_column('description')
+ with op.batch_alter_table("variable", schema=None) as batch_op:
+ batch_op.drop_column("description")
diff --git a/airflow/migrations/versions/0084_2_1_0_resource_based_permissions_for_default_.py b/airflow/migrations/versions/0084_2_1_0_resource_based_permissions_for_default_.py
index cd162de566e88..f5e8706c09d54 100644
--- a/airflow/migrations/versions/0084_2_1_0_resource_based_permissions_for_default_.py
+++ b/airflow/migrations/versions/0084_2_1_0_resource_based_permissions_for_default_.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Resource based permissions for default ``Flask-AppBuilder`` views
Revision ID: a13f7613ad25
@@ -23,17 +22,19 @@
Create Date: 2021-03-20 21:23:05.793378
"""
+from __future__ import annotations
+
import logging
from airflow.security import permissions
-from airflow.www.app import create_app
+from airflow.www.app import cached_app
# revision identifiers, used by Alembic.
-revision = 'a13f7613ad25'
-down_revision = 'e165e7455d70'
+revision = "a13f7613ad25"
+down_revision = "e165e7455d70"
branch_labels = None
depends_on = None
-airflow_version = '2.1.0'
+airflow_version = "2.1.0"
mapping = {
@@ -139,7 +140,7 @@
def remap_permissions():
"""Apply Map Airflow permissions."""
- appbuilder = create_app(config={'FAB_UPDATE_PERMS': False}).appbuilder
+ appbuilder = cached_app(config={"FAB_UPDATE_PERMS": False}).appbuilder
for old, new in mapping.items():
(old_resource_name, old_action_name) = old
old_permission = appbuilder.sm.get_permission(old_action_name, old_resource_name)
@@ -164,7 +165,7 @@ def remap_permissions():
def undo_remap_permissions():
"""Unapply Map Airflow permissions"""
- appbuilder = create_app(config={'FAB_UPDATE_PERMS': False}).appbuilder
+ appbuilder = cached_app(config={"FAB_UPDATE_PERMS": False}).appbuilder
for old, new in mapping.items():
(new_resource_name, new_action_name) = new[0]
new_permission = appbuilder.sm.get_permission(new_action_name, new_resource_name)
diff --git a/airflow/migrations/versions/0085_2_1_3_add_queued_at_column_to_dagrun_table.py b/airflow/migrations/versions/0085_2_1_3_add_queued_at_column_to_dagrun_table.py
index f40daa1035b23..3d27a5d1e30f4 100644
--- a/airflow/migrations/versions/0085_2_1_3_add_queued_at_column_to_dagrun_table.py
+++ b/airflow/migrations/versions/0085_2_1_3_add_queued_at_column_to_dagrun_table.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Add ``queued_at`` column in ``dag_run`` table
Revision ID: 97cdd93827b8
@@ -23,6 +22,7 @@
Create Date: 2021-06-29 21:53:48.059438
"""
+from __future__ import annotations
import sqlalchemy as sa
from alembic import op
@@ -30,19 +30,19 @@
from airflow.migrations.db_types import TIMESTAMP
# revision identifiers, used by Alembic.
-revision = '97cdd93827b8'
-down_revision = 'a13f7613ad25'
+revision = "97cdd93827b8"
+down_revision = "a13f7613ad25"
branch_labels = None
depends_on = None
-airflow_version = '2.1.3'
+airflow_version = "2.1.3"
def upgrade():
"""Apply Add ``queued_at`` column in ``dag_run`` table"""
- op.add_column('dag_run', sa.Column('queued_at', TIMESTAMP, nullable=True))
+ op.add_column("dag_run", sa.Column("queued_at", TIMESTAMP, nullable=True))
def downgrade():
"""Unapply Add ``queued_at`` column in ``dag_run`` table"""
- with op.batch_alter_table('dag_run') as batch_op:
- batch_op.drop_column('queued_at')
+ with op.batch_alter_table("dag_run") as batch_op:
+ batch_op.drop_column("queued_at")
diff --git a/airflow/migrations/versions/0086_2_1_4_add_max_active_runs_column_to_dagmodel_.py b/airflow/migrations/versions/0086_2_1_4_add_max_active_runs_column_to_dagmodel_.py
index 49ba327dc00e8..c68e8e8e07c97 100644
--- a/airflow/migrations/versions/0086_2_1_4_add_max_active_runs_column_to_dagmodel_.py
+++ b/airflow/migrations/versions/0086_2_1_4_add_max_active_runs_column_to_dagmodel_.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Add ``max_active_runs`` column to ``dag_model`` table
Revision ID: 092435bf5d12
@@ -23,27 +22,28 @@
Create Date: 2021-09-06 21:29:24.728923
"""
+from __future__ import annotations
import sqlalchemy as sa
from alembic import op
from sqlalchemy import text
# revision identifiers, used by Alembic.
-revision = '092435bf5d12'
-down_revision = '97cdd93827b8'
+revision = "092435bf5d12"
+down_revision = "97cdd93827b8"
branch_labels = None
depends_on = None
-airflow_version = '2.1.4'
+airflow_version = "2.1.4"
def upgrade():
"""Apply Add ``max_active_runs`` column to ``dag_model`` table"""
- op.add_column('dag', sa.Column('max_active_runs', sa.Integer(), nullable=True))
- with op.batch_alter_table('dag_run', schema=None) as batch_op:
+ op.add_column("dag", sa.Column("max_active_runs", sa.Integer(), nullable=True))
+ with op.batch_alter_table("dag_run", schema=None) as batch_op:
# Add index to dag_run.dag_id and also add index to dag_run.state where state==running
- batch_op.create_index('idx_dag_run_dag_id', ['dag_id'])
+ batch_op.create_index("idx_dag_run_dag_id", ["dag_id"])
batch_op.create_index(
- 'idx_dag_run_running_dags',
+ "idx_dag_run_running_dags",
["state", "dag_id"],
postgresql_where=text("state='running'"),
mssql_where=text("state='running'"),
@@ -53,9 +53,9 @@ def upgrade():
def downgrade():
"""Unapply Add ``max_active_runs`` column to ``dag_model`` table"""
- with op.batch_alter_table('dag') as batch_op:
- batch_op.drop_column('max_active_runs')
- with op.batch_alter_table('dag_run', schema=None) as batch_op:
+ with op.batch_alter_table("dag") as batch_op:
+ batch_op.drop_column("max_active_runs")
+ with op.batch_alter_table("dag_run", schema=None) as batch_op:
# Drop index to dag_run.dag_id and also drop index to dag_run.state where state==running
- batch_op.drop_index('idx_dag_run_dag_id')
- batch_op.drop_index('idx_dag_run_running_dags')
+ batch_op.drop_index("idx_dag_run_dag_id")
+ batch_op.drop_index("idx_dag_run_running_dags")
diff --git a/airflow/migrations/versions/0087_2_1_4_add_index_on_state_dag_id_for_queued_.py b/airflow/migrations/versions/0087_2_1_4_add_index_on_state_dag_id_for_queued_.py
index f9e55fce83550..d7c34cc5b0586 100644
--- a/airflow/migrations/versions/0087_2_1_4_add_index_on_state_dag_id_for_queued_.py
+++ b/airflow/migrations/versions/0087_2_1_4_add_index_on_state_dag_id_for_queued_.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Add index on state, dag_id for queued ``dagrun``
Revision ID: ccde3e26fe78
@@ -23,23 +22,24 @@
Create Date: 2021-09-08 16:35:34.867711
"""
+from __future__ import annotations
from alembic import op
from sqlalchemy import text
# revision identifiers, used by Alembic.
-revision = 'ccde3e26fe78'
-down_revision = '092435bf5d12'
+revision = "ccde3e26fe78"
+down_revision = "092435bf5d12"
branch_labels = None
depends_on = None
-airflow_version = '2.1.4'
+airflow_version = "2.1.4"
def upgrade():
"""Apply Add index on state, dag_id for queued ``dagrun``"""
- with op.batch_alter_table('dag_run') as batch_op:
+ with op.batch_alter_table("dag_run") as batch_op:
batch_op.create_index(
- 'idx_dag_run_queued_dags',
+ "idx_dag_run_queued_dags",
["state", "dag_id"],
postgresql_where=text("state='queued'"),
mssql_where=text("state='queued'"),
@@ -49,5 +49,5 @@ def upgrade():
def downgrade():
"""Unapply Add index on state, dag_id for queued ``dagrun``"""
- with op.batch_alter_table('dag_run') as batch_op:
- batch_op.drop_index('idx_dag_run_queued_dags')
+ with op.batch_alter_table("dag_run") as batch_op:
+ batch_op.drop_index("idx_dag_run_queued_dags")
diff --git a/airflow/migrations/versions/0088_2_2_0_improve_mssql_compatibility.py b/airflow/migrations/versions/0088_2_2_0_improve_mssql_compatibility.py
index 9fdb56ac9bf16..a04a9ffe70b32 100644
--- a/airflow/migrations/versions/0088_2_2_0_improve_mssql_compatibility.py
+++ b/airflow/migrations/versions/0088_2_2_0_improve_mssql_compatibility.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Improve MSSQL compatibility
Revision ID: 83f031fd9f1c
@@ -23,6 +22,7 @@
Create Date: 2021-04-06 12:22:02.197726
"""
+from __future__ import annotations
from collections import defaultdict
@@ -33,11 +33,11 @@
from airflow.migrations.db_types import TIMESTAMP
# revision identifiers, used by Alembic.
-revision = '83f031fd9f1c'
-down_revision = 'ccde3e26fe78'
+revision = "83f031fd9f1c"
+down_revision = "ccde3e26fe78"
branch_labels = None
depends_on = None
-airflow_version = '2.2.0'
+airflow_version = "2.2.0"
def is_table_empty(conn, table_name):
@@ -48,10 +48,10 @@ def is_table_empty(conn, table_name):
:param table_name: table name
:return: Booelan indicating if the table is present
"""
- return conn.execute(f'select TOP 1 * from {table_name}').first() is None
+ return conn.execute(f"select TOP 1 * from {table_name}").first() is None
-def get_table_constraints(conn, table_name):
+def get_table_constraints(conn, table_name) -> dict[tuple[str, str], list[str]]:
"""
This function return primary and unique constraint
along with column name. some tables like task_instance
@@ -62,7 +62,6 @@ def get_table_constraints(conn, table_name):
:param conn: sql connection object
:param table_name: table name
:return: a dictionary of ((constraint name, constraint type), column name) of table
- :rtype: defaultdict(list)
"""
query = f"""SELECT tc.CONSTRAINT_NAME , tc.CONSTRAINT_TYPE, ccu.COLUMN_NAME
FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS AS tc
@@ -87,9 +86,9 @@ def drop_column_constraints(operator, column_name, constraint_dict):
for constraint, columns in constraint_dict.items():
if column_name in columns:
if constraint[1].lower().startswith("primary"):
- operator.drop_constraint(constraint[0], type_='primary')
+ operator.drop_constraint(constraint[0], type_="primary")
elif constraint[1].lower().startswith("unique"):
- operator.drop_constraint(constraint[0], type_='unique')
+ operator.drop_constraint(constraint[0], type_="unique")
def create_constraints(operator, column_name, constraint_dict):
@@ -146,32 +145,32 @@ def alter_mssql_datetime_column(conn, op, table_name, column_name, nullable):
def upgrade():
"""Improve compatibility with MSSQL backend"""
conn = op.get_bind()
- if conn.dialect.name != 'mssql':
+ if conn.dialect.name != "mssql":
return
- recreate_mssql_ts_column(conn, op, 'dag_code', 'last_updated')
- recreate_mssql_ts_column(conn, op, 'rendered_task_instance_fields', 'execution_date')
- alter_mssql_datetime_column(conn, op, 'serialized_dag', 'last_updated', False)
+ recreate_mssql_ts_column(conn, op, "dag_code", "last_updated")
+ recreate_mssql_ts_column(conn, op, "rendered_task_instance_fields", "execution_date")
+ alter_mssql_datetime_column(conn, op, "serialized_dag", "last_updated", False)
op.alter_column(table_name="xcom", column_name="timestamp", type_=TIMESTAMP, nullable=False)
- with op.batch_alter_table('task_reschedule') as task_reschedule_batch_op:
- task_reschedule_batch_op.alter_column(column_name='end_date', type_=TIMESTAMP, nullable=False)
- task_reschedule_batch_op.alter_column(column_name='reschedule_date', type_=TIMESTAMP, nullable=False)
- task_reschedule_batch_op.alter_column(column_name='start_date', type_=TIMESTAMP, nullable=False)
- with op.batch_alter_table('task_fail') as task_fail_batch_op:
- task_fail_batch_op.drop_index('idx_task_fail_dag_task_date')
+ with op.batch_alter_table("task_reschedule") as task_reschedule_batch_op:
+ task_reschedule_batch_op.alter_column(column_name="end_date", type_=TIMESTAMP, nullable=False)
+ task_reschedule_batch_op.alter_column(column_name="reschedule_date", type_=TIMESTAMP, nullable=False)
+ task_reschedule_batch_op.alter_column(column_name="start_date", type_=TIMESTAMP, nullable=False)
+ with op.batch_alter_table("task_fail") as task_fail_batch_op:
+ task_fail_batch_op.drop_index("idx_task_fail_dag_task_date")
task_fail_batch_op.alter_column(column_name="execution_date", type_=TIMESTAMP, nullable=False)
task_fail_batch_op.create_index(
- 'idx_task_fail_dag_task_date', ['dag_id', 'task_id', 'execution_date'], unique=False
+ "idx_task_fail_dag_task_date", ["dag_id", "task_id", "execution_date"], unique=False
)
- with op.batch_alter_table('task_instance') as task_instance_batch_op:
- task_instance_batch_op.drop_index('ti_state_lkp')
+ with op.batch_alter_table("task_instance") as task_instance_batch_op:
+ task_instance_batch_op.drop_index("ti_state_lkp")
task_instance_batch_op.create_index(
- 'ti_state_lkp', ['dag_id', 'task_id', 'execution_date', 'state'], unique=False
+ "ti_state_lkp", ["dag_id", "task_id", "execution_date", "state"], unique=False
)
- constraint_dict = get_table_constraints(conn, 'dag_run')
+ constraint_dict = get_table_constraints(conn, "dag_run")
for constraint, columns in constraint_dict.items():
- if 'dag_id' in columns:
+ if "dag_id" in columns:
if constraint[1].lower().startswith("unique"):
- op.drop_constraint(constraint[0], 'dag_run', type_='unique')
+ op.drop_constraint(constraint[0], "dag_run", type_="unique")
# create filtered indexes
conn.execute(
"""CREATE UNIQUE NONCLUSTERED INDEX idx_not_null_dag_id_execution_date
@@ -188,25 +187,25 @@ def upgrade():
def downgrade():
"""Reverse MSSQL backend compatibility improvements"""
conn = op.get_bind()
- if conn.dialect.name != 'mssql':
+ if conn.dialect.name != "mssql":
return
op.alter_column(table_name="xcom", column_name="timestamp", type_=TIMESTAMP, nullable=True)
- with op.batch_alter_table('task_reschedule') as task_reschedule_batch_op:
- task_reschedule_batch_op.alter_column(column_name='end_date', type_=TIMESTAMP, nullable=True)
- task_reschedule_batch_op.alter_column(column_name='reschedule_date', type_=TIMESTAMP, nullable=True)
- task_reschedule_batch_op.alter_column(column_name='start_date', type_=TIMESTAMP, nullable=True)
- with op.batch_alter_table('task_fail') as task_fail_batch_op:
- task_fail_batch_op.drop_index('idx_task_fail_dag_task_date')
+ with op.batch_alter_table("task_reschedule") as task_reschedule_batch_op:
+ task_reschedule_batch_op.alter_column(column_name="end_date", type_=TIMESTAMP, nullable=True)
+ task_reschedule_batch_op.alter_column(column_name="reschedule_date", type_=TIMESTAMP, nullable=True)
+ task_reschedule_batch_op.alter_column(column_name="start_date", type_=TIMESTAMP, nullable=True)
+ with op.batch_alter_table("task_fail") as task_fail_batch_op:
+ task_fail_batch_op.drop_index("idx_task_fail_dag_task_date")
task_fail_batch_op.alter_column(column_name="execution_date", type_=TIMESTAMP, nullable=False)
task_fail_batch_op.create_index(
- 'idx_task_fail_dag_task_date', ['dag_id', 'task_id', 'execution_date'], unique=False
+ "idx_task_fail_dag_task_date", ["dag_id", "task_id", "execution_date"], unique=False
)
- with op.batch_alter_table('task_instance') as task_instance_batch_op:
- task_instance_batch_op.drop_index('ti_state_lkp')
+ with op.batch_alter_table("task_instance") as task_instance_batch_op:
+ task_instance_batch_op.drop_index("ti_state_lkp")
task_instance_batch_op.create_index(
- 'ti_state_lkp', ['dag_id', 'task_id', 'execution_date'], unique=False
+ "ti_state_lkp", ["dag_id", "task_id", "execution_date"], unique=False
)
- op.create_unique_constraint('UQ__dag_run__dag_id_run_id', 'dag_run', ['dag_id', 'run_id'])
- op.create_unique_constraint('UQ__dag_run__dag_id_execution_date', 'dag_run', ['dag_id', 'execution_date'])
- op.drop_index('idx_not_null_dag_id_execution_date', table_name='dag_run')
- op.drop_index('idx_not_null_dag_id_run_id', table_name='dag_run')
+ op.create_unique_constraint("UQ__dag_run__dag_id_run_id", "dag_run", ["dag_id", "run_id"])
+ op.create_unique_constraint("UQ__dag_run__dag_id_execution_date", "dag_run", ["dag_id", "execution_date"])
+ op.drop_index("idx_not_null_dag_id_execution_date", table_name="dag_run")
+ op.drop_index("idx_not_null_dag_id_run_id", table_name="dag_run")
diff --git a/airflow/migrations/versions/0089_2_2_0_make_xcom_pkey_columns_non_nullable.py b/airflow/migrations/versions/0089_2_2_0_make_xcom_pkey_columns_non_nullable.py
index 45a4559056418..68ef900a2db98 100644
--- a/airflow/migrations/versions/0089_2_2_0_make_xcom_pkey_columns_non_nullable.py
+++ b/airflow/migrations/versions/0089_2_2_0_make_xcom_pkey_columns_non_nullable.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Make XCom primary key columns non-nullable
Revision ID: e9304a3141f0
@@ -23,37 +22,39 @@
Create Date: 2021-04-06 13:22:02.197726
"""
+from __future__ import annotations
+
from alembic import op
from airflow.migrations.db_types import TIMESTAMP, StringID
# revision identifiers, used by Alembic.
-revision = 'e9304a3141f0'
-down_revision = '83f031fd9f1c'
+revision = "e9304a3141f0"
+down_revision = "83f031fd9f1c"
branch_labels = None
depends_on = None
-airflow_version = '2.2.0'
+airflow_version = "2.2.0"
def upgrade():
"""Apply Make XCom primary key columns non-nullable"""
conn = op.get_bind()
- with op.batch_alter_table('xcom') as bop:
+ with op.batch_alter_table("xcom") as bop:
bop.alter_column("key", type_=StringID(length=512), nullable=False)
bop.alter_column("execution_date", type_=TIMESTAMP, nullable=False)
- if conn.dialect.name == 'mssql':
- bop.create_primary_key('pk_xcom', ['dag_id', 'task_id', 'key', 'execution_date'])
+ if conn.dialect.name == "mssql":
+ bop.create_primary_key("pk_xcom", ["dag_id", "task_id", "key", "execution_date"])
def downgrade():
"""Unapply Make XCom primary key columns non-nullable"""
conn = op.get_bind()
- with op.batch_alter_table('xcom') as bop:
+ with op.batch_alter_table("xcom") as bop:
# regardless of what the model defined, the `key` and `execution_date`
# columns were always non-nullable for mysql, sqlite and postgres, so leave them alone
- if conn.dialect.name == 'mssql':
- bop.drop_constraint('pk_xcom', 'primary')
+ if conn.dialect.name == "mssql":
+ bop.drop_constraint("pk_xcom", "primary")
# execution_date and key wasn't nullable in the other databases
bop.alter_column("key", type_=StringID(length=512), nullable=True)
bop.alter_column("execution_date", type_=TIMESTAMP, nullable=True)
diff --git a/airflow/migrations/versions/0090_2_2_0_rename_concurrency_column_in_dag_table_.py b/airflow/migrations/versions/0090_2_2_0_rename_concurrency_column_in_dag_table_.py
index f71fdf38821c8..d1816e32ed42b 100644
--- a/airflow/migrations/versions/0090_2_2_0_rename_concurrency_column_in_dag_table_.py
+++ b/airflow/migrations/versions/0090_2_2_0_rename_concurrency_column_in_dag_table_.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Rename ``concurrency`` column in ``dag`` table to`` max_active_tasks``
Revision ID: 30867afad44a
@@ -23,16 +22,17 @@
Create Date: 2021-06-04 22:11:19.849981
"""
+from __future__ import annotations
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
-revision = '30867afad44a'
-down_revision = 'e9304a3141f0'
+revision = "30867afad44a"
+down_revision = "e9304a3141f0"
branch_labels = None
depends_on = None
-airflow_version = '2.2.0'
+airflow_version = "2.2.0"
def upgrade():
@@ -42,10 +42,10 @@ def upgrade():
if is_sqlite:
op.execute("PRAGMA foreign_keys=off")
- with op.batch_alter_table('dag') as batch_op:
+ with op.batch_alter_table("dag") as batch_op:
batch_op.alter_column(
- 'concurrency',
- new_column_name='max_active_tasks',
+ "concurrency",
+ new_column_name="max_active_tasks",
type_=sa.Integer(),
nullable=False,
)
@@ -55,10 +55,10 @@ def upgrade():
def downgrade():
"""Unapply Rename ``concurrency`` column in ``dag`` table to`` max_active_tasks``"""
- with op.batch_alter_table('dag') as batch_op:
+ with op.batch_alter_table("dag") as batch_op:
batch_op.alter_column(
- 'max_active_tasks',
- new_column_name='concurrency',
+ "max_active_tasks",
+ new_column_name="concurrency",
type_=sa.Integer(),
nullable=False,
)
diff --git a/airflow/migrations/versions/0091_2_2_0_add_trigger_table_and_task_info.py b/airflow/migrations/versions/0091_2_2_0_add_trigger_table_and_task_info.py
index 0f83d4ff8dd59..34d32d9d8c914 100644
--- a/airflow/migrations/versions/0091_2_2_0_add_trigger_table_and_task_info.py
+++ b/airflow/migrations/versions/0091_2_2_0_add_trigger_table_and_task_info.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Adds ``trigger`` table and deferrable operator columns to task instance
Revision ID: 54bebd308c5f
@@ -23,6 +22,7 @@
Create Date: 2021-04-14 12:56:40.688260
"""
+from __future__ import annotations
import sqlalchemy as sa
from alembic import op
@@ -30,41 +30,41 @@
from airflow.utils.sqlalchemy import ExtendedJSON
# revision identifiers, used by Alembic.
-revision = '54bebd308c5f'
-down_revision = '30867afad44a'
+revision = "54bebd308c5f"
+down_revision = "30867afad44a"
branch_labels = None
depends_on = None
-airflow_version = '2.2.0'
+airflow_version = "2.2.0"
def upgrade():
"""Apply Adds ``trigger`` table and deferrable operator columns to task instance"""
op.create_table(
- 'trigger',
- sa.Column('id', sa.Integer(), primary_key=True, nullable=False),
- sa.Column('classpath', sa.String(length=1000), nullable=False),
- sa.Column('kwargs', ExtendedJSON(), nullable=False),
- sa.Column('created_date', sa.DateTime(), nullable=False),
- sa.Column('triggerer_id', sa.Integer(), nullable=True),
+ "trigger",
+ sa.Column("id", sa.Integer(), primary_key=True, nullable=False),
+ sa.Column("classpath", sa.String(length=1000), nullable=False),
+ sa.Column("kwargs", ExtendedJSON(), nullable=False),
+ sa.Column("created_date", sa.DateTime(), nullable=False),
+ sa.Column("triggerer_id", sa.Integer(), nullable=True),
)
- with op.batch_alter_table('task_instance', schema=None) as batch_op:
- batch_op.add_column(sa.Column('trigger_id', sa.Integer()))
- batch_op.add_column(sa.Column('trigger_timeout', sa.DateTime()))
- batch_op.add_column(sa.Column('next_method', sa.String(length=1000)))
- batch_op.add_column(sa.Column('next_kwargs', ExtendedJSON()))
+ with op.batch_alter_table("task_instance", schema=None) as batch_op:
+ batch_op.add_column(sa.Column("trigger_id", sa.Integer()))
+ batch_op.add_column(sa.Column("trigger_timeout", sa.DateTime()))
+ batch_op.add_column(sa.Column("next_method", sa.String(length=1000)))
+ batch_op.add_column(sa.Column("next_kwargs", ExtendedJSON()))
batch_op.create_foreign_key(
- 'task_instance_trigger_id_fkey', 'trigger', ['trigger_id'], ['id'], ondelete="CASCADE"
+ "task_instance_trigger_id_fkey", "trigger", ["trigger_id"], ["id"], ondelete="CASCADE"
)
- batch_op.create_index('ti_trigger_id', ['trigger_id'])
+ batch_op.create_index("ti_trigger_id", ["trigger_id"])
def downgrade():
"""Unapply Adds ``trigger`` table and deferrable operator columns to task instance"""
- with op.batch_alter_table('task_instance', schema=None) as batch_op:
- batch_op.drop_constraint('task_instance_trigger_id_fkey', type_='foreignkey')
- batch_op.drop_index('ti_trigger_id')
- batch_op.drop_column('trigger_id')
- batch_op.drop_column('trigger_timeout')
- batch_op.drop_column('next_method')
- batch_op.drop_column('next_kwargs')
- op.drop_table('trigger')
+ with op.batch_alter_table("task_instance", schema=None) as batch_op:
+ batch_op.drop_constraint("task_instance_trigger_id_fkey", type_="foreignkey")
+ batch_op.drop_index("ti_trigger_id")
+ batch_op.drop_column("trigger_id")
+ batch_op.drop_column("trigger_timeout")
+ batch_op.drop_column("next_method")
+ batch_op.drop_column("next_kwargs")
+ op.drop_table("trigger")
diff --git a/airflow/migrations/versions/0092_2_2_0_add_data_interval_start_end_to_dagmodel_and_dagrun.py b/airflow/migrations/versions/0092_2_2_0_add_data_interval_start_end_to_dagmodel_and_dagrun.py
index 1aaa2230ee571..8a4fac006b40c 100644
--- a/airflow/migrations/versions/0092_2_2_0_add_data_interval_start_end_to_dagmodel_and_dagrun.py
+++ b/airflow/migrations/versions/0092_2_2_0_add_data_interval_start_end_to_dagmodel_and_dagrun.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Add data_interval_[start|end] to DagModel and DagRun.
Revision ID: 142555e44c17
@@ -23,6 +22,7 @@
Create Date: 2021-06-09 08:28:02.089817
"""
+from __future__ import annotations
from alembic import op
from sqlalchemy import Column
@@ -34,7 +34,7 @@
down_revision = "54bebd308c5f"
branch_labels = None
depends_on = None
-airflow_version = '2.2.0'
+airflow_version = "2.2.0"
def upgrade():
diff --git a/airflow/migrations/versions/0093_2_2_0_taskinstance_keyed_to_dagrun.py b/airflow/migrations/versions/0093_2_2_0_taskinstance_keyed_to_dagrun.py
index b7bff79edfa6b..57cc7ac70c549 100644
--- a/airflow/migrations/versions/0093_2_2_0_taskinstance_keyed_to_dagrun.py
+++ b/airflow/migrations/versions/0093_2_2_0_taskinstance_keyed_to_dagrun.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Change ``TaskInstance`` and ``TaskReschedule`` tables from execution_date to run_id.
Revision ID: 7b2661a43ba3
@@ -23,6 +22,7 @@
Create Date: 2021-07-15 15:26:12.710749
"""
+from __future__ import annotations
import sqlalchemy as sa
from alembic import op
@@ -34,33 +34,33 @@
ID_LEN = 250
# revision identifiers, used by Alembic.
-revision = '7b2661a43ba3'
-down_revision = '142555e44c17'
+revision = "7b2661a43ba3"
+down_revision = "142555e44c17"
branch_labels = None
depends_on = None
-airflow_version = '2.2.0'
+airflow_version = "2.2.0"
# Just Enough Table to run the conditions for update.
task_instance = table(
- 'task_instance',
- column('task_id', sa.String),
- column('dag_id', sa.String),
- column('run_id', sa.String),
- column('execution_date', sa.TIMESTAMP),
+ "task_instance",
+ column("task_id", sa.String),
+ column("dag_id", sa.String),
+ column("run_id", sa.String),
+ column("execution_date", sa.TIMESTAMP),
)
task_reschedule = table(
- 'task_reschedule',
- column('task_id', sa.String),
- column('dag_id', sa.String),
- column('run_id', sa.String),
- column('execution_date', sa.TIMESTAMP),
+ "task_reschedule",
+ column("task_id", sa.String),
+ column("dag_id", sa.String),
+ column("run_id", sa.String),
+ column("execution_date", sa.TIMESTAMP),
)
dag_run = table(
- 'dag_run',
- column('dag_id', sa.String),
- column('run_id', sa.String),
- column('execution_date', sa.TIMESTAMP),
+ "dag_run",
+ column("dag_id", sa.String),
+ column("run_id", sa.String),
+ column("execution_date", sa.TIMESTAMP),
)
@@ -72,75 +72,77 @@ def upgrade():
dt_type = TIMESTAMP
string_id_col_type = StringID()
- if dialect_name == 'sqlite':
+ if dialect_name == "sqlite":
naming_convention = {
"uq": "%(table_name)s_%(column_0_N_name)s_key",
}
# The naming_convention force the previously un-named UNIQUE constraints to have the right name
with op.batch_alter_table(
- 'dag_run', naming_convention=naming_convention, recreate="always"
+ "dag_run", naming_convention=naming_convention, recreate="always"
) as batch_op:
- batch_op.alter_column('dag_id', existing_type=string_id_col_type, nullable=False)
- batch_op.alter_column('run_id', existing_type=string_id_col_type, nullable=False)
- batch_op.alter_column('execution_date', existing_type=dt_type, nullable=False)
- elif dialect_name == 'mysql':
- with op.batch_alter_table('dag_run') as batch_op:
+ batch_op.alter_column("dag_id", existing_type=string_id_col_type, nullable=False)
+ batch_op.alter_column("run_id", existing_type=string_id_col_type, nullable=False)
+ batch_op.alter_column("execution_date", existing_type=dt_type, nullable=False)
+ elif dialect_name == "mysql":
+ with op.batch_alter_table("dag_run") as batch_op:
batch_op.alter_column(
- 'dag_id', existing_type=sa.String(length=ID_LEN), type_=string_id_col_type, nullable=False
+ "dag_id", existing_type=sa.String(length=ID_LEN), type_=string_id_col_type, nullable=False
)
batch_op.alter_column(
- 'run_id', existing_type=sa.String(length=ID_LEN), type_=string_id_col_type, nullable=False
+ "run_id", existing_type=sa.String(length=ID_LEN), type_=string_id_col_type, nullable=False
)
- batch_op.alter_column('execution_date', existing_type=dt_type, nullable=False)
- batch_op.drop_constraint('dag_id', 'unique')
- batch_op.drop_constraint('dag_id_2', 'unique')
+ batch_op.alter_column("execution_date", existing_type=dt_type, nullable=False)
+ inspector = sa.inspect(conn.engine)
+ unique_keys = inspector.get_unique_constraints("dag_run")
+ for unique_key in unique_keys:
+ batch_op.drop_constraint(unique_key["name"], type_="unique")
batch_op.create_unique_constraint(
- 'dag_run_dag_id_execution_date_key', ['dag_id', 'execution_date']
+ "dag_run_dag_id_execution_date_key", ["dag_id", "execution_date"]
)
- batch_op.create_unique_constraint('dag_run_dag_id_run_id_key', ['dag_id', 'run_id'])
- elif dialect_name == 'mssql':
+ batch_op.create_unique_constraint("dag_run_dag_id_run_id_key", ["dag_id", "run_id"])
+ elif dialect_name == "mssql":
- with op.batch_alter_table('dag_run') as batch_op:
- batch_op.drop_index('idx_not_null_dag_id_execution_date')
- batch_op.drop_index('idx_not_null_dag_id_run_id')
+ with op.batch_alter_table("dag_run") as batch_op:
+ batch_op.drop_index("idx_not_null_dag_id_execution_date")
+ batch_op.drop_index("idx_not_null_dag_id_run_id")
- batch_op.drop_index('dag_id_state')
- batch_op.drop_index('idx_dag_run_dag_id')
- batch_op.drop_index('idx_dag_run_running_dags')
- batch_op.drop_index('idx_dag_run_queued_dags')
+ batch_op.drop_index("dag_id_state")
+ batch_op.drop_index("idx_dag_run_dag_id")
+ batch_op.drop_index("idx_dag_run_running_dags")
+ batch_op.drop_index("idx_dag_run_queued_dags")
- batch_op.alter_column('dag_id', existing_type=string_id_col_type, nullable=False)
- batch_op.alter_column('execution_date', existing_type=dt_type, nullable=False)
- batch_op.alter_column('run_id', existing_type=string_id_col_type, nullable=False)
+ batch_op.alter_column("dag_id", existing_type=string_id_col_type, nullable=False)
+ batch_op.alter_column("execution_date", existing_type=dt_type, nullable=False)
+ batch_op.alter_column("run_id", existing_type=string_id_col_type, nullable=False)
# _Somehow_ mssql was missing these constraints entirely
batch_op.create_unique_constraint(
- 'dag_run_dag_id_execution_date_key', ['dag_id', 'execution_date']
+ "dag_run_dag_id_execution_date_key", ["dag_id", "execution_date"]
)
- batch_op.create_unique_constraint('dag_run_dag_id_run_id_key', ['dag_id', 'run_id'])
+ batch_op.create_unique_constraint("dag_run_dag_id_run_id_key", ["dag_id", "run_id"])
- batch_op.create_index('dag_id_state', ['dag_id', 'state'], unique=False)
- batch_op.create_index('idx_dag_run_dag_id', ['dag_id'])
+ batch_op.create_index("dag_id_state", ["dag_id", "state"], unique=False)
+ batch_op.create_index("idx_dag_run_dag_id", ["dag_id"])
batch_op.create_index(
- 'idx_dag_run_running_dags',
+ "idx_dag_run_running_dags",
["state", "dag_id"],
mssql_where=sa.text("state='running'"),
)
batch_op.create_index(
- 'idx_dag_run_queued_dags',
+ "idx_dag_run_queued_dags",
["state", "dag_id"],
mssql_where=sa.text("state='queued'"),
)
else:
# Make sure DagRun PK columns are non-nullable
- with op.batch_alter_table('dag_run', schema=None) as batch_op:
- batch_op.alter_column('dag_id', existing_type=string_id_col_type, nullable=False)
- batch_op.alter_column('execution_date', existing_type=dt_type, nullable=False)
- batch_op.alter_column('run_id', existing_type=string_id_col_type, nullable=False)
+ with op.batch_alter_table("dag_run", schema=None) as batch_op:
+ batch_op.alter_column("dag_id", existing_type=string_id_col_type, nullable=False)
+ batch_op.alter_column("execution_date", existing_type=dt_type, nullable=False)
+ batch_op.alter_column("run_id", existing_type=string_id_col_type, nullable=False)
# First create column nullable
- op.add_column('task_instance', sa.Column('run_id', type_=string_id_col_type, nullable=True))
- op.add_column('task_reschedule', sa.Column('run_id', type_=string_id_col_type, nullable=True))
+ op.add_column("task_instance", sa.Column("run_id", type_=string_id_col_type, nullable=True))
+ op.add_column("task_reschedule", sa.Column("run_id", type_=string_id_col_type, nullable=True))
#
# TaskReschedule has a FK to TaskInstance, so we have to update that before
@@ -149,23 +151,23 @@ def upgrade():
update_query = _multi_table_update(dialect_name, task_reschedule, task_reschedule.c.run_id)
op.execute(update_query)
- with op.batch_alter_table('task_reschedule', schema=None) as batch_op:
+ with op.batch_alter_table("task_reschedule", schema=None) as batch_op:
batch_op.alter_column(
- 'run_id', existing_type=string_id_col_type, existing_nullable=True, nullable=False
+ "run_id", existing_type=string_id_col_type, existing_nullable=True, nullable=False
)
- batch_op.drop_constraint('task_reschedule_dag_task_date_fkey', 'foreignkey')
+ batch_op.drop_constraint("task_reschedule_dag_task_date_fkey", "foreignkey")
if dialect_name == "mysql":
# Mysql creates an index and a constraint -- we have to drop both
- batch_op.drop_index('task_reschedule_dag_task_date_fkey')
+ batch_op.drop_index("task_reschedule_dag_task_date_fkey")
batch_op.alter_column(
- 'dag_id', existing_type=sa.String(length=ID_LEN), type_=string_id_col_type, nullable=False
+ "dag_id", existing_type=sa.String(length=ID_LEN), type_=string_id_col_type, nullable=False
)
- batch_op.drop_index('idx_task_reschedule_dag_task_date')
+ batch_op.drop_index("idx_task_reschedule_dag_task_date")
# Then update the new column by selecting the right value from DagRun
# But first we will drop and recreate indexes to make it faster
- if dialect_name == 'postgresql':
+ if dialect_name == "postgresql":
# Recreate task_instance, without execution_date and with dagrun.run_id
op.execute(
"""
@@ -200,87 +202,87 @@ def upgrade():
INNER JOIN dag_run ON dag_run.dag_id = ti.dag_id AND dag_run.execution_date = ti.execution_date;
"""
)
- op.drop_table('task_instance')
- op.rename_table('new_task_instance', 'task_instance')
+ op.drop_table("task_instance")
+ op.rename_table("new_task_instance", "task_instance")
# Fix up columns after the 'create table as select'
- with op.batch_alter_table('task_instance', schema=None) as batch_op:
+ with op.batch_alter_table("task_instance", schema=None) as batch_op:
batch_op.alter_column(
- 'pool', existing_type=string_id_col_type, existing_nullable=True, nullable=False
+ "pool", existing_type=string_id_col_type, existing_nullable=True, nullable=False
)
- batch_op.alter_column('max_tries', existing_type=sa.Integer(), server_default="-1")
+ batch_op.alter_column("max_tries", existing_type=sa.Integer(), server_default="-1")
batch_op.alter_column(
- 'pool_slots', existing_type=sa.Integer(), existing_nullable=True, nullable=False
+ "pool_slots", existing_type=sa.Integer(), existing_nullable=True, nullable=False
)
else:
update_query = _multi_table_update(dialect_name, task_instance, task_instance.c.run_id)
op.execute(update_query)
- with op.batch_alter_table('task_instance', schema=None) as batch_op:
- if dialect_name != 'postgresql':
+ with op.batch_alter_table("task_instance", schema=None) as batch_op:
+ if dialect_name != "postgresql":
# TODO: Is this right for non-postgres?
- if dialect_name == 'mssql':
+ if dialect_name == "mssql":
constraints = get_mssql_table_constraints(conn, "task_instance")
- pk, _ = constraints['PRIMARY KEY'].popitem()
- batch_op.drop_constraint(pk, type_='primary')
- elif dialect_name not in ('sqlite'):
- batch_op.drop_constraint('task_instance_pkey', type_='primary')
- batch_op.drop_index('ti_dag_date')
- batch_op.drop_index('ti_state_lkp')
- batch_op.drop_column('execution_date')
+ pk, _ = constraints["PRIMARY KEY"].popitem()
+ batch_op.drop_constraint(pk, type_="primary")
+ elif dialect_name not in ("sqlite"):
+ batch_op.drop_constraint("task_instance_pkey", type_="primary")
+ batch_op.drop_index("ti_dag_date")
+ batch_op.drop_index("ti_state_lkp")
+ batch_op.drop_column("execution_date")
# Then make it non-nullable
batch_op.alter_column(
- 'run_id', existing_type=string_id_col_type, existing_nullable=True, nullable=False
+ "run_id", existing_type=string_id_col_type, existing_nullable=True, nullable=False
)
batch_op.alter_column(
- 'dag_id', existing_type=string_id_col_type, existing_nullable=True, nullable=False
+ "dag_id", existing_type=string_id_col_type, existing_nullable=True, nullable=False
)
- batch_op.create_primary_key('task_instance_pkey', ['dag_id', 'task_id', 'run_id'])
+ batch_op.create_primary_key("task_instance_pkey", ["dag_id", "task_id", "run_id"])
batch_op.create_foreign_key(
- 'task_instance_dag_run_fkey',
- 'dag_run',
- ['dag_id', 'run_id'],
- ['dag_id', 'run_id'],
- ondelete='CASCADE',
+ "task_instance_dag_run_fkey",
+ "dag_run",
+ ["dag_id", "run_id"],
+ ["dag_id", "run_id"],
+ ondelete="CASCADE",
)
- batch_op.create_index('ti_dag_run', ['dag_id', 'run_id'])
- batch_op.create_index('ti_state_lkp', ['dag_id', 'task_id', 'run_id', 'state'])
- if dialect_name == 'postgresql':
- batch_op.create_index('ti_dag_state', ['dag_id', 'state'])
- batch_op.create_index('ti_job_id', ['job_id'])
- batch_op.create_index('ti_pool', ['pool', 'state', 'priority_weight'])
- batch_op.create_index('ti_state', ['state'])
+ batch_op.create_index("ti_dag_run", ["dag_id", "run_id"])
+ batch_op.create_index("ti_state_lkp", ["dag_id", "task_id", "run_id", "state"])
+ if dialect_name == "postgresql":
+ batch_op.create_index("ti_dag_state", ["dag_id", "state"])
+ batch_op.create_index("ti_job_id", ["job_id"])
+ batch_op.create_index("ti_pool", ["pool", "state", "priority_weight"])
+ batch_op.create_index("ti_state", ["state"])
batch_op.create_foreign_key(
- 'task_instance_trigger_id_fkey', 'trigger', ['trigger_id'], ['id'], ondelete="CASCADE"
+ "task_instance_trigger_id_fkey", "trigger", ["trigger_id"], ["id"], ondelete="CASCADE"
)
- batch_op.create_index('ti_trigger_id', ['trigger_id'])
+ batch_op.create_index("ti_trigger_id", ["trigger_id"])
- with op.batch_alter_table('task_reschedule', schema=None) as batch_op:
- batch_op.drop_column('execution_date')
+ with op.batch_alter_table("task_reschedule", schema=None) as batch_op:
+ batch_op.drop_column("execution_date")
batch_op.create_index(
- 'idx_task_reschedule_dag_task_run',
- ['dag_id', 'task_id', 'run_id'],
+ "idx_task_reschedule_dag_task_run",
+ ["dag_id", "task_id", "run_id"],
unique=False,
)
# _Now_ there is a unique constraint on the columns in TI we can re-create the FK from TaskReschedule
batch_op.create_foreign_key(
- 'task_reschedule_ti_fkey',
- 'task_instance',
- ['dag_id', 'task_id', 'run_id'],
- ['dag_id', 'task_id', 'run_id'],
- ondelete='CASCADE',
+ "task_reschedule_ti_fkey",
+ "task_instance",
+ ["dag_id", "task_id", "run_id"],
+ ["dag_id", "task_id", "run_id"],
+ ondelete="CASCADE",
)
# https://docs.microsoft.com/en-us/sql/relational-databases/errors-events/mssqlserver-1785-database-engine-error?view=sql-server-ver15
- ondelete = 'CASCADE' if dialect_name != 'mssql' else 'NO ACTION'
+ ondelete = "CASCADE" if dialect_name != "mssql" else "NO ACTION"
batch_op.create_foreign_key(
- 'task_reschedule_dr_fkey',
- 'dag_run',
- ['dag_id', 'run_id'],
- ['dag_id', 'run_id'],
+ "task_reschedule_dr_fkey",
+ "dag_run",
+ ["dag_id", "run_id"],
+ ["dag_id", "run_id"],
ondelete=ondelete,
)
@@ -291,8 +293,8 @@ def downgrade():
dt_type = TIMESTAMP
string_id_col_type = StringID()
- op.add_column('task_instance', sa.Column('execution_date', dt_type, nullable=True))
- op.add_column('task_reschedule', sa.Column('execution_date', dt_type, nullable=True))
+ op.add_column("task_instance", sa.Column("execution_date", dt_type, nullable=True))
+ op.add_column("task_reschedule", sa.Column("execution_date", dt_type, nullable=True))
update_query = _multi_table_update(dialect_name, task_instance, task_instance.c.execution_date)
op.execute(update_query)
@@ -300,71 +302,75 @@ def downgrade():
update_query = _multi_table_update(dialect_name, task_reschedule, task_reschedule.c.execution_date)
op.execute(update_query)
- with op.batch_alter_table('task_reschedule', schema=None) as batch_op:
- batch_op.alter_column('execution_date', existing_type=dt_type, existing_nullable=True, nullable=False)
+ with op.batch_alter_table("task_reschedule", schema=None) as batch_op:
+ batch_op.alter_column("execution_date", existing_type=dt_type, existing_nullable=True, nullable=False)
# Can't drop PK index while there is a FK referencing it
- batch_op.drop_constraint('task_reschedule_ti_fkey', type_='foreignkey')
- batch_op.drop_constraint('task_reschedule_dr_fkey', type_='foreignkey')
- batch_op.drop_index('idx_task_reschedule_dag_task_run')
-
- with op.batch_alter_table('task_instance', schema=None) as batch_op:
- batch_op.drop_constraint('task_instance_pkey', type_='primary')
- batch_op.alter_column('execution_date', existing_type=dt_type, existing_nullable=True, nullable=False)
- if dialect_name != 'mssql':
+ batch_op.drop_constraint("task_reschedule_ti_fkey", type_="foreignkey")
+ batch_op.drop_constraint("task_reschedule_dr_fkey", type_="foreignkey")
+ batch_op.drop_index("idx_task_reschedule_dag_task_run")
+
+ with op.batch_alter_table("task_instance", schema=None) as batch_op:
+ batch_op.drop_constraint("task_instance_pkey", type_="primary")
+ batch_op.alter_column("execution_date", existing_type=dt_type, existing_nullable=True, nullable=False)
+ if dialect_name != "mssql":
batch_op.alter_column(
- 'dag_id', existing_type=string_id_col_type, existing_nullable=False, nullable=True
+ "dag_id", existing_type=string_id_col_type, existing_nullable=False, nullable=True
)
- batch_op.create_primary_key('task_instance_pkey', ['dag_id', 'task_id', 'execution_date'])
+ batch_op.create_primary_key("task_instance_pkey", ["dag_id", "task_id", "execution_date"])
- batch_op.drop_constraint('task_instance_dag_run_fkey', type_='foreignkey')
- batch_op.drop_index('ti_dag_run')
- batch_op.drop_index('ti_state_lkp')
- batch_op.create_index('ti_state_lkp', ['dag_id', 'task_id', 'execution_date', 'state'])
- batch_op.create_index('ti_dag_date', ['dag_id', 'execution_date'], unique=False)
+ batch_op.drop_constraint("task_instance_dag_run_fkey", type_="foreignkey")
+ batch_op.drop_index("ti_dag_run")
+ batch_op.drop_index("ti_state_lkp")
+ batch_op.create_index("ti_state_lkp", ["dag_id", "task_id", "execution_date", "state"])
+ batch_op.create_index("ti_dag_date", ["dag_id", "execution_date"], unique=False)
- batch_op.drop_column('run_id')
+ batch_op.drop_column("run_id")
- with op.batch_alter_table('task_reschedule', schema=None) as batch_op:
- batch_op.drop_column('run_id')
+ with op.batch_alter_table("task_reschedule", schema=None) as batch_op:
+ batch_op.drop_column("run_id")
batch_op.create_index(
- 'idx_task_reschedule_dag_task_date',
- ['dag_id', 'task_id', 'execution_date'],
+ "idx_task_reschedule_dag_task_date",
+ ["dag_id", "task_id", "execution_date"],
unique=False,
)
# Can only create FK once there is an index on these columns
batch_op.create_foreign_key(
- 'task_reschedule_dag_task_date_fkey',
- 'task_instance',
- ['dag_id', 'task_id', 'execution_date'],
- ['dag_id', 'task_id', 'execution_date'],
- ondelete='CASCADE',
+ "task_reschedule_dag_task_date_fkey",
+ "task_instance",
+ ["dag_id", "task_id", "execution_date"],
+ ["dag_id", "task_id", "execution_date"],
+ ondelete="CASCADE",
)
+ if dialect_name == "mysql":
+ batch_op.create_index(
+ "task_reschedule_dag_task_date_fkey", ["dag_id", "execution_date"], unique=False
+ )
if dialect_name == "mssql":
- with op.batch_alter_table('dag_run', schema=None) as batch_op:
- batch_op.drop_constraint('dag_run_dag_id_execution_date_key', 'unique')
- batch_op.drop_constraint('dag_run_dag_id_run_id_key', 'unique')
- batch_op.drop_index('dag_id_state')
- batch_op.drop_index('idx_dag_run_running_dags')
- batch_op.drop_index('idx_dag_run_queued_dags')
- batch_op.drop_index('idx_dag_run_dag_id')
+ with op.batch_alter_table("dag_run", schema=None) as batch_op:
+ batch_op.drop_constraint("dag_run_dag_id_execution_date_key", "unique")
+ batch_op.drop_constraint("dag_run_dag_id_run_id_key", "unique")
+ batch_op.drop_index("dag_id_state")
+ batch_op.drop_index("idx_dag_run_running_dags")
+ batch_op.drop_index("idx_dag_run_queued_dags")
+ batch_op.drop_index("idx_dag_run_dag_id")
- batch_op.alter_column('dag_id', existing_type=string_id_col_type, nullable=True)
- batch_op.alter_column('execution_date', existing_type=dt_type, nullable=True)
- batch_op.alter_column('run_id', existing_type=string_id_col_type, nullable=True)
+ batch_op.alter_column("dag_id", existing_type=string_id_col_type, nullable=True)
+ batch_op.alter_column("execution_date", existing_type=dt_type, nullable=True)
+ batch_op.alter_column("run_id", existing_type=string_id_col_type, nullable=True)
- batch_op.create_index('dag_id_state', ['dag_id', 'state'], unique=False)
- batch_op.create_index('idx_dag_run_dag_id', ['dag_id'])
+ batch_op.create_index("dag_id_state", ["dag_id", "state"], unique=False)
+ batch_op.create_index("idx_dag_run_dag_id", ["dag_id"])
batch_op.create_index(
- 'idx_dag_run_running_dags',
+ "idx_dag_run_running_dags",
["state", "dag_id"],
mssql_where=sa.text("state='running'"),
)
batch_op.create_index(
- 'idx_dag_run_queued_dags',
+ "idx_dag_run_queued_dags",
["state", "dag_id"],
mssql_where=sa.text("state='queued'"),
)
@@ -379,12 +385,12 @@ def downgrade():
WHERE dag_id IS NOT NULL and run_id is not null"""
)
else:
- with op.batch_alter_table('dag_run', schema=None) as batch_op:
- batch_op.drop_index('dag_id_state')
- batch_op.alter_column('run_id', existing_type=sa.VARCHAR(length=250), nullable=True)
- batch_op.alter_column('execution_date', existing_type=dt_type, nullable=True)
- batch_op.alter_column('dag_id', existing_type=sa.VARCHAR(length=250), nullable=True)
- batch_op.create_index('dag_id_state', ['dag_id', 'state'], unique=False)
+ with op.batch_alter_table("dag_run", schema=None) as batch_op:
+ batch_op.drop_index("dag_id_state")
+ batch_op.alter_column("run_id", existing_type=sa.VARCHAR(length=250), nullable=True)
+ batch_op.alter_column("execution_date", existing_type=dt_type, nullable=True)
+ batch_op.alter_column("dag_id", existing_type=sa.VARCHAR(length=250), nullable=True)
+ batch_op.create_index("dag_id_state", ["dag_id", "state"], unique=False)
def _multi_table_update(dialect_name, target, column):
diff --git a/airflow/migrations/versions/0094_2_2_3_add_has_import_errors_column_to_dagmodel.py b/airflow/migrations/versions/0094_2_2_3_add_has_import_errors_column_to_dagmodel.py
index d401e241b862d..fe931e62ddf77 100644
--- a/airflow/migrations/versions/0094_2_2_3_add_has_import_errors_column_to_dagmodel.py
+++ b/airflow/migrations/versions/0094_2_2_3_add_has_import_errors_column_to_dagmodel.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Add has_import_errors column to DagModel
Revision ID: be2bfac3da23
@@ -23,24 +22,25 @@
Create Date: 2021-11-04 20:33:11.009547
"""
+from __future__ import annotations
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
-revision = 'be2bfac3da23'
-down_revision = '7b2661a43ba3'
+revision = "be2bfac3da23"
+down_revision = "7b2661a43ba3"
branch_labels = None
depends_on = None
-airflow_version = '2.2.3'
+airflow_version = "2.2.3"
def upgrade():
"""Apply Add has_import_errors column to DagModel"""
- op.add_column("dag", sa.Column("has_import_errors", sa.Boolean(), server_default='0'))
+ op.add_column("dag", sa.Column("has_import_errors", sa.Boolean(), server_default="0"))
def downgrade():
"""Unapply Add has_import_errors column to DagModel"""
- with op.batch_alter_table('dag') as batch_op:
- batch_op.drop_column('has_import_errors', mssql_drop_default=True)
+ with op.batch_alter_table("dag") as batch_op:
+ batch_op.drop_column("has_import_errors", mssql_drop_default=True)
diff --git a/airflow/migrations/versions/0095_2_2_4_add_session_table_to_db.py b/airflow/migrations/versions/0095_2_2_4_add_session_table_to_db.py
index 3b70acbdab529..ba5f32253164a 100644
--- a/airflow/migrations/versions/0095_2_2_4_add_session_table_to_db.py
+++ b/airflow/migrations/versions/0095_2_2_4_add_session_table_to_db.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Create a ``session`` table to store web session data
Revision ID: c381b21cb7e4
@@ -23,30 +22,31 @@
Create Date: 2022-01-25 13:56:35.069429
"""
+from __future__ import annotations
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
-revision = 'c381b21cb7e4'
-down_revision = 'be2bfac3da23'
+revision = "c381b21cb7e4"
+down_revision = "be2bfac3da23"
branch_labels = None
depends_on = None
-airflow_version = '2.2.4'
+airflow_version = "2.2.4"
-TABLE_NAME = 'session'
+TABLE_NAME = "session"
def upgrade():
"""Apply Create a ``session`` table to store web session data"""
op.create_table(
TABLE_NAME,
- sa.Column('id', sa.Integer()),
- sa.Column('session_id', sa.String(255)),
- sa.Column('data', sa.LargeBinary()),
- sa.Column('expiry', sa.DateTime()),
- sa.PrimaryKeyConstraint('id'),
- sa.UniqueConstraint('session_id'),
+ sa.Column("id", sa.Integer()),
+ sa.Column("session_id", sa.String(255)),
+ sa.Column("data", sa.LargeBinary()),
+ sa.Column("expiry", sa.DateTime()),
+ sa.PrimaryKeyConstraint("id"),
+ sa.UniqueConstraint("session_id"),
)
diff --git a/airflow/migrations/versions/0096_2_2_4_adding_index_for_dag_id_in_job.py b/airflow/migrations/versions/0096_2_2_4_adding_index_for_dag_id_in_job.py
index 2d08d7f922a58..447081a40a763 100644
--- a/airflow/migrations/versions/0096_2_2_4_adding_index_for_dag_id_in_job.py
+++ b/airflow/migrations/versions/0096_2_2_4_adding_index_for_dag_id_in_job.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Add index for ``dag_id`` column in ``job`` table.
Revision ID: 587bdf053233
@@ -23,22 +22,23 @@
Create Date: 2021-12-14 10:20:12.482940
"""
+from __future__ import annotations
from alembic import op
# revision identifiers, used by Alembic.
-revision = '587bdf053233'
-down_revision = 'c381b21cb7e4'
+revision = "587bdf053233"
+down_revision = "c381b21cb7e4"
branch_labels = None
depends_on = None
-airflow_version = '2.2.4'
+airflow_version = "2.2.4"
def upgrade():
"""Apply Add index for ``dag_id`` column in ``job`` table."""
- op.create_index('idx_job_dag_id', 'job', ['dag_id'], unique=False)
+ op.create_index("idx_job_dag_id", "job", ["dag_id"], unique=False)
def downgrade():
"""Unapply Add index for ``dag_id`` column in ``job`` table."""
- op.drop_index('idx_job_dag_id', table_name='job')
+ op.drop_index("idx_job_dag_id", table_name="job")
diff --git a/airflow/migrations/versions/0097_2_3_0_increase_length_of_email_and_username.py b/airflow/migrations/versions/0097_2_3_0_increase_length_of_email_and_username.py
index be028e83378fe..b62b3136ae9ce 100644
--- a/airflow/migrations/versions/0097_2_3_0_increase_length_of_email_and_username.py
+++ b/airflow/migrations/versions/0097_2_3_0_increase_length_of_email_and_username.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Increase length of email and username in ``ab_user`` and ``ab_register_user`` table to ``256`` characters
Revision ID: 5e3ec427fdd3
@@ -23,6 +22,7 @@
Create Date: 2021-12-01 11:49:26.390210
"""
+from __future__ import annotations
import sqlalchemy as sa
from alembic import op
@@ -30,53 +30,53 @@
from airflow.migrations.utils import get_mssql_table_constraints
# revision identifiers, used by Alembic.
-revision = '5e3ec427fdd3'
-down_revision = '587bdf053233'
+revision = "5e3ec427fdd3"
+down_revision = "587bdf053233"
branch_labels = None
depends_on = None
-airflow_version = '2.3.0'
+airflow_version = "2.3.0"
def upgrade():
"""Increase length of email from 64 to 256 characters"""
- with op.batch_alter_table('ab_user') as batch_op:
- batch_op.alter_column('username', type_=sa.String(256))
- batch_op.alter_column('email', type_=sa.String(256))
- with op.batch_alter_table('ab_register_user') as batch_op:
- batch_op.alter_column('username', type_=sa.String(256))
- batch_op.alter_column('email', type_=sa.String(256))
+ with op.batch_alter_table("ab_user") as batch_op:
+ batch_op.alter_column("username", type_=sa.String(256))
+ batch_op.alter_column("email", type_=sa.String(256))
+ with op.batch_alter_table("ab_register_user") as batch_op:
+ batch_op.alter_column("username", type_=sa.String(256))
+ batch_op.alter_column("email", type_=sa.String(256))
def downgrade():
"""Revert length of email from 256 to 64 characters"""
conn = op.get_bind()
- if conn.dialect.name != 'mssql':
- with op.batch_alter_table('ab_user') as batch_op:
- batch_op.alter_column('username', type_=sa.String(64), nullable=False)
- batch_op.alter_column('email', type_=sa.String(64))
- with op.batch_alter_table('ab_register_user') as batch_op:
- batch_op.alter_column('username', type_=sa.String(64))
- batch_op.alter_column('email', type_=sa.String(64))
+ if conn.dialect.name != "mssql":
+ with op.batch_alter_table("ab_user") as batch_op:
+ batch_op.alter_column("username", type_=sa.String(64), nullable=False)
+ batch_op.alter_column("email", type_=sa.String(64))
+ with op.batch_alter_table("ab_register_user") as batch_op:
+ batch_op.alter_column("username", type_=sa.String(64))
+ batch_op.alter_column("email", type_=sa.String(64))
else:
# MSSQL doesn't drop implicit unique constraints it created
# We need to drop the two unique constraints explicitly
- with op.batch_alter_table('ab_user') as batch_op:
+ with op.batch_alter_table("ab_user") as batch_op:
# Drop the unique constraint on username and email
- constraints = get_mssql_table_constraints(conn, 'ab_user')
- unique_key, _ = constraints['UNIQUE'].popitem()
- batch_op.drop_constraint(unique_key, type_='unique')
- unique_key, _ = constraints['UNIQUE'].popitem()
- batch_op.drop_constraint(unique_key, type_='unique')
- batch_op.alter_column('username', type_=sa.String(64), nullable=False)
- batch_op.create_unique_constraint(None, ['username'])
- batch_op.alter_column('email', type_=sa.String(64))
- batch_op.create_unique_constraint(None, ['email'])
+ constraints = get_mssql_table_constraints(conn, "ab_user")
+ unique_key, _ = constraints["UNIQUE"].popitem()
+ batch_op.drop_constraint(unique_key, type_="unique")
+ unique_key, _ = constraints["UNIQUE"].popitem()
+ batch_op.drop_constraint(unique_key, type_="unique")
+ batch_op.alter_column("username", type_=sa.String(64), nullable=False)
+ batch_op.create_unique_constraint(None, ["username"])
+ batch_op.alter_column("email", type_=sa.String(64))
+ batch_op.create_unique_constraint(None, ["email"])
- with op.batch_alter_table('ab_register_user') as batch_op:
+ with op.batch_alter_table("ab_register_user") as batch_op:
# Drop the unique constraint on username and email
- constraints = get_mssql_table_constraints(conn, 'ab_register_user')
- for k, _ in constraints.get('UNIQUE').items():
- batch_op.drop_constraint(k, type_='unique')
- batch_op.alter_column('username', type_=sa.String(64))
- batch_op.create_unique_constraint(None, ['username'])
- batch_op.alter_column('email', type_=sa.String(64))
+ constraints = get_mssql_table_constraints(conn, "ab_register_user")
+ for k, _ in constraints.get("UNIQUE").items():
+ batch_op.drop_constraint(k, type_="unique")
+ batch_op.alter_column("username", type_=sa.String(64))
+ batch_op.create_unique_constraint(None, ["username"])
+ batch_op.alter_column("email", type_=sa.String(64))
diff --git a/airflow/migrations/versions/0098_2_3_0_added_timetable_description_column.py b/airflow/migrations/versions/0098_2_3_0_added_timetable_description_column.py
index 811ece55c54bc..9d9d8c72eea10 100644
--- a/airflow/migrations/versions/0098_2_3_0_added_timetable_description_column.py
+++ b/airflow/migrations/versions/0098_2_3_0_added_timetable_description_column.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Add ``timetable_description`` column to DagModel for UI.
Revision ID: 786e3737b18f
@@ -23,30 +22,31 @@
Create Date: 2021-10-15 13:33:04.754052
"""
+from __future__ import annotations
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
-revision = '786e3737b18f'
-down_revision = '5e3ec427fdd3'
+revision = "786e3737b18f"
+down_revision = "5e3ec427fdd3"
branch_labels = None
depends_on = None
-airflow_version = '2.3.0'
+airflow_version = "2.3.0"
def upgrade():
"""Apply Add ``timetable_description`` column to DagModel for UI."""
- with op.batch_alter_table('dag', schema=None) as batch_op:
- batch_op.add_column(sa.Column('timetable_description', sa.String(length=1000), nullable=True))
+ with op.batch_alter_table("dag", schema=None) as batch_op:
+ batch_op.add_column(sa.Column("timetable_description", sa.String(length=1000), nullable=True))
def downgrade():
"""Unapply Add ``timetable_description`` column to DagModel for UI."""
- is_sqlite = bool(op.get_bind().dialect.name == 'sqlite')
+ is_sqlite = bool(op.get_bind().dialect.name == "sqlite")
if is_sqlite:
- op.execute('PRAGMA foreign_keys=off')
- with op.batch_alter_table('dag') as batch_op:
- batch_op.drop_column('timetable_description')
+ op.execute("PRAGMA foreign_keys=off")
+ with op.batch_alter_table("dag") as batch_op:
+ batch_op.drop_column("timetable_description")
if is_sqlite:
- op.execute('PRAGMA foreign_keys=on')
+ op.execute("PRAGMA foreign_keys=on")
diff --git a/airflow/migrations/versions/0099_2_3_0_add_task_log_filename_template_model.py b/airflow/migrations/versions/0099_2_3_0_add_task_log_filename_template_model.py
index 05fc2767fa9fe..2fa870943c186 100644
--- a/airflow/migrations/versions/0099_2_3_0_add_task_log_filename_template_model.py
+++ b/airflow/migrations/versions/0099_2_3_0_add_task_log_filename_template_model.py
@@ -15,13 +15,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Add ``LogTemplate`` table to track changes to config values ``log_filename_template``
Revision ID: f9da662e7089
Revises: 786e3737b18f
Create Date: 2021-12-09 06:11:21.044940
"""
+from __future__ import annotations
from alembic import op
from sqlalchemy import Column, ForeignKey, Integer, Text
@@ -34,7 +34,7 @@
down_revision = "786e3737b18f"
branch_labels = None
depends_on = None
-airflow_version = '2.3.0'
+airflow_version = "2.3.0"
def upgrade():
diff --git a/airflow/migrations/versions/0100_2_3_0_add_taskmap_and_map_id_on_taskinstance.py b/airflow/migrations/versions/0100_2_3_0_add_taskmap_and_map_id_on_taskinstance.py
index 741c02d0ff6fb..d11953009aeed 100644
--- a/airflow/migrations/versions/0100_2_3_0_add_taskmap_and_map_id_on_taskinstance.py
+++ b/airflow/migrations/versions/0100_2_3_0_add_taskmap_and_map_id_on_taskinstance.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Add ``map_index`` column to TaskInstance to identify task-mapping,
and a ``task_map`` table to track mapping values from XCom.
@@ -23,6 +22,7 @@
Revises: f9da662e7089
Create Date: 2021-12-13 22:59:41.052584
"""
+from __future__ import annotations
from alembic import op
from sqlalchemy import CheckConstraint, Column, ForeignKeyConstraint, Integer, text
@@ -35,7 +35,7 @@
down_revision = "f9da662e7089"
branch_labels = None
depends_on = None
-airflow_version = '2.3.0'
+airflow_version = "2.3.0"
def upgrade():
diff --git a/airflow/migrations/versions/0101_2_3_0_add_data_compressed_to_serialized_dag.py b/airflow/migrations/versions/0101_2_3_0_add_data_compressed_to_serialized_dag.py
index f7276d8b2cac2..34e832c4a0e96 100644
--- a/airflow/migrations/versions/0101_2_3_0_add_data_compressed_to_serialized_dag.py
+++ b/airflow/migrations/versions/0101_2_3_0_add_data_compressed_to_serialized_dag.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""add data_compressed to serialized_dag
Revision ID: a3bcd0914482
@@ -23,25 +22,26 @@
Create Date: 2022-02-03 22:40:59.841119
"""
+from __future__ import annotations
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
-revision = 'a3bcd0914482'
-down_revision = 'e655c0453f75'
+revision = "a3bcd0914482"
+down_revision = "e655c0453f75"
branch_labels = None
depends_on = None
-airflow_version = '2.3.0'
+airflow_version = "2.3.0"
def upgrade():
- with op.batch_alter_table('serialized_dag') as batch_op:
- batch_op.alter_column('data', existing_type=sa.JSON, nullable=True)
- batch_op.add_column(sa.Column('data_compressed', sa.LargeBinary, nullable=True))
+ with op.batch_alter_table("serialized_dag") as batch_op:
+ batch_op.alter_column("data", existing_type=sa.JSON, nullable=True)
+ batch_op.add_column(sa.Column("data_compressed", sa.LargeBinary, nullable=True))
def downgrade():
- with op.batch_alter_table('serialized_dag') as batch_op:
- batch_op.alter_column('data', existing_type=sa.JSON, nullable=False)
- batch_op.drop_column('data_compressed')
+ with op.batch_alter_table("serialized_dag") as batch_op:
+ batch_op.alter_column("data", existing_type=sa.JSON, nullable=False)
+ batch_op.drop_column("data_compressed")
diff --git a/airflow/migrations/versions/0102_2_3_0_switch_xcom_table_to_use_run_id.py b/airflow/migrations/versions/0102_2_3_0_switch_xcom_table_to_use_run_id.py
index f987a1fe12606..b8b06e802c931 100644
--- a/airflow/migrations/versions/0102_2_3_0_switch_xcom_table_to_use_run_id.py
+++ b/airflow/migrations/versions/0102_2_3_0_switch_xcom_table_to_use_run_id.py
@@ -15,13 +15,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Switch XCom table to use ``run_id`` and add ``map_index``.
Revision ID: c306b5b5ae4a
Revises: a3bcd0914482
Create Date: 2022-01-19 03:20:35.329037
"""
+from __future__ import annotations
+
from typing import Sequence
from alembic import op
@@ -35,7 +36,7 @@
down_revision = "a3bcd0914482"
branch_labels = None
depends_on = None
-airflow_version = '2.3.0'
+airflow_version = "2.3.0"
metadata = MetaData()
@@ -167,8 +168,8 @@ def downgrade():
op.drop_table("xcom")
op.rename_table("__airflow_tmp_xcom", "xcom")
- if conn.dialect.name == 'mssql':
- constraints = get_mssql_table_constraints(conn, 'xcom')
- pk, _ = constraints['PRIMARY KEY'].popitem()
- op.drop_constraint(pk, 'xcom', type_='primary')
+ if conn.dialect.name == "mssql":
+ constraints = get_mssql_table_constraints(conn, "xcom")
+ pk, _ = constraints["PRIMARY KEY"].popitem()
+ op.drop_constraint(pk, "xcom", type_="primary")
op.create_primary_key("pk_xcom", "xcom", ["dag_id", "task_id", "execution_date", "key"])
diff --git a/airflow/migrations/versions/0103_2_3_0_add_callback_request_table.py b/airflow/migrations/versions/0103_2_3_0_add_callback_request_table.py
index 637abe8aee1b9..09ad52727d0ec 100644
--- a/airflow/migrations/versions/0103_2_3_0_add_callback_request_table.py
+++ b/airflow/migrations/versions/0103_2_3_0_add_callback_request_table.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""add callback request table
Revision ID: c97c2ab6aa23
@@ -23,6 +22,7 @@
Create Date: 2022-01-28 21:11:11.857010
"""
+from __future__ import annotations
import sqlalchemy as sa
from alembic import op
@@ -32,25 +32,25 @@
from airflow.utils.sqlalchemy import ExtendedJSON
# revision identifiers, used by Alembic.
-revision = 'c97c2ab6aa23'
-down_revision = 'c306b5b5ae4a'
+revision = "c97c2ab6aa23"
+down_revision = "c306b5b5ae4a"
branch_labels = None
depends_on = None
-airflow_version = '2.3.0'
+airflow_version = "2.3.0"
-TABLE_NAME = 'callback_request'
+TABLE_NAME = "callback_request"
def upgrade():
op.create_table(
TABLE_NAME,
- sa.Column('id', sa.Integer(), nullable=False, primary_key=True),
- sa.Column('created_at', TIMESTAMP, default=func.now(), nullable=False),
- sa.Column('priority_weight', sa.Integer(), default=1, nullable=False),
- sa.Column('callback_data', ExtendedJSON, nullable=False),
- sa.Column('callback_type', sa.String(20), nullable=False),
- sa.Column('dag_directory', sa.String(length=1000), nullable=True),
- sa.PrimaryKeyConstraint('id'),
+ sa.Column("id", sa.Integer(), nullable=False, primary_key=True),
+ sa.Column("created_at", TIMESTAMP, default=func.now, nullable=False),
+ sa.Column("priority_weight", sa.Integer(), default=1, nullable=False),
+ sa.Column("callback_data", ExtendedJSON, nullable=False),
+ sa.Column("callback_type", sa.String(20), nullable=False),
+ sa.Column("dag_directory", sa.String(length=1000), nullable=True),
+ sa.PrimaryKeyConstraint("id"),
)
diff --git a/airflow/migrations/versions/0104_2_3_0_migrate_rtif_to_use_run_id_and_map_index.py b/airflow/migrations/versions/0104_2_3_0_migrate_rtif_to_use_run_id_and_map_index.py
index f64f95b2b9bde..ff52470f42164 100644
--- a/airflow/migrations/versions/0104_2_3_0_migrate_rtif_to_use_run_id_and_map_index.py
+++ b/airflow/migrations/versions/0104_2_3_0_migrate_rtif_to_use_run_id_and_map_index.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Migrate RTIF to use run_id and map_index
Revision ID: 4eaab2fe6582
@@ -23,6 +22,7 @@
Create Date: 2022-03-03 17:48:29.955821
"""
+from __future__ import annotations
import sqlalchemy as sa
from alembic import op
@@ -35,11 +35,11 @@
ID_LEN = 250
# revision identifiers, used by Alembic.
-revision = '4eaab2fe6582'
-down_revision = 'c97c2ab6aa23'
+revision = "4eaab2fe6582"
+down_revision = "c97c2ab6aa23"
branch_labels = None
depends_on = None
-airflow_version = '2.3.0'
+airflow_version = "2.3.0"
# Just Enough Table to run the conditions for update.
@@ -49,42 +49,42 @@ def tables(for_downgrade=False):
global task_instance, rendered_task_instance_fields, dag_run
metadata = sa.MetaData()
task_instance = sa.Table(
- 'task_instance',
+ "task_instance",
metadata,
- sa.Column('task_id', StringID()),
- sa.Column('dag_id', StringID()),
- sa.Column('run_id', StringID()),
- sa.Column('execution_date', TIMESTAMP),
+ sa.Column("task_id", StringID()),
+ sa.Column("dag_id", StringID()),
+ sa.Column("run_id", StringID()),
+ sa.Column("execution_date", TIMESTAMP),
)
rendered_task_instance_fields = sa.Table(
- 'rendered_task_instance_fields',
+ "rendered_task_instance_fields",
metadata,
- sa.Column('dag_id', StringID()),
- sa.Column('task_id', StringID()),
- sa.Column('run_id', StringID()),
- sa.Column('execution_date', TIMESTAMP),
- sa.Column('rendered_fields', sqlalchemy_jsonfield.JSONField(), nullable=False),
- sa.Column('k8s_pod_yaml', sqlalchemy_jsonfield.JSONField(), nullable=True),
+ sa.Column("dag_id", StringID()),
+ sa.Column("task_id", StringID()),
+ sa.Column("run_id", StringID()),
+ sa.Column("execution_date", TIMESTAMP),
+ sa.Column("rendered_fields", sqlalchemy_jsonfield.JSONField(), nullable=False),
+ sa.Column("k8s_pod_yaml", sqlalchemy_jsonfield.JSONField(), nullable=True),
)
if for_downgrade:
rendered_task_instance_fields.append_column(
- sa.Column('map_index', sa.Integer(), server_default='-1'),
+ sa.Column("map_index", sa.Integer(), server_default="-1"),
)
rendered_task_instance_fields.append_constraint(
ForeignKeyConstraint(
- ['dag_id', 'run_id'],
+ ["dag_id", "run_id"],
["dag_run.dag_id", "dag_run.run_id"],
- name='rtif_dag_run_fkey',
+ name="rtif_dag_run_fkey",
ondelete="CASCADE",
),
)
dag_run = sa.Table(
- 'dag_run',
+ "dag_run",
metadata,
- sa.Column('dag_id', StringID()),
- sa.Column('run_id', StringID()),
- sa.Column('execution_date', TIMESTAMP),
+ sa.Column("dag_id", StringID()),
+ sa.Column("run_id", StringID()),
+ sa.Column("execution_date", TIMESTAMP),
)
@@ -110,44 +110,44 @@ def upgrade():
tables()
dialect_name = op.get_bind().dialect.name
- with op.batch_alter_table('rendered_task_instance_fields') as batch_op:
- batch_op.add_column(sa.Column('map_index', sa.Integer(), server_default='-1', nullable=False))
+ with op.batch_alter_table("rendered_task_instance_fields") as batch_op:
+ batch_op.add_column(sa.Column("map_index", sa.Integer(), server_default="-1", nullable=False))
rendered_task_instance_fields.append_column(
- sa.Column('map_index', sa.Integer(), server_default='-1', nullable=False)
+ sa.Column("map_index", sa.Integer(), server_default="-1", nullable=False)
)
- batch_op.add_column(sa.Column('run_id', type_=StringID(), nullable=True))
+ batch_op.add_column(sa.Column("run_id", type_=StringID(), nullable=True))
update_query = _multi_table_update(
dialect_name, rendered_task_instance_fields, rendered_task_instance_fields.c.run_id
)
op.execute(update_query)
with op.batch_alter_table(
- 'rendered_task_instance_fields', copy_from=rendered_task_instance_fields
+ "rendered_task_instance_fields", copy_from=rendered_task_instance_fields
) as batch_op:
- if dialect_name == 'mssql':
- constraints = get_mssql_table_constraints(op.get_bind(), 'rendered_task_instance_fields')
- pk, _ = constraints['PRIMARY KEY'].popitem()
- batch_op.drop_constraint(pk, type_='primary')
- elif dialect_name != 'sqlite':
- batch_op.drop_constraint('rendered_task_instance_fields_pkey', type_='primary')
- batch_op.alter_column('run_id', existing_type=StringID(), existing_nullable=True, nullable=False)
- batch_op.drop_column('execution_date')
+ if dialect_name == "mssql":
+ constraints = get_mssql_table_constraints(op.get_bind(), "rendered_task_instance_fields")
+ pk, _ = constraints["PRIMARY KEY"].popitem()
+ batch_op.drop_constraint(pk, type_="primary")
+ elif dialect_name != "sqlite":
+ batch_op.drop_constraint("rendered_task_instance_fields_pkey", type_="primary")
+ batch_op.alter_column("run_id", existing_type=StringID(), existing_nullable=True, nullable=False)
+ batch_op.drop_column("execution_date")
batch_op.create_primary_key(
- 'rendered_task_instance_fields_pkey', ['dag_id', 'task_id', 'run_id', 'map_index']
+ "rendered_task_instance_fields_pkey", ["dag_id", "task_id", "run_id", "map_index"]
)
batch_op.create_foreign_key(
- 'rtif_ti_fkey',
- 'task_instance',
- ['dag_id', 'task_id', 'run_id', 'map_index'],
- ['dag_id', 'task_id', 'run_id', 'map_index'],
- ondelete='CASCADE',
+ "rtif_ti_fkey",
+ "task_instance",
+ ["dag_id", "task_id", "run_id", "map_index"],
+ ["dag_id", "task_id", "run_id", "map_index"],
+ ondelete="CASCADE",
)
def downgrade():
tables(for_downgrade=True)
dialect_name = op.get_bind().dialect.name
- op.add_column('rendered_task_instance_fields', sa.Column('execution_date', TIMESTAMP, nullable=True))
+ op.add_column("rendered_task_instance_fields", sa.Column("execution_date", TIMESTAMP, nullable=True))
update_query = _multi_table_update(
dialect_name, rendered_task_instance_fields, rendered_task_instance_fields.c.execution_date
@@ -155,14 +155,14 @@ def downgrade():
op.execute(update_query)
with op.batch_alter_table(
- 'rendered_task_instance_fields', copy_from=rendered_task_instance_fields
+ "rendered_task_instance_fields", copy_from=rendered_task_instance_fields
) as batch_op:
- batch_op.alter_column('execution_date', existing_type=TIMESTAMP, nullable=False)
- if dialect_name != 'sqlite':
- batch_op.drop_constraint('rtif_ti_fkey', type_='foreignkey')
- batch_op.drop_constraint('rendered_task_instance_fields_pkey', type_='primary')
+ batch_op.alter_column("execution_date", existing_type=TIMESTAMP, nullable=False)
+ if dialect_name != "sqlite":
+ batch_op.drop_constraint("rtif_ti_fkey", type_="foreignkey")
+ batch_op.drop_constraint("rendered_task_instance_fields_pkey", type_="primary")
batch_op.create_primary_key(
- 'rendered_task_instance_fields_pkey', ['dag_id', 'task_id', 'execution_date']
+ "rendered_task_instance_fields_pkey", ["dag_id", "task_id", "execution_date"]
)
- batch_op.drop_column('map_index', mssql_drop_default=True)
- batch_op.drop_column('run_id')
+ batch_op.drop_column("map_index", mssql_drop_default=True)
+ batch_op.drop_column("run_id")
diff --git a/airflow/migrations/versions/0105_2_3_0_add_map_index_to_taskfail.py b/airflow/migrations/versions/0105_2_3_0_add_map_index_to_taskfail.py
index 303e1eda58241..d65cf6b548947 100644
--- a/airflow/migrations/versions/0105_2_3_0_add_map_index_to_taskfail.py
+++ b/airflow/migrations/versions/0105_2_3_0_add_map_index_to_taskfail.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Add map_index to TaskFail
Drop index idx_task_fail_dag_task_date
@@ -28,8 +27,7 @@
Revises: 4eaab2fe6582
Create Date: 2022-03-14 10:31:11.220720
"""
-
-from typing import List
+from __future__ import annotations
import sqlalchemy as sa
from alembic import op
@@ -38,11 +36,11 @@
from airflow.migrations.db_types import TIMESTAMP, StringID
# revision identifiers, used by Alembic.
-revision = '48925b2719cb'
-down_revision = '4eaab2fe6582'
+revision = "48925b2719cb"
+down_revision = "4eaab2fe6582"
branch_labels = None
depends_on = None
-airflow_version = '2.3.0'
+airflow_version = "2.3.0"
ID_LEN = 250
@@ -51,29 +49,29 @@ def tables():
global task_instance, task_fail, dag_run
metadata = sa.MetaData()
task_instance = sa.Table(
- 'task_instance',
+ "task_instance",
metadata,
- sa.Column('task_id', StringID()),
- sa.Column('dag_id', StringID()),
- sa.Column('run_id', StringID()),
- sa.Column('map_index', sa.Integer(), server_default='-1'),
- sa.Column('execution_date', TIMESTAMP),
+ sa.Column("task_id", StringID()),
+ sa.Column("dag_id", StringID()),
+ sa.Column("run_id", StringID()),
+ sa.Column("map_index", sa.Integer(), server_default="-1"),
+ sa.Column("execution_date", TIMESTAMP),
)
task_fail = sa.Table(
- 'task_fail',
+ "task_fail",
metadata,
- sa.Column('dag_id', StringID()),
- sa.Column('task_id', StringID()),
- sa.Column('run_id', StringID()),
- sa.Column('map_index', StringID()),
- sa.Column('execution_date', TIMESTAMP),
+ sa.Column("dag_id", StringID()),
+ sa.Column("task_id", StringID()),
+ sa.Column("run_id", StringID()),
+ sa.Column("map_index", StringID()),
+ sa.Column("execution_date", TIMESTAMP),
)
dag_run = sa.Table(
- 'dag_run',
+ "dag_run",
metadata,
- sa.Column('dag_id', StringID()),
- sa.Column('run_id', StringID()),
- sa.Column('execution_date', TIMESTAMP),
+ sa.Column("dag_id", StringID()),
+ sa.Column("run_id", StringID()),
+ sa.Column("execution_date", TIMESTAMP),
)
@@ -81,7 +79,7 @@ def _update_value_from_dag_run(
dialect_name: str,
target_table: sa.Table,
target_column: ColumnElement,
- join_columns: List[str],
+ join_columns: list[str],
) -> Update:
"""
Grabs a value from the source table ``dag_run`` and updates target with this value.
@@ -108,51 +106,51 @@ def upgrade():
tables()
dialect_name = op.get_bind().dialect.name
- op.drop_index('idx_task_fail_dag_task_date', table_name='task_fail')
+ op.drop_index("idx_task_fail_dag_task_date", table_name="task_fail")
- with op.batch_alter_table('task_fail') as batch_op:
- batch_op.add_column(sa.Column('map_index', sa.Integer(), server_default='-1', nullable=False))
- batch_op.add_column(sa.Column('run_id', type_=StringID(), nullable=True))
+ with op.batch_alter_table("task_fail") as batch_op:
+ batch_op.add_column(sa.Column("map_index", sa.Integer(), server_default="-1", nullable=False))
+ batch_op.add_column(sa.Column("run_id", type_=StringID(), nullable=True))
update_query = _update_value_from_dag_run(
dialect_name=dialect_name,
target_table=task_fail,
target_column=task_fail.c.run_id,
- join_columns=['dag_id', 'execution_date'],
+ join_columns=["dag_id", "execution_date"],
)
op.execute(update_query)
- with op.batch_alter_table('task_fail') as batch_op:
- batch_op.alter_column('run_id', existing_type=StringID(), existing_nullable=True, nullable=False)
- batch_op.drop_column('execution_date')
+ with op.batch_alter_table("task_fail") as batch_op:
+ batch_op.alter_column("run_id", existing_type=StringID(), existing_nullable=True, nullable=False)
+ batch_op.drop_column("execution_date")
batch_op.create_foreign_key(
- 'task_fail_ti_fkey',
- 'task_instance',
- ['dag_id', 'task_id', 'run_id', 'map_index'],
- ['dag_id', 'task_id', 'run_id', 'map_index'],
- ondelete='CASCADE',
+ "task_fail_ti_fkey",
+ "task_instance",
+ ["dag_id", "task_id", "run_id", "map_index"],
+ ["dag_id", "task_id", "run_id", "map_index"],
+ ondelete="CASCADE",
)
def downgrade():
tables()
dialect_name = op.get_bind().dialect.name
- op.add_column('task_fail', sa.Column('execution_date', TIMESTAMP, nullable=True))
+ op.add_column("task_fail", sa.Column("execution_date", TIMESTAMP, nullable=True))
update_query = _update_value_from_dag_run(
dialect_name=dialect_name,
target_table=task_fail,
target_column=task_fail.c.execution_date,
- join_columns=['dag_id', 'run_id'],
+ join_columns=["dag_id", "run_id"],
)
op.execute(update_query)
- with op.batch_alter_table('task_fail', copy_from=task_fail) as batch_op:
- batch_op.alter_column('execution_date', existing_type=TIMESTAMP, nullable=False)
- if dialect_name != 'sqlite':
- batch_op.drop_constraint('task_fail_ti_fkey', type_='foreignkey')
- batch_op.drop_column('map_index', mssql_drop_default=True)
- batch_op.drop_column('run_id')
+ with op.batch_alter_table("task_fail") as batch_op:
+ batch_op.alter_column("execution_date", existing_type=TIMESTAMP, nullable=False)
+ if dialect_name != "sqlite":
+ batch_op.drop_constraint("task_fail_ti_fkey", type_="foreignkey")
+ batch_op.drop_column("map_index", mssql_drop_default=True)
+ batch_op.drop_column("run_id")
op.create_index(
- index_name='idx_task_fail_dag_task_date',
- table_name='task_fail',
- columns=['dag_id', 'task_id', 'execution_date'],
+ index_name="idx_task_fail_dag_task_date",
+ table_name="task_fail",
+ columns=["dag_id", "task_id", "execution_date"],
unique=False,
)
diff --git a/airflow/migrations/versions/0106_2_3_0_update_migration_for_fab_tables_to_add_missing_constraints.py b/airflow/migrations/versions/0106_2_3_0_update_migration_for_fab_tables_to_add_missing_constraints.py
index 81f6f1f34492b..13823ac354011 100644
--- a/airflow/migrations/versions/0106_2_3_0_update_migration_for_fab_tables_to_add_missing_constraints.py
+++ b/airflow/migrations/versions/0106_2_3_0_update_migration_for_fab_tables_to_add_missing_constraints.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Update migration for FAB tables to add missing constraints
Revision ID: 909884dea523
@@ -23,7 +22,7 @@
Create Date: 2022-03-21 08:33:01.635688
"""
-
+from __future__ import annotations
import sqlalchemy as sa
from alembic import op
@@ -31,78 +30,78 @@
from airflow.migrations.utils import get_mssql_table_constraints
# revision identifiers, used by Alembic.
-revision = '909884dea523'
-down_revision = '48925b2719cb'
+revision = "909884dea523"
+down_revision = "48925b2719cb"
branch_labels = None
depends_on = None
-airflow_version = '2.3.0'
+airflow_version = "2.3.0"
def upgrade():
"""Apply Update migration for FAB tables to add missing constraints"""
conn = op.get_bind()
- if conn.dialect.name == 'sqlite':
- op.execute('PRAGMA foreign_keys=OFF')
- with op.batch_alter_table('ab_view_menu', schema=None) as batch_op:
- batch_op.create_unique_constraint(batch_op.f('ab_view_menu_name_uq'), ['name'])
- op.execute('PRAGMA foreign_keys=ON')
- elif conn.dialect.name == 'mysql':
- with op.batch_alter_table('ab_register_user', schema=None) as batch_op:
- batch_op.alter_column('username', existing_type=sa.String(256), nullable=False)
- batch_op.alter_column('email', existing_type=sa.String(256), nullable=False)
- with op.batch_alter_table('ab_user', schema=None) as batch_op:
- batch_op.alter_column('username', existing_type=sa.String(256), nullable=False)
- batch_op.alter_column('email', existing_type=sa.String(256), nullable=False)
- elif conn.dialect.name == 'mssql':
- with op.batch_alter_table('ab_register_user') as batch_op:
+ if conn.dialect.name == "sqlite":
+ op.execute("PRAGMA foreign_keys=OFF")
+ with op.batch_alter_table("ab_view_menu", schema=None) as batch_op:
+ batch_op.create_unique_constraint(batch_op.f("ab_view_menu_name_uq"), ["name"])
+ op.execute("PRAGMA foreign_keys=ON")
+ elif conn.dialect.name == "mysql":
+ with op.batch_alter_table("ab_register_user", schema=None) as batch_op:
+ batch_op.alter_column("username", existing_type=sa.String(256), nullable=False)
+ batch_op.alter_column("email", existing_type=sa.String(256), nullable=False)
+ with op.batch_alter_table("ab_user", schema=None) as batch_op:
+ batch_op.alter_column("username", existing_type=sa.String(256), nullable=False)
+ batch_op.alter_column("email", existing_type=sa.String(256), nullable=False)
+ elif conn.dialect.name == "mssql":
+ with op.batch_alter_table("ab_register_user") as batch_op:
# Drop the unique constraint on username and email
- constraints = get_mssql_table_constraints(conn, 'ab_register_user')
- for k, _ in constraints.get('UNIQUE').items():
- batch_op.drop_constraint(k, type_='unique')
- batch_op.alter_column('username', existing_type=sa.String(256), nullable=False)
- batch_op.create_unique_constraint(None, ['username'])
- batch_op.alter_column('email', existing_type=sa.String(256), nullable=False)
- with op.batch_alter_table('ab_user') as batch_op:
+ constraints = get_mssql_table_constraints(conn, "ab_register_user")
+ for k, _ in constraints.get("UNIQUE").items():
+ batch_op.drop_constraint(k, type_="unique")
+ batch_op.alter_column("username", existing_type=sa.String(256), nullable=False)
+ batch_op.create_unique_constraint(None, ["username"])
+ batch_op.alter_column("email", existing_type=sa.String(256), nullable=False)
+ with op.batch_alter_table("ab_user") as batch_op:
# Drop the unique constraint on username and email
- constraints = get_mssql_table_constraints(conn, 'ab_user')
- for k, _ in constraints.get('UNIQUE').items():
- batch_op.drop_constraint(k, type_='unique')
- batch_op.alter_column('username', existing_type=sa.String(256), nullable=False)
- batch_op.create_unique_constraint(None, ['username'])
- batch_op.alter_column('email', existing_type=sa.String(256), nullable=False)
- batch_op.create_unique_constraint(None, ['email'])
+ constraints = get_mssql_table_constraints(conn, "ab_user")
+ for k, _ in constraints.get("UNIQUE").items():
+ batch_op.drop_constraint(k, type_="unique")
+ batch_op.alter_column("username", existing_type=sa.String(256), nullable=False)
+ batch_op.create_unique_constraint(None, ["username"])
+ batch_op.alter_column("email", existing_type=sa.String(256), nullable=False)
+ batch_op.create_unique_constraint(None, ["email"])
def downgrade():
"""Unapply Update migration for FAB tables to add missing constraints"""
conn = op.get_bind()
- if conn.dialect.name == 'sqlite':
- op.execute('PRAGMA foreign_keys=OFF')
- with op.batch_alter_table('ab_view_menu', schema=None) as batch_op:
- batch_op.drop_constraint('ab_view_menu_name_uq', type_='unique')
- op.execute('PRAGMA foreign_keys=ON')
- elif conn.dialect.name == 'mysql':
- with op.batch_alter_table('ab_user', schema=None) as batch_op:
- batch_op.alter_column('email', existing_type=sa.String(256), nullable=True)
- batch_op.alter_column('username', existing_type=sa.String(256), nullable=True, unique=True)
- with op.batch_alter_table('ab_register_user', schema=None) as batch_op:
- batch_op.alter_column('email', existing_type=sa.String(256), nullable=True)
- batch_op.alter_column('username', existing_type=sa.String(256), nullable=True, unique=True)
- elif conn.dialect.name == 'mssql':
- with op.batch_alter_table('ab_register_user') as batch_op:
+ if conn.dialect.name == "sqlite":
+ op.execute("PRAGMA foreign_keys=OFF")
+ with op.batch_alter_table("ab_view_menu", schema=None) as batch_op:
+ batch_op.drop_constraint("ab_view_menu_name_uq", type_="unique")
+ op.execute("PRAGMA foreign_keys=ON")
+ elif conn.dialect.name == "mysql":
+ with op.batch_alter_table("ab_user", schema=None) as batch_op:
+ batch_op.alter_column("email", existing_type=sa.String(256), nullable=True)
+ batch_op.alter_column("username", existing_type=sa.String(256), nullable=True, unique=True)
+ with op.batch_alter_table("ab_register_user", schema=None) as batch_op:
+ batch_op.alter_column("email", existing_type=sa.String(256), nullable=True)
+ batch_op.alter_column("username", existing_type=sa.String(256), nullable=True, unique=True)
+ elif conn.dialect.name == "mssql":
+ with op.batch_alter_table("ab_register_user") as batch_op:
# Drop the unique constraint on username and email
- constraints = get_mssql_table_constraints(conn, 'ab_register_user')
- for k, _ in constraints.get('UNIQUE').items():
- batch_op.drop_constraint(k, type_='unique')
- batch_op.alter_column('username', existing_type=sa.String(256), nullable=False, unique=True)
- batch_op.create_unique_constraint(None, ['username'])
- batch_op.alter_column('email', existing_type=sa.String(256), nullable=False, unique=True)
- with op.batch_alter_table('ab_user') as batch_op:
+ constraints = get_mssql_table_constraints(conn, "ab_register_user")
+ for k, _ in constraints.get("UNIQUE").items():
+ batch_op.drop_constraint(k, type_="unique")
+ batch_op.alter_column("username", existing_type=sa.String(256), nullable=False, unique=True)
+ batch_op.create_unique_constraint(None, ["username"])
+ batch_op.alter_column("email", existing_type=sa.String(256), nullable=False, unique=True)
+ with op.batch_alter_table("ab_user") as batch_op:
# Drop the unique constraint on username and email
- constraints = get_mssql_table_constraints(conn, 'ab_user')
- for k, _ in constraints.get('UNIQUE').items():
- batch_op.drop_constraint(k, type_='unique')
- batch_op.alter_column('username', existing_type=sa.String(256), nullable=True)
- batch_op.create_unique_constraint(None, ['username'])
- batch_op.alter_column('email', existing_type=sa.String(256), nullable=True, unique=True)
- batch_op.create_unique_constraint(None, ['email'])
+ constraints = get_mssql_table_constraints(conn, "ab_user")
+ for k, _ in constraints.get("UNIQUE").items():
+ batch_op.drop_constraint(k, type_="unique")
+ batch_op.alter_column("username", existing_type=sa.String(256), nullable=True)
+ batch_op.create_unique_constraint(None, ["username"])
+ batch_op.alter_column("email", existing_type=sa.String(256), nullable=True, unique=True)
+ batch_op.create_unique_constraint(None, ["email"])
diff --git a/airflow/migrations/versions/0107_2_3_0_add_map_index_to_log.py b/airflow/migrations/versions/0107_2_3_0_add_map_index_to_log.py
index f48cd90cfa1a3..96c52b03967b3 100644
--- a/airflow/migrations/versions/0107_2_3_0_add_map_index_to_log.py
+++ b/airflow/migrations/versions/0107_2_3_0_add_map_index_to_log.py
@@ -15,13 +15,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Add map_index to Log.
Revision ID: 75d5ed6c2b43
Revises: 909884dea523
Create Date: 2022-03-15 16:35:54.816863
"""
+from __future__ import annotations
+
from alembic import op
from sqlalchemy import Column, Integer
@@ -30,7 +31,7 @@
down_revision = "909884dea523"
branch_labels = None
depends_on = None
-airflow_version = '2.3.0'
+airflow_version = "2.3.0"
def upgrade():
diff --git a/airflow/migrations/versions/0108_2_3_0_default_dag_view_grid.py b/airflow/migrations/versions/0108_2_3_0_default_dag_view_grid.py
index 03a9f43c3ad62..7875d1eceb30d 100644
--- a/airflow/migrations/versions/0108_2_3_0_default_dag_view_grid.py
+++ b/airflow/migrations/versions/0108_2_3_0_default_dag_view_grid.py
@@ -15,41 +15,41 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
-"""Update dag.default_view to grid
+"""Update dag.default_view to grid.
Revision ID: b1b348e02d07
Revises: 75d5ed6c2b43
Create Date: 2022-04-19 17:25:00.872220
"""
+from __future__ import annotations
from alembic import op
from sqlalchemy import String
from sqlalchemy.sql import column, table
# revision identifiers, used by Alembic.
-revision = 'b1b348e02d07'
-down_revision = '75d5ed6c2b43'
+revision = "b1b348e02d07"
+down_revision = "75d5ed6c2b43"
branch_labels = None
-depends_on = '75d5ed6c2b43'
-airflow_version = '2.3.0'
+depends_on = "75d5ed6c2b43"
+airflow_version = "2.3.0"
-dag = table('dag', column('default_view', String))
+dag = table("dag", column("default_view", String))
def upgrade():
op.execute(
dag.update()
- .where(dag.c.default_view == op.inline_literal('tree'))
- .values({'default_view': op.inline_literal('grid')})
+ .where(dag.c.default_view == op.inline_literal("tree"))
+ .values({"default_view": op.inline_literal("grid")})
)
def downgrade():
op.execute(
dag.update()
- .where(dag.c.default_view == op.inline_literal('grid'))
- .values({'default_view': op.inline_literal('tree')})
+ .where(dag.c.default_view == op.inline_literal("grid"))
+ .values({"default_view": op.inline_literal("tree")})
)
diff --git a/airflow/migrations/versions/0109_2_3_1_add_index_for_event_in_log.py b/airflow/migrations/versions/0109_2_3_1_add_index_for_event_in_log.py
index 2023a3c294fee..bfe958520be11 100644
--- a/airflow/migrations/versions/0109_2_3_1_add_index_for_event_in_log.py
+++ b/airflow/migrations/versions/0109_2_3_1_add_index_for_event_in_log.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Add index for ``event`` column in ``log`` table.
Revision ID: 1de7bc13c950
@@ -23,22 +22,23 @@
Create Date: 2022-05-10 18:18:43.484829
"""
+from __future__ import annotations
from alembic import op
# revision identifiers, used by Alembic.
-revision = '1de7bc13c950'
-down_revision = 'b1b348e02d07'
+revision = "1de7bc13c950"
+down_revision = "b1b348e02d07"
branch_labels = None
depends_on = None
-airflow_version = '2.3.1'
+airflow_version = "2.3.1"
def upgrade():
"""Apply Add index for ``event`` column in ``log`` table."""
- op.create_index('idx_log_event', 'log', ['event'], unique=False)
+ op.create_index("idx_log_event", "log", ["event"], unique=False)
def downgrade():
"""Unapply Add index for ``event`` column in ``log`` table."""
- op.drop_index('idx_log_event', table_name='log')
+ op.drop_index("idx_log_event", table_name="log")
diff --git a/airflow/migrations/versions/0110_2_3_2_add_cascade_to_dag_tag_foreignkey.py b/airflow/migrations/versions/0110_2_3_2_add_cascade_to_dag_tag_foreignkey.py
index 55d9e9754e532..6ca9b2afb9528 100644
--- a/airflow/migrations/versions/0110_2_3_2_add_cascade_to_dag_tag_foreignkey.py
+++ b/airflow/migrations/versions/0110_2_3_2_add_cascade_to_dag_tag_foreignkey.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Add cascade to dag_tag foreign key
Revision ID: 3c94c427fdf6
@@ -23,62 +22,60 @@
Create Date: 2022-05-03 09:47:41.957710
"""
+from __future__ import annotations
from alembic import op
+from sqlalchemy import inspect
from airflow.migrations.utils import get_mssql_table_constraints
# revision identifiers, used by Alembic.
-revision = '3c94c427fdf6'
-down_revision = '1de7bc13c950'
+revision = "3c94c427fdf6"
+down_revision = "1de7bc13c950"
branch_labels = None
depends_on = None
-airflow_version = '2.3.2'
+airflow_version = "2.3.2"
def upgrade():
"""Apply Add cascade to dag_tag foreignkey"""
conn = op.get_bind()
- if conn.dialect.name == 'sqlite':
- naming_convention = {
- "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
- }
+ if conn.dialect.name in ["sqlite", "mysql"]:
+ inspector = inspect(conn.engine)
+ foreignkey = inspector.get_foreign_keys("dag_tag")
with op.batch_alter_table(
- 'dag_tag', naming_convention=naming_convention, recreate='always'
+ "dag_tag",
) as batch_op:
- batch_op.drop_constraint('fk_dag_tag_dag_id_dag', type_='foreignkey')
+ batch_op.drop_constraint(foreignkey[0]["name"], type_="foreignkey")
batch_op.create_foreign_key(
- "dag_tag_dag_id_fkey", 'dag', ['dag_id'], ['dag_id'], ondelete='CASCADE'
+ "dag_tag_dag_id_fkey", "dag", ["dag_id"], ["dag_id"], ondelete="CASCADE"
)
else:
- with op.batch_alter_table('dag_tag') as batch_op:
- if conn.dialect.name == 'mssql':
- constraints = get_mssql_table_constraints(conn, 'dag_tag')
- Fk, _ = constraints['FOREIGN KEY'].popitem()
- batch_op.drop_constraint(Fk, type_='foreignkey')
- if conn.dialect.name == 'postgresql':
- batch_op.drop_constraint('dag_tag_dag_id_fkey', type_='foreignkey')
- if conn.dialect.name == 'mysql':
- batch_op.drop_constraint('dag_tag_ibfk_1', type_='foreignkey')
-
+ with op.batch_alter_table("dag_tag") as batch_op:
+ if conn.dialect.name == "mssql":
+ constraints = get_mssql_table_constraints(conn, "dag_tag")
+ Fk, _ = constraints["FOREIGN KEY"].popitem()
+ batch_op.drop_constraint(Fk, type_="foreignkey")
+ if conn.dialect.name == "postgresql":
+ batch_op.drop_constraint("dag_tag_dag_id_fkey", type_="foreignkey")
batch_op.create_foreign_key(
- "dag_tag_dag_id_fkey", 'dag', ['dag_id'], ['dag_id'], ondelete='CASCADE'
+ "dag_tag_dag_id_fkey", "dag", ["dag_id"], ["dag_id"], ondelete="CASCADE"
)
def downgrade():
"""Unapply Add cascade to dag_tag foreignkey"""
conn = op.get_bind()
- if conn.dialect.name == 'sqlite':
- with op.batch_alter_table('dag_tag') as batch_op:
- batch_op.drop_constraint('dag_tag_dag_id_fkey', type_='foreignkey')
- batch_op.create_foreign_key("fk_dag_tag_dag_id_dag", 'dag', ['dag_id'], ['dag_id'])
+ if conn.dialect.name == "sqlite":
+ with op.batch_alter_table("dag_tag") as batch_op:
+ batch_op.drop_constraint("dag_tag_dag_id_fkey", type_="foreignkey")
+ batch_op.create_foreign_key("fk_dag_tag_dag_id_dag", "dag", ["dag_id"], ["dag_id"])
else:
- with op.batch_alter_table('dag_tag') as batch_op:
- batch_op.drop_constraint('dag_tag_dag_id_fkey', type_='foreignkey')
+ with op.batch_alter_table("dag_tag") as batch_op:
+ batch_op.drop_constraint("dag_tag_dag_id_fkey", type_="foreignkey")
batch_op.create_foreign_key(
None,
- 'dag',
- ['dag_id'],
- ['dag_id'],
+ "dag",
+ ["dag_id"],
+ ["dag_id"],
)
diff --git a/airflow/migrations/versions/0111_2_3_3_add_indexes_for_cascade_deletes.py b/airflow/migrations/versions/0111_2_3_3_add_indexes_for_cascade_deletes.py
new file mode 100644
index 0000000000000..d2c39177fafcd
--- /dev/null
+++ b/airflow/migrations/versions/0111_2_3_3_add_indexes_for_cascade_deletes.py
@@ -0,0 +1,94 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Add indexes for CASCADE deletes on task_instance
+
+Some databases don't add indexes on the FK columns so we have to add them for performance on CASCADE deletes.
+
+Revision ID: f5fcbda3e651
+Revises: 3c94c427fdf6
+Create Date: 2022-06-15 18:04:54.081789
+
+"""
+from __future__ import annotations
+
+from alembic import context, op
+
+# revision identifiers, used by Alembic.
+revision = "f5fcbda3e651"
+down_revision = "3c94c427fdf6"
+branch_labels = None
+depends_on = None
+airflow_version = "2.3.3"
+
+
+def _mysql_tables_where_indexes_already_present(conn):
+ """
+ If user downgraded and is upgrading again, we have to check for existing
+ indexes on mysql because we can't (and don't) drop them as part of the
+ downgrade.
+ """
+ to_check = [
+ ("xcom", "idx_xcom_task_instance"),
+ ("task_reschedule", "idx_task_reschedule_dag_run"),
+ ("task_fail", "idx_task_fail_task_instance"),
+ ]
+ tables = set()
+ for tbl, idx in to_check:
+ if conn.execute(f"show indexes from {tbl} where Key_name = '{idx}'").first():
+ tables.add(tbl)
+ return tables
+
+
+def upgrade():
+ """Apply Add indexes for CASCADE deletes"""
+ conn = op.get_bind()
+ tables_to_skip = set()
+
+ # mysql requires indexes for FKs, so adding had the effect of renaming, and we cannot remove.
+ if conn.dialect.name == "mysql" and not context.is_offline_mode():
+ tables_to_skip.update(_mysql_tables_where_indexes_already_present(conn))
+
+ if "task_fail" not in tables_to_skip:
+ with op.batch_alter_table("task_fail", schema=None) as batch_op:
+ batch_op.create_index("idx_task_fail_task_instance", ["dag_id", "task_id", "run_id", "map_index"])
+
+ if "task_reschedule" not in tables_to_skip:
+ with op.batch_alter_table("task_reschedule", schema=None) as batch_op:
+ batch_op.create_index("idx_task_reschedule_dag_run", ["dag_id", "run_id"])
+
+ if "xcom" not in tables_to_skip:
+ with op.batch_alter_table("xcom", schema=None) as batch_op:
+ batch_op.create_index("idx_xcom_task_instance", ["dag_id", "task_id", "run_id", "map_index"])
+
+
+def downgrade():
+ """Unapply Add indexes for CASCADE deletes"""
+ conn = op.get_bind()
+
+ # mysql requires indexes for FKs, so adding had the effect of renaming, and we cannot remove.
+ if conn.dialect.name == "mysql":
+ return
+
+ with op.batch_alter_table("xcom", schema=None) as batch_op:
+ batch_op.drop_index("idx_xcom_task_instance")
+
+ with op.batch_alter_table("task_reschedule", schema=None) as batch_op:
+ batch_op.drop_index("idx_task_reschedule_dag_run")
+
+ with op.batch_alter_table("task_fail", schema=None) as batch_op:
+ batch_op.drop_index("idx_task_fail_task_instance")
diff --git a/airflow/migrations/versions/0112_2_4_0_add_dagwarning_model.py b/airflow/migrations/versions/0112_2_4_0_add_dagwarning_model.py
new file mode 100644
index 0000000000000..2bb999b21b350
--- /dev/null
+++ b/airflow/migrations/versions/0112_2_4_0_add_dagwarning_model.py
@@ -0,0 +1,60 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Add DagWarning model
+
+Revision ID: 424117c37d18
+Revises: 3c94c427fdf6
+Create Date: 2022-04-27 15:57:36.736743
+"""
+from __future__ import annotations
+
+import sqlalchemy as sa
+from alembic import op
+
+from airflow.migrations.db_types import TIMESTAMP, StringID
+
+# revision identifiers, used by Alembic.
+
+
+revision = "424117c37d18"
+down_revision = "f5fcbda3e651"
+branch_labels = None
+depends_on = None
+airflow_version = "2.4.0"
+
+
+def upgrade():
+ """Apply Add DagWarning model"""
+ op.create_table(
+ "dag_warning",
+ sa.Column("dag_id", StringID(), primary_key=True),
+ sa.Column("warning_type", sa.String(length=50), primary_key=True),
+ sa.Column("message", sa.Text(), nullable=False),
+ sa.Column("timestamp", TIMESTAMP, nullable=False),
+ sa.ForeignKeyConstraint(
+ ("dag_id",),
+ ["dag.dag_id"],
+ name="dcw_dag_id_fkey",
+ ondelete="CASCADE",
+ ),
+ )
+
+
+def downgrade():
+ """Unapply Add DagWarning model"""
+ op.drop_table("dag_warning")
diff --git a/airflow/migrations/versions/0113_2_4_0_compare_types_between_orm_and_db.py b/airflow/migrations/versions/0113_2_4_0_compare_types_between_orm_and_db.py
new file mode 100644
index 0000000000000..8beebc80a9aa7
--- /dev/null
+++ b/airflow/migrations/versions/0113_2_4_0_compare_types_between_orm_and_db.py
@@ -0,0 +1,262 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""compare types between ORM and DB.
+
+Revision ID: 44b7034f6bdc
+Revises: 424117c37d18
+Create Date: 2022-05-31 09:16:44.558754
+
+"""
+from __future__ import annotations
+
+import sqlalchemy as sa
+from alembic import op
+
+from airflow.migrations.db_types import TIMESTAMP
+
+# revision identifiers, used by Alembic.
+revision = "44b7034f6bdc"
+down_revision = "424117c37d18"
+branch_labels = None
+depends_on = None
+airflow_version = "2.4.0"
+
+
+def upgrade():
+ """Apply compare types between ORM and DB."""
+ conn = op.get_bind()
+ with op.batch_alter_table("connection", schema=None) as batch_op:
+ batch_op.alter_column(
+ "extra",
+ existing_type=sa.TEXT(),
+ type_=sa.Text(),
+ existing_nullable=True,
+ )
+ with op.batch_alter_table("log_template", schema=None) as batch_op:
+ batch_op.alter_column(
+ "created_at", existing_type=sa.DateTime(), type_=TIMESTAMP(), existing_nullable=False
+ )
+
+ with op.batch_alter_table("serialized_dag", schema=None) as batch_op:
+ # drop server_default
+ batch_op.alter_column(
+ "dag_hash",
+ existing_type=sa.String(32),
+ server_default=None,
+ type_=sa.String(32),
+ existing_nullable=False,
+ )
+ with op.batch_alter_table("trigger", schema=None) as batch_op:
+ batch_op.alter_column(
+ "created_date", existing_type=sa.DateTime(), type_=TIMESTAMP(), existing_nullable=False
+ )
+
+ if conn.dialect.name != "sqlite":
+ return
+ with op.batch_alter_table("serialized_dag", schema=None) as batch_op:
+ batch_op.alter_column("fileloc_hash", existing_type=sa.Integer, type_=sa.BigInteger())
+ # Some sqlite date are not in db_types.TIMESTAMP. Convert these to TIMESTAMP.
+ with op.batch_alter_table("dag", schema=None) as batch_op:
+ batch_op.alter_column(
+ "last_pickled", existing_type=sa.DATETIME(), type_=TIMESTAMP(), existing_nullable=True
+ )
+ batch_op.alter_column(
+ "last_expired", existing_type=sa.DATETIME(), type_=TIMESTAMP(), existing_nullable=True
+ )
+
+ with op.batch_alter_table("dag_pickle", schema=None) as batch_op:
+ batch_op.alter_column(
+ "created_dttm", existing_type=sa.DATETIME(), type_=TIMESTAMP(), existing_nullable=True
+ )
+
+ with op.batch_alter_table("dag_run", schema=None) as batch_op:
+ batch_op.alter_column(
+ "execution_date", existing_type=sa.DATETIME(), type_=TIMESTAMP(), existing_nullable=False
+ )
+ batch_op.alter_column(
+ "start_date", existing_type=sa.DATETIME(), type_=TIMESTAMP(), existing_nullable=True
+ )
+ batch_op.alter_column(
+ "end_date", existing_type=sa.DATETIME(), type_=TIMESTAMP(), existing_nullable=True
+ )
+
+ with op.batch_alter_table("import_error", schema=None) as batch_op:
+ batch_op.alter_column(
+ "timestamp", existing_type=sa.DATETIME(), type_=TIMESTAMP(), existing_nullable=True
+ )
+
+ with op.batch_alter_table("job", schema=None) as batch_op:
+ batch_op.alter_column(
+ "start_date", existing_type=sa.DATETIME(), type_=TIMESTAMP(), existing_nullable=True
+ )
+ batch_op.alter_column(
+ "end_date", existing_type=sa.DATETIME(), type_=TIMESTAMP(), existing_nullable=True
+ )
+ batch_op.alter_column(
+ "latest_heartbeat", existing_type=sa.DATETIME(), type_=TIMESTAMP(), existing_nullable=True
+ )
+
+ with op.batch_alter_table("log", schema=None) as batch_op:
+ batch_op.alter_column("dttm", existing_type=sa.DATETIME(), type_=TIMESTAMP(), existing_nullable=True)
+ batch_op.alter_column(
+ "execution_date", existing_type=sa.DATETIME(), type_=TIMESTAMP(), existing_nullable=True
+ )
+
+ with op.batch_alter_table("serialized_dag", schema=None) as batch_op:
+ batch_op.alter_column(
+ "last_updated", existing_type=sa.DATETIME(), type_=TIMESTAMP(), existing_nullable=False
+ )
+
+ with op.batch_alter_table("sla_miss", schema=None) as batch_op:
+ batch_op.alter_column(
+ "execution_date", existing_type=sa.DATETIME(), type_=TIMESTAMP(), existing_nullable=False
+ )
+ batch_op.alter_column(
+ "timestamp", existing_type=sa.DATETIME(), type_=TIMESTAMP(), existing_nullable=True
+ )
+
+ with op.batch_alter_table("task_fail", schema=None) as batch_op:
+ batch_op.alter_column(
+ "start_date", existing_type=sa.DATETIME(), type_=TIMESTAMP(), existing_nullable=True
+ )
+ batch_op.alter_column(
+ "end_date", existing_type=sa.DATETIME(), type_=TIMESTAMP(), existing_nullable=True
+ )
+
+ with op.batch_alter_table("task_instance", schema=None) as batch_op:
+ batch_op.alter_column(
+ "start_date", existing_type=sa.DATETIME(), type_=TIMESTAMP(), existing_nullable=True
+ )
+ batch_op.alter_column(
+ "end_date", existing_type=sa.DATETIME(), type_=TIMESTAMP(), existing_nullable=True
+ )
+ batch_op.alter_column(
+ "queued_dttm", existing_type=sa.DATETIME(), type_=TIMESTAMP(), existing_nullable=True
+ )
+
+
+def downgrade():
+ """Unapply compare types between ORM and DB."""
+ with op.batch_alter_table("connection", schema=None) as batch_op:
+ batch_op.alter_column(
+ "extra",
+ existing_type=sa.Text(),
+ type_=sa.TEXT(),
+ existing_nullable=True,
+ )
+ with op.batch_alter_table("log_template", schema=None) as batch_op:
+ batch_op.alter_column(
+ "created_at", existing_type=TIMESTAMP(), type_=sa.DateTime(), existing_nullable=False
+ )
+ with op.batch_alter_table("serialized_dag", schema=None) as batch_op:
+ # add server_default
+ batch_op.alter_column(
+ "dag_hash",
+ existing_type=sa.String(32),
+ server_default="Hash not calculated yet",
+ type_=sa.String(32),
+ existing_nullable=False,
+ )
+ with op.batch_alter_table("trigger", schema=None) as batch_op:
+ batch_op.alter_column(
+ "created_date", existing_type=TIMESTAMP(), type_=sa.DateTime(), existing_nullable=False
+ )
+ conn = op.get_bind()
+
+ if conn.dialect.name != "sqlite":
+ return
+ with op.batch_alter_table("serialized_dag", schema=None) as batch_op:
+ batch_op.alter_column("fileloc_hash", existing_type=sa.BigInteger, type_=sa.Integer())
+ # Change these column back to sa.DATETIME()
+ with op.batch_alter_table("task_instance", schema=None) as batch_op:
+ batch_op.alter_column(
+ "queued_dttm", existing_type=TIMESTAMP(), type_=sa.DATETIME(), existing_nullable=True
+ )
+ batch_op.alter_column(
+ "end_date", existing_type=TIMESTAMP(), type_=sa.DATETIME(), existing_nullable=True
+ )
+ batch_op.alter_column(
+ "start_date", existing_type=TIMESTAMP(), type_=sa.DATETIME(), existing_nullable=True
+ )
+
+ with op.batch_alter_table("task_fail", schema=None) as batch_op:
+ batch_op.alter_column(
+ "end_date", existing_type=TIMESTAMP(), type_=sa.DATETIME(), existing_nullable=True
+ )
+ batch_op.alter_column(
+ "start_date", existing_type=TIMESTAMP(), type_=sa.DATETIME(), existing_nullable=True
+ )
+
+ with op.batch_alter_table("sla_miss", schema=None) as batch_op:
+ batch_op.alter_column(
+ "timestamp", existing_type=TIMESTAMP(), type_=sa.DATETIME(), existing_nullable=True
+ )
+ batch_op.alter_column(
+ "execution_date", existing_type=TIMESTAMP(), type_=sa.DATETIME(), existing_nullable=False
+ )
+
+ with op.batch_alter_table("serialized_dag", schema=None) as batch_op:
+ batch_op.alter_column(
+ "last_updated", existing_type=TIMESTAMP(), type_=sa.DATETIME(), existing_nullable=False
+ )
+
+ with op.batch_alter_table("log", schema=None) as batch_op:
+ batch_op.alter_column(
+ "execution_date", existing_type=TIMESTAMP(), type_=sa.DATETIME(), existing_nullable=True
+ )
+ batch_op.alter_column("dttm", existing_type=TIMESTAMP(), type_=sa.DATETIME(), existing_nullable=True)
+
+ with op.batch_alter_table("job", schema=None) as batch_op:
+ batch_op.alter_column(
+ "latest_heartbeat", existing_type=TIMESTAMP(), type_=sa.DATETIME(), existing_nullable=True
+ )
+ batch_op.alter_column(
+ "end_date", existing_type=TIMESTAMP(), type_=sa.DATETIME(), existing_nullable=True
+ )
+ batch_op.alter_column(
+ "start_date", existing_type=TIMESTAMP(), type_=sa.DATETIME(), existing_nullable=True
+ )
+
+ with op.batch_alter_table("import_error", schema=None) as batch_op:
+ batch_op.alter_column(
+ "timestamp", existing_type=TIMESTAMP(), type_=sa.DATETIME(), existing_nullable=True
+ )
+
+ with op.batch_alter_table("dag_run", schema=None) as batch_op:
+ batch_op.alter_column(
+ "end_date", existing_type=TIMESTAMP(), type_=sa.DATETIME(), existing_nullable=True
+ )
+ batch_op.alter_column(
+ "start_date", existing_type=TIMESTAMP(), type_=sa.DATETIME(), existing_nullable=True
+ )
+ batch_op.alter_column(
+ "execution_date", existing_type=TIMESTAMP(), type_=sa.DATETIME(), existing_nullable=False
+ )
+
+ with op.batch_alter_table("dag_pickle", schema=None) as batch_op:
+ batch_op.alter_column(
+ "created_dttm", existing_type=TIMESTAMP(), type_=sa.DATETIME(), existing_nullable=True
+ )
+
+ with op.batch_alter_table("dag", schema=None) as batch_op:
+ batch_op.alter_column(
+ "last_expired", existing_type=TIMESTAMP(), type_=sa.DATETIME(), existing_nullable=True
+ )
+ batch_op.alter_column(
+ "last_pickled", existing_type=TIMESTAMP(), type_=sa.DATETIME(), existing_nullable=True
+ )
diff --git a/airflow/migrations/versions/0114_2_4_0_add_dataset_model.py b/airflow/migrations/versions/0114_2_4_0_add_dataset_model.py
new file mode 100644
index 0000000000000..b6671a8c075a2
--- /dev/null
+++ b/airflow/migrations/versions/0114_2_4_0_add_dataset_model.py
@@ -0,0 +1,190 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Add Dataset model
+
+Revision ID: 0038cd0c28b4
+Revises: 44b7034f6bdc
+Create Date: 2022-06-22 14:37:20.880672
+
+"""
+from __future__ import annotations
+
+import sqlalchemy as sa
+import sqlalchemy_jsonfield
+from alembic import op
+from sqlalchemy import Integer, String, func
+
+from airflow.migrations.db_types import TIMESTAMP, StringID
+from airflow.settings import json
+
+revision = "0038cd0c28b4"
+down_revision = "44b7034f6bdc"
+branch_labels = None
+depends_on = None
+airflow_version = "2.4.0"
+
+
+def _create_dataset_table():
+ op.create_table(
+ "dataset",
+ sa.Column("id", Integer, primary_key=True, autoincrement=True),
+ sa.Column(
+ "uri",
+ String(length=3000).with_variant(
+ String(
+ length=3000,
+ # latin1 allows for more indexed length in mysql
+ # and this field should only be ascii chars
+ collation="latin1_general_cs",
+ ),
+ "mysql",
+ ),
+ nullable=False,
+ ),
+ sa.Column("extra", sqlalchemy_jsonfield.JSONField(json=json), nullable=False, default={}),
+ sa.Column("created_at", TIMESTAMP, nullable=False),
+ sa.Column("updated_at", TIMESTAMP, nullable=False),
+ sqlite_autoincrement=True, # ensures PK values not reused
+ )
+ op.create_index("idx_uri_unique", "dataset", ["uri"], unique=True)
+
+
+def _create_dag_schedule_dataset_reference_table():
+ op.create_table(
+ "dag_schedule_dataset_reference",
+ sa.Column("dataset_id", Integer, primary_key=True, nullable=False),
+ sa.Column("dag_id", StringID(), primary_key=True, nullable=False),
+ sa.Column("created_at", TIMESTAMP, default=func.now, nullable=False),
+ sa.Column("updated_at", TIMESTAMP, default=func.now, nullable=False),
+ sa.ForeignKeyConstraint(
+ ("dataset_id",),
+ ["dataset.id"],
+ name="dsdr_dataset_fkey",
+ ondelete="CASCADE",
+ ),
+ sa.ForeignKeyConstraint(
+ columns=("dag_id",),
+ refcolumns=["dag.dag_id"],
+ name="dsdr_dag_id_fkey",
+ ondelete="CASCADE",
+ ),
+ )
+
+
+def _create_task_outlet_dataset_reference_table():
+ op.create_table(
+ "task_outlet_dataset_reference",
+ sa.Column("dataset_id", Integer, primary_key=True, nullable=False),
+ sa.Column("dag_id", StringID(), primary_key=True, nullable=False),
+ sa.Column("task_id", StringID(), primary_key=True, nullable=False),
+ sa.Column("created_at", TIMESTAMP, default=func.now, nullable=False),
+ sa.Column("updated_at", TIMESTAMP, default=func.now, nullable=False),
+ sa.ForeignKeyConstraint(
+ ("dataset_id",),
+ ["dataset.id"],
+ name="todr_dataset_fkey",
+ ondelete="CASCADE",
+ ),
+ sa.ForeignKeyConstraint(
+ columns=("dag_id",),
+ refcolumns=["dag.dag_id"],
+ name="todr_dag_id_fkey",
+ ondelete="CASCADE",
+ ),
+ )
+
+
+def _create_dataset_dag_run_queue_table():
+ op.create_table(
+ "dataset_dag_run_queue",
+ sa.Column("dataset_id", Integer, primary_key=True, nullable=False),
+ sa.Column("target_dag_id", StringID(), primary_key=True, nullable=False),
+ sa.Column("created_at", TIMESTAMP, default=func.now, nullable=False),
+ sa.ForeignKeyConstraint(
+ ("dataset_id",),
+ ["dataset.id"],
+ name="ddrq_dataset_fkey",
+ ondelete="CASCADE",
+ ),
+ sa.ForeignKeyConstraint(
+ ("target_dag_id",),
+ ["dag.dag_id"],
+ name="ddrq_dag_fkey",
+ ondelete="CASCADE",
+ ),
+ )
+
+
+def _create_dataset_event_table():
+ op.create_table(
+ "dataset_event",
+ sa.Column("id", Integer, primary_key=True, autoincrement=True),
+ sa.Column("dataset_id", Integer, nullable=False),
+ sa.Column("extra", sqlalchemy_jsonfield.JSONField(json=json), nullable=False, default={}),
+ sa.Column("source_task_id", String(250), nullable=True),
+ sa.Column("source_dag_id", String(250), nullable=True),
+ sa.Column("source_run_id", String(250), nullable=True),
+ sa.Column("source_map_index", sa.Integer(), nullable=True, server_default="-1"),
+ sa.Column("timestamp", TIMESTAMP, nullable=False),
+ sqlite_autoincrement=True, # ensures PK values not reused
+ )
+ op.create_index("idx_dataset_id_timestamp", "dataset_event", ["dataset_id", "timestamp"])
+
+
+def _create_dataset_event_dag_run_table():
+ op.create_table(
+ "dagrun_dataset_event",
+ sa.Column("dag_run_id", sa.Integer(), nullable=False),
+ sa.Column("event_id", sa.Integer(), nullable=False),
+ sa.ForeignKeyConstraint(
+ ["dag_run_id"],
+ ["dag_run.id"],
+ name=op.f("dagrun_dataset_events_dag_run_id_fkey"),
+ ondelete="CASCADE",
+ ),
+ sa.ForeignKeyConstraint(
+ ["event_id"],
+ ["dataset_event.id"],
+ name=op.f("dagrun_dataset_events_event_id_fkey"),
+ ondelete="CASCADE",
+ ),
+ sa.PrimaryKeyConstraint("dag_run_id", "event_id", name=op.f("dagrun_dataset_events_pkey")),
+ )
+ with op.batch_alter_table("dagrun_dataset_event") as batch_op:
+ batch_op.create_index("idx_dagrun_dataset_events_dag_run_id", ["dag_run_id"], unique=False)
+ batch_op.create_index("idx_dagrun_dataset_events_event_id", ["event_id"], unique=False)
+
+
+def upgrade():
+ """Apply Add Dataset model"""
+ _create_dataset_table()
+ _create_dag_schedule_dataset_reference_table()
+ _create_task_outlet_dataset_reference_table()
+ _create_dataset_dag_run_queue_table()
+ _create_dataset_event_table()
+ _create_dataset_event_dag_run_table()
+
+
+def downgrade():
+ """Unapply Add Dataset model"""
+ op.drop_table("dag_schedule_dataset_reference")
+ op.drop_table("task_outlet_dataset_reference")
+ op.drop_table("dataset_dag_run_queue")
+ op.drop_table("dagrun_dataset_event")
+ op.drop_table("dataset_event")
+ op.drop_table("dataset")
diff --git a/airflow/migrations/versions/0115_2_4_0_remove_smart_sensors.py b/airflow/migrations/versions/0115_2_4_0_remove_smart_sensors.py
new file mode 100644
index 0000000000000..2f1c9994345fe
--- /dev/null
+++ b/airflow/migrations/versions/0115_2_4_0_remove_smart_sensors.py
@@ -0,0 +1,76 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Remove smart sensors
+
+Revision ID: f4ff391becb5
+Revises: 0038cd0c28b4
+Create Date: 2022-08-03 11:33:44.777945
+
+"""
+from __future__ import annotations
+
+import sqlalchemy as sa
+from alembic import op
+from sqlalchemy import func
+from sqlalchemy.sql import column, table
+
+from airflow.migrations.db_types import TIMESTAMP, StringID
+
+# revision identifiers, used by Alembic.
+revision = "f4ff391becb5"
+down_revision = "0038cd0c28b4"
+branch_labels = None
+depends_on = None
+airflow_version = "2.4.0"
+
+
+def upgrade():
+ """Apply Remove smart sensors"""
+ op.drop_table("sensor_instance")
+
+ """Minimal model definition for migrations"""
+ task_instance = table("task_instance", column("state", sa.String))
+ op.execute(task_instance.update().where(task_instance.c.state == "sensing").values({"state": "failed"}))
+
+
+def downgrade():
+ """Unapply Remove smart sensors"""
+ op.create_table(
+ "sensor_instance",
+ sa.Column("id", sa.Integer(), nullable=False),
+ sa.Column("task_id", StringID(), nullable=False),
+ sa.Column("dag_id", StringID(), nullable=False),
+ sa.Column("execution_date", TIMESTAMP, nullable=False),
+ sa.Column("state", sa.String(length=20), nullable=True),
+ sa.Column("try_number", sa.Integer(), nullable=True),
+ sa.Column("start_date", TIMESTAMP, nullable=True),
+ sa.Column("operator", sa.String(length=1000), nullable=False),
+ sa.Column("op_classpath", sa.String(length=1000), nullable=False),
+ sa.Column("hashcode", sa.BigInteger(), nullable=False),
+ sa.Column("shardcode", sa.Integer(), nullable=False),
+ sa.Column("poke_context", sa.Text(), nullable=False),
+ sa.Column("execution_context", sa.Text(), nullable=True),
+ sa.Column("created_at", TIMESTAMP, default=func.now, nullable=False),
+ sa.Column("updated_at", TIMESTAMP, default=func.now, nullable=False),
+ sa.PrimaryKeyConstraint("id"),
+ )
+ op.create_index("ti_primary_key", "sensor_instance", ["dag_id", "task_id", "execution_date"], unique=True)
+ op.create_index("si_hashcode", "sensor_instance", ["hashcode"], unique=False)
+ op.create_index("si_shardcode", "sensor_instance", ["shardcode"], unique=False)
+ op.create_index("si_state_shard", "sensor_instance", ["state", "shardcode"], unique=False)
+ op.create_index("si_updated_at", "sensor_instance", ["updated_at"], unique=False)
diff --git a/airflow/migrations/versions/0116_2_4_0_add_dag_owner_attributes_table.py b/airflow/migrations/versions/0116_2_4_0_add_dag_owner_attributes_table.py
new file mode 100644
index 0000000000000..5b4d8c2e6baa2
--- /dev/null
+++ b/airflow/migrations/versions/0116_2_4_0_add_dag_owner_attributes_table.py
@@ -0,0 +1,54 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""add dag_owner_attributes table
+
+Revision ID: 1486deb605b4
+Revises: f4ff391becb5
+Create Date: 2022-08-04 16:59:45.406589
+
+"""
+from __future__ import annotations
+
+import sqlalchemy as sa
+from alembic import op
+
+from airflow.migrations.db_types import StringID
+
+# revision identifiers, used by Alembic.
+revision = "1486deb605b4"
+down_revision = "f4ff391becb5"
+branch_labels = None
+depends_on = None
+airflow_version = "2.4.0"
+
+
+def upgrade():
+ """Apply Add ``DagOwnerAttributes`` table"""
+ op.create_table(
+ "dag_owner_attributes",
+ sa.Column("dag_id", StringID(), nullable=False),
+ sa.Column("owner", sa.String(length=500), nullable=False),
+ sa.Column("link", sa.String(length=500), nullable=False),
+ sa.ForeignKeyConstraint(["dag_id"], ["dag.dag_id"], ondelete="CASCADE"),
+ sa.PrimaryKeyConstraint("dag_id", "owner"),
+ )
+
+
+def downgrade():
+ """Unapply Add Dataset model"""
+ op.drop_table("dag_owner_attributes")
diff --git a/airflow/migrations/versions/0117_2_4_0_add_processor_subdir_to_dagmodel_and_.py b/airflow/migrations/versions/0117_2_4_0_add_processor_subdir_to_dagmodel_and_.py
new file mode 100644
index 0000000000000..6846d4ebb8943
--- /dev/null
+++ b/airflow/migrations/versions/0117_2_4_0_add_processor_subdir_to_dagmodel_and_.py
@@ -0,0 +1,76 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Add processor_subdir column to DagModel, SerializedDagModel and CallbackRequest tables.
+
+Revision ID: ecb43d2a1842
+Revises: 1486deb605b4
+Create Date: 2022-08-26 11:30:11.249580
+
+"""
+from __future__ import annotations
+
+import sqlalchemy as sa
+from alembic import op
+
+# revision identifiers, used by Alembic.
+revision = "ecb43d2a1842"
+down_revision = "1486deb605b4"
+branch_labels = None
+depends_on = None
+airflow_version = "2.4.0"
+
+
+def upgrade():
+ """Apply add processor_subdir to DagModel and SerializedDagModel"""
+ conn = op.get_bind()
+
+ with op.batch_alter_table("dag") as batch_op:
+ if conn.dialect.name == "mysql":
+ batch_op.add_column(sa.Column("processor_subdir", sa.Text(length=2000), nullable=True))
+ else:
+ batch_op.add_column(sa.Column("processor_subdir", sa.String(length=2000), nullable=True))
+
+ with op.batch_alter_table("serialized_dag") as batch_op:
+ if conn.dialect.name == "mysql":
+ batch_op.add_column(sa.Column("processor_subdir", sa.Text(length=2000), nullable=True))
+ else:
+ batch_op.add_column(sa.Column("processor_subdir", sa.String(length=2000), nullable=True))
+
+ with op.batch_alter_table("callback_request") as batch_op:
+ batch_op.drop_column("dag_directory")
+ if conn.dialect.name == "mysql":
+ batch_op.add_column(sa.Column("processor_subdir", sa.Text(length=2000), nullable=True))
+ else:
+ batch_op.add_column(sa.Column("processor_subdir", sa.String(length=2000), nullable=True))
+
+
+def downgrade():
+ """Unapply Add processor_subdir to DagModel and SerializedDagModel"""
+ conn = op.get_bind()
+ with op.batch_alter_table("dag", schema=None) as batch_op:
+ batch_op.drop_column("processor_subdir")
+
+ with op.batch_alter_table("serialized_dag", schema=None) as batch_op:
+ batch_op.drop_column("processor_subdir")
+
+ with op.batch_alter_table("callback_request") as batch_op:
+ batch_op.drop_column("processor_subdir")
+ if conn.dialect.name == "mysql":
+ batch_op.add_column(sa.Column("dag_directory", sa.Text(length=1000), nullable=True))
+ else:
+ batch_op.add_column(sa.Column("dag_directory", sa.String(length=1000), nullable=True))
diff --git a/airflow/migrations/versions/0118_2_4_2_add_missing_autoinc_fab.py b/airflow/migrations/versions/0118_2_4_2_add_missing_autoinc_fab.py
new file mode 100644
index 0000000000000..6c4f010e2f353
--- /dev/null
+++ b/airflow/migrations/versions/0118_2_4_2_add_missing_autoinc_fab.py
@@ -0,0 +1,78 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Add missing auto-increment to columns on FAB tables
+
+Revision ID: b0d31815b5a6
+Revises: ecb43d2a1842
+Create Date: 2022-10-05 13:16:45.638490
+
+"""
+
+from __future__ import annotations
+
+import sqlalchemy as sa
+from alembic import op
+
+# revision identifiers, used by Alembic.
+revision = "b0d31815b5a6"
+down_revision = "ecb43d2a1842"
+branch_labels = None
+depends_on = None
+airflow_version = "2.4.2"
+
+
+def upgrade():
+ """Apply migration.
+
+ If these columns are already of the right type (i.e. created by our
+ migration in 1.10.13 rather than FAB itself in an earlier version), this
+ migration will issue an alter statement to change them to what they already
+ are -- i.e. its a no-op.
+
+ These tables are small (100 to low 1k rows at most), so it's not too costly
+ to change them.
+ """
+ conn = op.get_bind()
+ if conn.dialect.name in ["mssql", "sqlite"]:
+ # 1.10.12 didn't support SQL Server, so it couldn't have gotten this wrong --> nothing to correct
+ # SQLite autoinc was "implicit" for an INTEGER NOT NULL PRIMARY KEY
+ return
+
+ for table in (
+ "ab_permission",
+ "ab_view_menu",
+ "ab_role",
+ "ab_permission_view",
+ "ab_permission_view_role",
+ "ab_user",
+ "ab_user_role",
+ "ab_register_user",
+ ):
+ with op.batch_alter_table(table) as batch:
+ kwargs = {}
+ if conn.dialect.name == "postgresql":
+ kwargs["server_default"] = sa.Sequence(f"{table}_id_seq").next_value()
+ else:
+ kwargs["autoincrement"] = True
+ batch.alter_column("id", existing_type=sa.Integer(), existing_nullable=False, **kwargs)
+
+
+def downgrade():
+ """Unapply add_missing_autoinc_fab"""
+ # No downgrade needed, these _should_ have applied from 1.10.13 but didn't due to a previous bug!
diff --git a/airflow/migrations/versions/0119_2_4_3_add_case_insensitive_unique_constraint_for_username.py b/airflow/migrations/versions/0119_2_4_3_add_case_insensitive_unique_constraint_for_username.py
new file mode 100644
index 0000000000000..69b52f0864339
--- /dev/null
+++ b/airflow/migrations/versions/0119_2_4_3_add_case_insensitive_unique_constraint_for_username.py
@@ -0,0 +1,89 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Add case-insensitive unique constraint for username
+
+Revision ID: e07f49787c9d
+Revises: b0d31815b5a6
+Create Date: 2022-10-25 17:29:46.432326
+
+"""
+
+from __future__ import annotations
+
+import sqlalchemy as sa
+from alembic import op
+
+# revision identifiers, used by Alembic.
+revision = "e07f49787c9d"
+down_revision = "b0d31815b5a6"
+branch_labels = None
+depends_on = None
+airflow_version = "2.4.3"
+
+
+def upgrade():
+ """Apply Add case-insensitive unique constraint"""
+ conn = op.get_bind()
+ if conn.dialect.name == "postgresql":
+ op.create_index("idx_ab_user_username", "ab_user", [sa.text("LOWER(username)")], unique=True)
+ op.create_index(
+ "idx_ab_register_user_username", "ab_register_user", [sa.text("LOWER(username)")], unique=True
+ )
+ elif conn.dialect.name == "sqlite":
+ with op.batch_alter_table("ab_user") as batch_op:
+ batch_op.alter_column(
+ "username",
+ existing_type=sa.String(64),
+ _type=sa.String(64, collation="NOCASE"),
+ unique=True,
+ nullable=False,
+ )
+ with op.batch_alter_table("ab_register_user") as batch_op:
+ batch_op.alter_column(
+ "username",
+ existing_type=sa.String(64),
+ _type=sa.String(64, collation="NOCASE"),
+ unique=True,
+ nullable=False,
+ )
+
+
+def downgrade():
+ """Unapply Add case-insensitive unique constraint"""
+ conn = op.get_bind()
+ if conn.dialect.name == "postgresql":
+ op.drop_index("idx_ab_user_username", table_name="ab_user")
+ op.drop_index("idx_ab_register_user_username", table_name="ab_register_user")
+ elif conn.dialect.name == "sqlite":
+ with op.batch_alter_table("ab_user") as batch_op:
+ batch_op.alter_column(
+ "username",
+ existing_type=sa.String(64, collation="NOCASE"),
+ _type=sa.String(64),
+ unique=True,
+ nullable=False,
+ )
+ with op.batch_alter_table("ab_register_user") as batch_op:
+ batch_op.alter_column(
+ "username",
+ existing_type=sa.String(64, collation="NOCASE"),
+ _type=sa.String(64),
+ unique=True,
+ nullable=False,
+ )
diff --git a/airflow/migrations/versions/0120_2_5_0_add_updated_at_to_dagrun_and_ti.py b/airflow/migrations/versions/0120_2_5_0_add_updated_at_to_dagrun_and_ti.py
new file mode 100644
index 0000000000000..f51f89001e82b
--- /dev/null
+++ b/airflow/migrations/versions/0120_2_5_0_add_updated_at_to_dagrun_and_ti.py
@@ -0,0 +1,57 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Add updated_at column to DagRun and TaskInstance
+
+Revision ID: ee8d93fcc81e
+Revises: e07f49787c9d
+Create Date: 2022-09-08 19:08:37.623121
+
+"""
+
+from __future__ import annotations
+
+import sqlalchemy as sa
+from alembic import op
+
+from airflow.migrations.db_types import TIMESTAMP
+
+# revision identifiers, used by Alembic.
+revision = "ee8d93fcc81e"
+down_revision = "e07f49787c9d"
+branch_labels = None
+depends_on = None
+airflow_version = "2.5.0"
+
+
+def upgrade():
+ """Apply add updated_at column to DagRun and TaskInstance"""
+ with op.batch_alter_table("task_instance") as batch_op:
+ batch_op.add_column(sa.Column("updated_at", TIMESTAMP, default=sa.func.now))
+
+ with op.batch_alter_table("dag_run") as batch_op:
+ batch_op.add_column(sa.Column("updated_at", TIMESTAMP, default=sa.func.now))
+
+
+def downgrade():
+ """Unapply add updated_at column to DagRun and TaskInstance"""
+ with op.batch_alter_table("task_instance") as batch_op:
+ batch_op.drop_column("updated_at")
+
+ with op.batch_alter_table("dag_run") as batch_op:
+ batch_op.drop_column("updated_at")
diff --git a/airflow/migrations/versions/0121_2_5_0_add_dagrunnote_and_taskinstancenote.py b/airflow/migrations/versions/0121_2_5_0_add_dagrunnote_and_taskinstancenote.py
new file mode 100644
index 0000000000000..d13fba38aa70c
--- /dev/null
+++ b/airflow/migrations/versions/0121_2_5_0_add_dagrunnote_and_taskinstancenote.py
@@ -0,0 +1,94 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Add DagRunNote and TaskInstanceNote
+
+Revision ID: 1986afd32c1b
+Revises: ee8d93fcc81e
+Create Date: 2022-11-22 21:49:05.843439
+
+"""
+
+from __future__ import annotations
+
+import sqlalchemy as sa
+from alembic import op
+
+from airflow.migrations.db_types import StringID
+from airflow.utils.sqlalchemy import UtcDateTime
+
+# revision identifiers, used by Alembic.
+revision = "1986afd32c1b"
+down_revision = "ee8d93fcc81e"
+branch_labels = None
+depends_on = None
+airflow_version = "2.5.0"
+
+
+def upgrade():
+ """Apply Add DagRunNote and TaskInstanceNote"""
+ op.create_table(
+ "dag_run_note",
+ sa.Column("user_id", sa.Integer(), nullable=True),
+ sa.Column("dag_run_id", sa.Integer(), nullable=False),
+ sa.Column(
+ "content", sa.String(length=1000).with_variant(sa.Text(length=1000), "mysql"), nullable=True
+ ),
+ sa.Column("created_at", UtcDateTime(timezone=True), nullable=False),
+ sa.Column("updated_at", UtcDateTime(timezone=True), nullable=False),
+ sa.ForeignKeyConstraint(
+ ("dag_run_id",), ["dag_run.id"], name="dag_run_note_dr_fkey", ondelete="CASCADE"
+ ),
+ sa.ForeignKeyConstraint(("user_id",), ["ab_user.id"], name="dag_run_note_user_fkey"),
+ sa.PrimaryKeyConstraint("dag_run_id", name=op.f("dag_run_note_pkey")),
+ )
+
+ op.create_table(
+ "task_instance_note",
+ sa.Column("user_id", sa.Integer(), nullable=True),
+ sa.Column("task_id", StringID(), nullable=False),
+ sa.Column("dag_id", StringID(), nullable=False),
+ sa.Column("run_id", StringID(), nullable=False),
+ sa.Column("map_index", sa.Integer(), nullable=False),
+ sa.Column(
+ "content", sa.String(length=1000).with_variant(sa.Text(length=1000), "mysql"), nullable=True
+ ),
+ sa.Column("created_at", UtcDateTime(timezone=True), nullable=False),
+ sa.Column("updated_at", UtcDateTime(timezone=True), nullable=False),
+ sa.PrimaryKeyConstraint(
+ "task_id", "dag_id", "run_id", "map_index", name=op.f("task_instance_note_pkey")
+ ),
+ sa.ForeignKeyConstraint(
+ ("dag_id", "task_id", "run_id", "map_index"),
+ [
+ "task_instance.dag_id",
+ "task_instance.task_id",
+ "task_instance.run_id",
+ "task_instance.map_index",
+ ],
+ name="task_instance_note_ti_fkey",
+ ondelete="CASCADE",
+ ),
+ sa.ForeignKeyConstraint(("user_id",), ["ab_user.id"], name="task_instance_note_user_fkey"),
+ )
+
+
+def downgrade():
+ """Unapply Add DagRunNote and TaskInstanceNote"""
+ op.drop_table("task_instance_note")
+ op.drop_table("dag_run_note")
diff --git a/airflow/migrations/versions/0122_2_5_0_add_is_orphaned_to_datasetmodel.py b/airflow/migrations/versions/0122_2_5_0_add_is_orphaned_to_datasetmodel.py
new file mode 100644
index 0000000000000..f4355402c84d6
--- /dev/null
+++ b/airflow/migrations/versions/0122_2_5_0_add_is_orphaned_to_datasetmodel.py
@@ -0,0 +1,57 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Add is_orphaned to DatasetModel
+
+Revision ID: 290244fb8b83
+Revises: 1986afd32c1b
+Create Date: 2022-11-22 00:12:53.432961
+
+"""
+
+from __future__ import annotations
+
+import sqlalchemy as sa
+from alembic import op
+
+# revision identifiers, used by Alembic.
+revision = "290244fb8b83"
+down_revision = "1986afd32c1b"
+branch_labels = None
+depends_on = None
+airflow_version = "2.5.0"
+
+
+def upgrade():
+ """Add is_orphaned to DatasetModel"""
+ with op.batch_alter_table("dataset") as batch_op:
+ batch_op.add_column(
+ sa.Column(
+ "is_orphaned",
+ sa.Boolean,
+ default=False,
+ nullable=False,
+ server_default="0",
+ )
+ )
+
+
+def downgrade():
+ """Remove is_orphaned from DatasetModel"""
+ with op.batch_alter_table("dataset") as batch_op:
+ batch_op.drop_column("is_orphaned", mssql_drop_default=True)
diff --git a/airflow/models/__init__.py b/airflow/models/__init__.py
index c36dfb1bca0ba..c750757cb048d 100644
--- a/airflow/models/__init__.py
+++ b/airflow/models/__init__.py
@@ -16,33 +16,9 @@
# specific language governing permissions and limitations
# under the License.
"""Airflow models"""
-from typing import Union
-
-from airflow.models.base import ID_LEN, Base
-from airflow.models.baseoperator import BaseOperator, BaseOperatorLink
-from airflow.models.connection import Connection
-from airflow.models.dag import DAG, DagModel, DagTag
-from airflow.models.dagbag import DagBag
-from airflow.models.dagpickle import DagPickle
-from airflow.models.dagrun import DagRun
-from airflow.models.db_callback_request import DbCallbackRequest
-from airflow.models.errors import ImportError
-from airflow.models.log import Log
-from airflow.models.mappedoperator import MappedOperator
-from airflow.models.operator import Operator
-from airflow.models.param import Param
-from airflow.models.pool import Pool
-from airflow.models.renderedtifields import RenderedTaskInstanceFields
-from airflow.models.sensorinstance import SensorInstance
-from airflow.models.skipmixin import SkipMixin
-from airflow.models.slamiss import SlaMiss
-from airflow.models.taskfail import TaskFail
-from airflow.models.taskinstance import TaskInstance, clear_task_instances
-from airflow.models.taskreschedule import TaskReschedule
-from airflow.models.trigger import Trigger
-from airflow.models.variable import Variable
-from airflow.models.xcom import XCOM_RETURN_KEY, XCom
+from __future__ import annotations
+# Do not add new models to this -- this is for compat only
__all__ = [
"DAG",
"ID_LEN",
@@ -52,6 +28,7 @@
"BaseOperatorLink",
"Connection",
"DagBag",
+ "DagWarning",
"DagModel",
"DagPickle",
"DagRun",
@@ -64,7 +41,6 @@
"Param",
"Pool",
"RenderedTaskInstanceFields",
- "SensorInstance",
"SkipMixin",
"SlaMiss",
"TaskFail",
@@ -75,3 +51,96 @@
"XCom",
"clear_task_instances",
]
+
+
+from typing import TYPE_CHECKING
+
+
+def import_all_models():
+ for name in __lazy_imports:
+ __getattr__(name)
+
+ import airflow.jobs.backfill_job
+ import airflow.jobs.base_job
+ import airflow.jobs.local_task_job
+ import airflow.jobs.scheduler_job
+ import airflow.jobs.triggerer_job
+ import airflow.models.dagwarning
+ import airflow.models.dataset
+ import airflow.models.serialized_dag
+ import airflow.models.tasklog
+ import airflow.www.fab_security.sqla.models
+
+
+def __getattr__(name):
+ # PEP-562: Lazy loaded attributes on python modules
+ path = __lazy_imports.get(name)
+ if not path:
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
+
+ from airflow.utils.module_loading import import_string
+
+ val = import_string(f"{path}.{name}")
+ # Store for next time
+ globals()[name] = val
+ return val
+
+
+__lazy_imports = {
+ "DAG": "airflow.models.dag",
+ "ID_LEN": "airflow.models.base",
+ "XCOM_RETURN_KEY": "airflow.models.xcom",
+ "Base": "airflow.models.base",
+ "BaseOperator": "airflow.models.baseoperator",
+ "BaseOperatorLink": "airflow.models.baseoperator",
+ "Connection": "airflow.models.connection",
+ "DagBag": "airflow.models.dagbag",
+ "DagModel": "airflow.models.dag",
+ "DagPickle": "airflow.models.dagpickle",
+ "DagRun": "airflow.models.dagrun",
+ "DagTag": "airflow.models.dag",
+ "DbCallbackRequest": "airflow.models.db_callback_request",
+ "ImportError": "airflow.models.errors",
+ "Log": "airflow.models.log",
+ "MappedOperator": "airflow.models.mappedoperator",
+ "Operator": "airflow.models.operator",
+ "Param": "airflow.models.param",
+ "Pool": "airflow.models.pool",
+ "RenderedTaskInstanceFields": "airflow.models.renderedtifields",
+ "SkipMixin": "airflow.models.skipmixin",
+ "SlaMiss": "airflow.models.slamiss",
+ "TaskFail": "airflow.models.taskfail",
+ "TaskInstance": "airflow.models.taskinstance",
+ "TaskReschedule": "airflow.models.taskreschedule",
+ "Trigger": "airflow.models.trigger",
+ "Variable": "airflow.models.variable",
+ "XCom": "airflow.models.xcom",
+ "clear_task_instances": "airflow.models.taskinstance",
+}
+
+if TYPE_CHECKING:
+ # I was unable to get mypy to respect a airflow/models/__init__.pyi, so
+ # having to resort back to this hacky method
+ from airflow.models.base import ID_LEN, Base
+ from airflow.models.baseoperator import BaseOperator, BaseOperatorLink
+ from airflow.models.connection import Connection
+ from airflow.models.dag import DAG, DagModel, DagTag
+ from airflow.models.dagbag import DagBag
+ from airflow.models.dagpickle import DagPickle
+ from airflow.models.dagrun import DagRun
+ from airflow.models.db_callback_request import DbCallbackRequest
+ from airflow.models.errors import ImportError
+ from airflow.models.log import Log
+ from airflow.models.mappedoperator import MappedOperator
+ from airflow.models.operator import Operator
+ from airflow.models.param import Param
+ from airflow.models.pool import Pool
+ from airflow.models.renderedtifields import RenderedTaskInstanceFields
+ from airflow.models.skipmixin import SkipMixin
+ from airflow.models.slamiss import SlaMiss
+ from airflow.models.taskfail import TaskFail
+ from airflow.models.taskinstance import TaskInstance, clear_task_instances
+ from airflow.models.taskreschedule import TaskReschedule
+ from airflow.models.trigger import Trigger
+ from airflow.models.variable import Variable
+ from airflow.models.xcom import XCOM_RETURN_KEY, XCom
diff --git a/airflow/models/abstractoperator.py b/airflow/models/abstractoperator.py
index cb566ea6f9557..d693f8bfc95cb 100644
--- a/airflow/models/abstractoperator.py
+++ b/airflow/models/abstractoperator.py
@@ -15,34 +15,25 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
import datetime
import inspect
-from typing import (
- TYPE_CHECKING,
- Any,
- Callable,
- ClassVar,
- Collection,
- Dict,
- FrozenSet,
- Iterable,
- List,
- Optional,
- Sequence,
- Set,
- Type,
- Union,
-)
+from typing import TYPE_CHECKING, Any, Callable, ClassVar, Collection, Iterable, Iterator, Sequence
-from airflow.compat.functools import cached_property
+from airflow.compat.functools import cache, cached_property
from airflow.configuration import conf
from airflow.exceptions import AirflowException
+from airflow.models.expandinput import NotFullyPopulated
from airflow.models.taskmixin import DAGNode
from airflow.utils.context import Context
from airflow.utils.helpers import render_template_as_native, render_template_to_string
from airflow.utils.log.logging_mixin import LoggingMixin
+from airflow.utils.mixins import ResolveMixin
from airflow.utils.session import NEW_SESSION, provide_session
+from airflow.utils.sqlalchemy import skip_locked, with_row_locks
+from airflow.utils.state import State, TaskInstanceState
+from airflow.utils.task_group import MappedTaskGroup
from airflow.utils.trigger_rule import TriggerRule
from airflow.utils.weight_rule import WeightRule
@@ -54,6 +45,7 @@
from airflow.models.baseoperator import BaseOperator, BaseOperatorLink
from airflow.models.dag import DAG
+ from airflow.models.mappedoperator import MappedOperator
from airflow.models.operator import Operator
from airflow.models.taskinstance import TaskInstance
@@ -72,11 +64,15 @@
conf.get("core", "default_task_weight_rule", fallback=WeightRule.DOWNSTREAM)
)
DEFAULT_TRIGGER_RULE: TriggerRule = TriggerRule.ALL_SUCCESS
-DEFAULT_TASK_EXECUTION_TIMEOUT: Optional[datetime.timedelta] = conf.gettimedelta(
+DEFAULT_TASK_EXECUTION_TIMEOUT: datetime.timedelta | None = conf.gettimedelta(
"core", "default_task_execution_timeout"
)
+class NotMapped(Exception):
+ """Raise if a task is neither mapped nor has any parent mapped groups."""
+
+
class AbstractOperator(LoggingMixin, DAGNode):
"""Common implementation for operators, including unmapped and mapped.
@@ -90,13 +86,13 @@ class AbstractOperator(LoggingMixin, DAGNode):
:meta private:
"""
- operator_class: Union[Type["BaseOperator"], Dict[str, Any]]
+ operator_class: type[BaseOperator] | dict[str, Any]
weight_rule: str
priority_weight: int
# Defines the operator level extra links.
- operator_extra_links: Collection["BaseOperatorLink"]
+ operator_extra_links: Collection[BaseOperatorLink]
# For derived classes to define which fields will get jinjaified.
template_fields: Collection[str]
# Defines which files extensions to look for in the templated fields.
@@ -105,32 +101,39 @@ class AbstractOperator(LoggingMixin, DAGNode):
owner: str
task_id: str
- HIDE_ATTRS_FROM_UI: ClassVar[FrozenSet[str]] = frozenset(
+ outlets: list
+ inlets: list
+
+ HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = frozenset(
(
- 'log',
- 'dag', # We show dag_id, don't need to show this too
- 'node_id', # Duplicates task_id
- 'task_group', # Doesn't have a useful repr, no point showing in UI
- 'inherits_from_empty_operator', # impl detail
+ "log",
+ "dag", # We show dag_id, don't need to show this too
+ "node_id", # Duplicates task_id
+ "task_group", # Doesn't have a useful repr, no point showing in UI
+ "inherits_from_empty_operator", # impl detail
# For compatibility with TG, for operators these are just the current task, no point showing
- 'roots',
- 'leaves',
+ "roots",
+ "leaves",
# These lists are already shown via *_task_ids
- 'upstream_list',
- 'downstream_list',
+ "upstream_list",
+ "downstream_list",
# Not useful, implementation detail, already shown elsewhere
- 'global_operator_extra_link_dict',
- 'operator_extra_link_dict',
+ "global_operator_extra_link_dict",
+ "operator_extra_link_dict",
)
)
- def get_dag(self) -> "Optional[DAG]":
+ def get_dag(self) -> DAG | None:
raise NotImplementedError()
@property
def task_type(self) -> str:
raise NotImplementedError()
+ @property
+ def operator_name(self) -> str:
+ raise NotImplementedError()
+
@property
def inherits_from_empty_operator(self) -> bool:
raise NotImplementedError()
@@ -147,7 +150,7 @@ def dag_id(self) -> str:
def node_id(self) -> str:
return self.task_id
- def get_template_env(self) -> "jinja2.Environment":
+ def get_template_env(self) -> jinja2.Environment:
"""Fetch a Jinja template environment from the DAG or instantiate empty environment if no DAG."""
# This is imported locally since Jinja2 is heavy and we don't need it
# for most of the functionalities. It is imported by get_template_env()
@@ -189,7 +192,7 @@ def resolve_template_files(self) -> None:
self.log.exception("Failed to get source %s", item)
self.prepare_template()
- def get_direct_relative_ids(self, upstream: bool = False) -> Set[str]:
+ def get_direct_relative_ids(self, upstream: bool = False) -> set[str]:
"""Get direct relative IDs to the current task, upstream or downstream."""
if upstream:
return self.upstream_task_ids
@@ -198,32 +201,116 @@ def get_direct_relative_ids(self, upstream: bool = False) -> Set[str]:
def get_flat_relative_ids(
self,
upstream: bool = False,
- found_descendants: Optional[Set[str]] = None,
- ) -> Set[str]:
+ found_descendants: set[str] | None = None,
+ ) -> set[str]:
"""Get a flat set of relative IDs, upstream or downstream."""
dag = self.get_dag()
if not dag:
return set()
- if not found_descendants:
+ if found_descendants is None:
found_descendants = set()
- relative_ids = self.get_direct_relative_ids(upstream)
- for relative_id in relative_ids:
- if relative_id not in found_descendants:
- found_descendants.add(relative_id)
- relative_task = dag.task_dict[relative_id]
- relative_task.get_flat_relative_ids(upstream, found_descendants)
+ task_ids_to_trace = self.get_direct_relative_ids(upstream)
+ while task_ids_to_trace:
+ task_ids_to_trace_next: set[str] = set()
+ for task_id in task_ids_to_trace:
+ if task_id in found_descendants:
+ continue
+ task_ids_to_trace_next.update(dag.task_dict[task_id].get_direct_relative_ids(upstream))
+ found_descendants.add(task_id)
+ task_ids_to_trace = task_ids_to_trace_next
return found_descendants
- def get_flat_relatives(self, upstream: bool = False) -> Collection["Operator"]:
+ def get_flat_relatives(self, upstream: bool = False) -> Collection[Operator]:
"""Get a flat list of relatives, either upstream or downstream."""
dag = self.get_dag()
if not dag:
return set()
return [dag.task_dict[task_id] for task_id in self.get_flat_relative_ids(upstream)]
+ def _iter_all_mapped_downstreams(self) -> Iterator[MappedOperator | MappedTaskGroup]:
+ """Return mapped nodes that are direct dependencies of the current task.
+
+ For now, this walks the entire DAG to find mapped nodes that has this
+ current task as an upstream. We cannot use ``downstream_list`` since it
+ only contains operators, not task groups. In the future, we should
+ provide a way to record an DAG node's all downstream nodes instead.
+
+ Note that this does not guarantee the returned tasks actually use the
+ current task for task mapping, but only checks those task are mapped
+ operators, and are downstreams of the current task.
+
+ To get a list of tasks that uses the current task for task mapping, use
+ :meth:`iter_mapped_dependants` instead.
+ """
+ from airflow.models.mappedoperator import MappedOperator
+ from airflow.utils.task_group import TaskGroup
+
+ def _walk_group(group: TaskGroup) -> Iterable[tuple[str, DAGNode]]:
+ """Recursively walk children in a task group.
+
+ This yields all direct children (including both tasks and task
+ groups), and all children of any task groups.
+ """
+ for key, child in group.children.items():
+ yield key, child
+ if isinstance(child, TaskGroup):
+ yield from _walk_group(child)
+
+ dag = self.get_dag()
+ if not dag:
+ raise RuntimeError("Cannot check for mapped dependants when not attached to a DAG")
+ for key, child in _walk_group(dag.task_group):
+ if key == self.node_id:
+ continue
+ if not isinstance(child, (MappedOperator, MappedTaskGroup)):
+ continue
+ if self.node_id in child.upstream_task_ids:
+ yield child
+
+ def iter_mapped_dependants(self) -> Iterator[MappedOperator | MappedTaskGroup]:
+ """Return mapped nodes that depend on the current task the expansion.
+
+ For now, this walks the entire DAG to find mapped nodes that has this
+ current task as an upstream. We cannot use ``downstream_list`` since it
+ only contains operators, not task groups. In the future, we should
+ provide a way to record an DAG node's all downstream nodes instead.
+ """
+ return (
+ downstream
+ for downstream in self._iter_all_mapped_downstreams()
+ if any(p.node_id == self.node_id for p in downstream.iter_mapped_dependencies())
+ )
+
+ def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]:
+ """Return mapped task groups this task belongs to.
+
+ Groups are returned from the closest to the outmost.
+
+ :meta private:
+ """
+ parent = self.task_group
+ while parent is not None:
+ if isinstance(parent, MappedTaskGroup):
+ yield parent
+ parent = parent.task_group
+
+ def get_closest_mapped_task_group(self) -> MappedTaskGroup | None:
+ """:meta private:"""
+ return next(self.iter_mapped_task_groups(), None)
+
+ def unmap(self, resolve: None | dict[str, Any] | tuple[Context, Session]) -> BaseOperator:
+ """Get the "normal" operator from current abstract operator.
+
+ MappedOperator uses this to unmap itself based on the map index. A non-
+ mapped operator (i.e. BaseOperator subclass) simply returns itself.
+
+ :meta private:
+ """
+ raise NotImplementedError()
+
@property
def priority_weight_total(self) -> int:
"""
@@ -252,9 +339,9 @@ def priority_weight_total(self) -> int:
)
@cached_property
- def operator_extra_link_dict(self) -> Dict[str, Any]:
+ def operator_extra_link_dict(self) -> dict[str, Any]:
"""Returns dictionary of all extra links for the operator"""
- op_extra_links_from_plugin: Dict[str, Any] = {}
+ op_extra_links_from_plugin: dict[str, Any] = {}
from airflow import plugins_manager
plugins_manager.initialize_extra_operators_links_plugins()
@@ -271,7 +358,7 @@ def operator_extra_link_dict(self) -> Dict[str, Any]:
return operator_extra_links_all
@cached_property
- def global_operator_extra_link_dict(self) -> Dict[str, Any]:
+ def global_operator_extra_link_dict(self) -> dict[str, Any]:
"""Returns dictionary of all global extra links"""
from airflow import plugins_manager
@@ -281,10 +368,10 @@ def global_operator_extra_link_dict(self) -> Dict[str, Any]:
return {link.name: link for link in plugins_manager.global_operator_extra_links}
@cached_property
- def extra_links(self) -> List[str]:
+ def extra_links(self) -> list[str]:
return list(set(self.operator_extra_link_dict).union(self.global_operator_extra_link_dict))
- def get_extra_links(self, ti: "TaskInstance", link_name: str) -> Optional[str]:
+ def get_extra_links(self, ti: TaskInstance, link_name: str) -> str | None:
"""For an operator, gets the URLs that the ``extra_links`` entry points to.
:meta private:
@@ -295,33 +382,198 @@ def get_extra_links(self, ti: "TaskInstance", link_name: str) -> Optional[str]:
:param link_name: The name of the link we're looking for the URL for. Should be
one of the options specified in ``extra_links``.
"""
- link: Optional["BaseOperatorLink"] = self.operator_extra_link_dict.get(link_name)
+ link: BaseOperatorLink | None = self.operator_extra_link_dict.get(link_name)
if not link:
link = self.global_operator_extra_link_dict.get(link_name)
if not link:
return None
- # Check for old function signature
+
parameters = inspect.signature(link.get_link).parameters
- args = [name for name, p in parameters.items() if p.kind != p.VAR_KEYWORD]
- if "ti_key" in args:
- return link.get_link(self, ti_key=ti.key)
+ old_signature = all(name != "ti_key" for name, p in parameters.items() if p.kind != p.VAR_KEYWORD)
+
+ if old_signature:
+ return link.get_link(self.unmap(None), ti.dag_run.logical_date) # type: ignore[misc]
+ return link.get_link(self.unmap(None), ti_key=ti.key)
+
+ @cache
+ def get_parse_time_mapped_ti_count(self) -> int:
+ """Number of mapped task instances that can be created on DAG run creation.
+
+ This only considers literal mapped arguments, and would return *None*
+ when any non-literal values are used for mapping.
+
+ :raise NotFullyPopulated: If non-literal mapped arguments are encountered.
+ :raise NotMapped: If the operator is neither mapped, nor has any parent
+ mapped task groups.
+ :return: Total number of mapped TIs this task should have.
+ """
+ group = self.get_closest_mapped_task_group()
+ if group is None:
+ raise NotMapped
+ return group.get_parse_time_mapped_ti_count()
+
+ def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int:
+ """Number of mapped TaskInstances that can be created at run time.
+
+ This considers both literal and non-literal mapped arguments, and the
+ result is therefore available when all depended tasks have finished. The
+ return value should be identical to ``parse_time_mapped_ti_count`` if
+ all mapped arguments are literal.
+
+ :raise NotFullyPopulated: If upstream tasks are not all complete yet.
+ :raise NotMapped: If the operator is neither mapped, nor has any parent
+ mapped task groups.
+ :return: Total number of mapped TIs this task should have.
+ """
+ group = self.get_closest_mapped_task_group()
+ if group is None:
+ raise NotMapped
+ return group.get_mapped_ti_count(run_id, session=session)
+
+ def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence[TaskInstance], int]:
+ """Create the mapped task instances for mapped task.
+
+ :raise NotMapped: If this task does not need expansion.
+ :return: The newly created mapped task instances (if any) in ascending
+ order by map index, and the maximum map index value.
+ """
+ from sqlalchemy import func, or_
+
+ from airflow.models.baseoperator import BaseOperator
+ from airflow.models.mappedoperator import MappedOperator
+ from airflow.models.taskinstance import TaskInstance
+ from airflow.settings import task_instance_mutation_hook
+
+ if not isinstance(self, (BaseOperator, MappedOperator)):
+ raise RuntimeError(f"cannot expand unrecognized operator type {type(self).__name__}")
+
+ try:
+ total_length: int | None = self.get_mapped_ti_count(run_id, session=session)
+ except NotFullyPopulated as e:
+ # It's possible that the upstream tasks are not yet done, but we
+ # don't have upstream of upstreams in partial DAGs (possible in the
+ # mini-scheduler), so we ignore this exception.
+ if not self.dag or not self.dag.partial:
+ self.log.error(
+ "Cannot expand %r for run %s; missing upstream values: %s",
+ self,
+ run_id,
+ sorted(e.missing),
+ )
+ total_length = None
+
+ state: TaskInstanceState | None = None
+ unmapped_ti: TaskInstance | None = (
+ session.query(TaskInstance)
+ .filter(
+ TaskInstance.dag_id == self.dag_id,
+ TaskInstance.task_id == self.task_id,
+ TaskInstance.run_id == run_id,
+ TaskInstance.map_index == -1,
+ or_(TaskInstance.state.in_(State.unfinished), TaskInstance.state.is_(None)),
+ )
+ .one_or_none()
+ )
+
+ all_expanded_tis: list[TaskInstance] = []
+
+ if unmapped_ti:
+ # The unmapped task instance still exists and is unfinished, i.e. we
+ # haven't tried to run it before.
+ if total_length is None:
+ # If the DAG is partial, it's likely that the upstream tasks
+ # are not done yet, so the task can't fail yet.
+ if not self.dag or not self.dag.partial:
+ unmapped_ti.state = TaskInstanceState.UPSTREAM_FAILED
+ elif total_length < 1:
+ # If the upstream maps this to a zero-length value, simply mark
+ # the unmapped task instance as SKIPPED (if needed).
+ self.log.info(
+ "Marking %s as SKIPPED since the map has %d values to expand",
+ unmapped_ti,
+ total_length,
+ )
+ unmapped_ti.state = TaskInstanceState.SKIPPED
+ else:
+ zero_index_ti_exists = (
+ session.query(TaskInstance)
+ .filter(
+ TaskInstance.dag_id == self.dag_id,
+ TaskInstance.task_id == self.task_id,
+ TaskInstance.run_id == run_id,
+ TaskInstance.map_index == 0,
+ )
+ .count()
+ > 0
+ )
+ if not zero_index_ti_exists:
+ # Otherwise convert this into the first mapped index, and create
+ # TaskInstance for other indexes.
+ unmapped_ti.map_index = 0
+ self.log.debug("Updated in place to become %s", unmapped_ti)
+ all_expanded_tis.append(unmapped_ti)
+ session.flush()
+ else:
+ self.log.debug("Deleting the original task instance: %s", unmapped_ti)
+ session.delete(unmapped_ti)
+ state = unmapped_ti.state
+
+ if total_length is None or total_length < 1:
+ # Nothing to fixup.
+ indexes_to_map: Iterable[int] = ()
else:
- return link.get_link(self, ti.dag_run.logical_date) # type: ignore[misc]
- return None
+ # Only create "missing" ones.
+ current_max_mapping = (
+ session.query(func.max(TaskInstance.map_index))
+ .filter(
+ TaskInstance.dag_id == self.dag_id,
+ TaskInstance.task_id == self.task_id,
+ TaskInstance.run_id == run_id,
+ )
+ .scalar()
+ )
+ indexes_to_map = range(current_max_mapping + 1, total_length)
+
+ for index in indexes_to_map:
+ # TODO: Make more efficient with bulk_insert_mappings/bulk_save_mappings.
+ ti = TaskInstance(self, run_id=run_id, map_index=index, state=state)
+ self.log.debug("Expanding TIs upserted %s", ti)
+ task_instance_mutation_hook(ti)
+ ti = session.merge(ti)
+ ti.refresh_from_task(self) # session.merge() loses task information.
+ all_expanded_tis.append(ti)
+
+ # Coerce the None case to 0 -- these two are almost treated identically,
+ # except the unmapped ti (if exists) is marked to different states.
+ total_expanded_ti_count = total_length or 0
+
+ # Any (old) task instances with inapplicable indexes (>= the total
+ # number we need) are set to "REMOVED".
+ query = session.query(TaskInstance).filter(
+ TaskInstance.dag_id == self.dag_id,
+ TaskInstance.task_id == self.task_id,
+ TaskInstance.run_id == run_id,
+ TaskInstance.map_index >= total_expanded_ti_count,
+ )
+ to_update = with_row_locks(query, of=TaskInstance, session=session, **skip_locked(session=session))
+ for ti in to_update:
+ ti.state = TaskInstanceState.REMOVED
+ session.flush()
+ return all_expanded_tis, total_expanded_ti_count - 1
def render_template_fields(
self,
context: Context,
- jinja_env: Optional["jinja2.Environment"] = None,
- ) -> Optional["BaseOperator"]:
- """Template all attributes listed in template_fields.
+ jinja_env: jinja2.Environment | None = None,
+ ) -> None:
+ """Template all attributes listed in *self.template_fields*.
If the operator is mapped, this should return the unmapped, fully
rendered, and map-expanded operator. The mapped operator should not be
- modified.
+ modified. However, *context* may be modified in-place to reference the
+ unmapped operator for template rendering.
- If the operator is not mapped, this should modify the operator in-place
- and return either *None* (for backwards compatibility) or *self*.
+ If the operator is not mapped, this should modify the operator in-place.
"""
raise NotImplementedError()
@@ -331,10 +583,10 @@ def _do_render_template_fields(
parent: Any,
template_fields: Iterable[str],
context: Context,
- jinja_env: "jinja2.Environment",
- seen_oids: Set,
+ jinja_env: jinja2.Environment,
+ seen_oids: set[int],
*,
- session: "Session" = NEW_SESSION,
+ session: Session = NEW_SESSION,
) -> None:
for attr_name in template_fields:
try:
@@ -346,20 +598,30 @@ def _do_render_template_fields(
)
if not value:
continue
- rendered_content = self.render_template(
- value,
- context,
- jinja_env,
- seen_oids,
- )
- setattr(parent, attr_name, rendered_content)
+ try:
+ rendered_content = self.render_template(
+ value,
+ context,
+ jinja_env,
+ seen_oids,
+ )
+ except Exception:
+ self.log.exception(
+ "Exception rendering Jinja template for task '%s', field '%s'. Template: %r",
+ self.task_id,
+ attr_name,
+ value,
+ )
+ raise
+ else:
+ setattr(parent, attr_name, rendered_content)
def render_template(
self,
content: Any,
context: Context,
- jinja_env: Optional["jinja2.Environment"] = None,
- seen_oids: Optional[Set] = None,
+ jinja_env: jinja2.Environment | None = None,
+ seen_oids: set[int] | None = None,
) -> Any:
"""Render a templated string.
@@ -379,12 +641,17 @@ def render_template(
value = content
del content
+ if seen_oids is not None:
+ oids = seen_oids
+ else:
+ oids = set()
+
+ if id(value) in oids:
+ return value
+
if not jinja_env:
jinja_env = self.get_template_env()
- from airflow.models.param import DagParam
- from airflow.models.xcom_arg import XComArg
-
if isinstance(value, str):
if any(value.endswith(ext) for ext in self.template_ext): # A filepath.
template = jinja_env.get_template(value)
@@ -395,26 +662,22 @@ def render_template(
return render_template_as_native(template, context)
return render_template_to_string(template, context)
- if isinstance(value, (DagParam, XComArg)):
+ if isinstance(value, ResolveMixin):
return value.resolve(context)
# Fast path for common built-in collections.
if value.__class__ is tuple:
- return tuple(self.render_template(element, context, jinja_env) for element in value)
+ return tuple(self.render_template(element, context, jinja_env, oids) for element in value)
elif isinstance(value, tuple): # Special case for named tuples.
- return value.__class__(*(self.render_template(el, context, jinja_env) for el in value))
+ return value.__class__(*(self.render_template(el, context, jinja_env, oids) for el in value))
elif isinstance(value, list):
- return [self.render_template(element, context, jinja_env) for element in value]
+ return [self.render_template(element, context, jinja_env, oids) for element in value]
elif isinstance(value, dict):
- return {key: self.render_template(value, context, jinja_env) for key, value in value.items()}
+ return {k: self.render_template(v, context, jinja_env, oids) for k, v in value.items()}
elif isinstance(value, set):
- return {self.render_template(element, context, jinja_env) for element in value}
+ return {self.render_template(element, context, jinja_env, oids) for element in value}
# More complex collections.
- if seen_oids is None:
- oids = set()
- else:
- oids = seen_oids
self._render_nested_template_fields(value, context, jinja_env, oids)
return value
@@ -422,8 +685,8 @@ def _render_nested_template_fields(
self,
value: Any,
context: Context,
- jinja_env: "jinja2.Environment",
- seen_oids: Set[int],
+ jinja_env: jinja2.Environment,
+ seen_oids: set[int],
) -> None:
if id(value) in seen_oids:
return
diff --git a/airflow/models/base.py b/airflow/models/base.py
index 478bd904eb83c..b2587cc1767b8 100644
--- a/airflow/models/base.py
+++ b/airflow/models/base.py
@@ -15,6 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
from typing import Any
@@ -25,9 +26,26 @@
SQL_ALCHEMY_SCHEMA = conf.get("database", "SQL_ALCHEMY_SCHEMA")
-metadata = (
- None if not SQL_ALCHEMY_SCHEMA or SQL_ALCHEMY_SCHEMA.isspace() else MetaData(schema=SQL_ALCHEMY_SCHEMA)
-)
+# For more information about what the tokens in the naming convention
+# below mean, see:
+# https://docs.sqlalchemy.org/en/14/core/metadata.html#sqlalchemy.schema.MetaData.params.naming_convention
+naming_convention = {
+ "ix": "idx_%(column_0_N_label)s",
+ "uq": "%(table_name)s_%(column_0_N_name)s_uq",
+ "ck": "ck_%(table_name)s_%(constraint_name)s",
+ "fk": "%(table_name)s_%(column_0_name)s_fkey",
+ "pk": "%(table_name)s_pkey",
+}
+
+
+def _get_schema():
+ if not SQL_ALCHEMY_SCHEMA or SQL_ALCHEMY_SCHEMA.isspace():
+ return None
+ return SQL_ALCHEMY_SCHEMA
+
+
+metadata = MetaData(schema=_get_schema(), naming_convention=naming_convention)
+
Base: Any = declarative_base(metadata=metadata)
ID_LEN = 250
@@ -35,9 +53,9 @@
def get_id_collation_args():
"""Get SQLAlchemy args to use for COLLATION"""
- collation = conf.get('database', 'sql_engine_collation_for_ids', fallback=None)
+ collation = conf.get("database", "sql_engine_collation_for_ids", fallback=None)
if collation:
- return {'collation': collation}
+ return {"collation": collation}
else:
# Automatically use utf8mb3_bin collation for mysql
# This is backwards-compatible. All our IDS are ASCII anyway so even if
@@ -49,9 +67,9 @@ def get_id_collation_args():
#
# We cannot use session/dialect as at this point we are trying to determine the right connection
# parameters, so we use the connection
- conn = conf.get('database', 'sql_alchemy_conn', fallback='')
- if conn.startswith('mysql') or conn.startswith("mariadb"):
- return {'collation': 'utf8mb3_bin'}
+ conn = conf.get("database", "sql_alchemy_conn", fallback="")
+ if conn.startswith("mysql") or conn.startswith("mariadb"):
+ return {"collation": "utf8mb3_bin"}
return {}
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 3208c435769b3..9448c4a415db9 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -16,6 +16,8 @@
# specific language governing permissions and limitations
# under the License.
"""Base operator for all operators."""
+from __future__ import annotations
+
import abc
import collections
import collections.abc
@@ -35,14 +37,9 @@
Callable,
ClassVar,
Collection,
- Dict,
- FrozenSet,
Iterable,
List,
- Optional,
Sequence,
- Set,
- Tuple,
Type,
TypeVar,
Union,
@@ -56,7 +53,7 @@
from sqlalchemy.orm.exc import NoResultFound
from airflow.configuration import conf
-from airflow.exceptions import AirflowException, TaskDeferred
+from airflow.exceptions import AirflowException, RemovedInAirflow3Warning, TaskDeferred
from airflow.lineage import apply_lineage, prepare_lineage
from airflow.models.abstractoperator import (
DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,
@@ -87,6 +84,7 @@
from airflow.triggers.base import BaseTrigger
from airflow.utils import timezone
from airflow.utils.context import Context
+from airflow.utils.decorators import fixup_decorator_warning_stack
from airflow.utils.helpers import validate_key
from airflow.utils.operator_resources import Resources
from airflow.utils.session import NEW_SESSION, provide_session
@@ -98,6 +96,7 @@
from airflow.models.dag import DAG
from airflow.models.taskinstance import TaskInstanceKey
+ from airflow.models.xcom_arg import XComArg
from airflow.utils.task_group import TaskGroup
ScheduleInterval = Union[str, timedelta, relativedelta]
@@ -105,12 +104,12 @@
TaskPreExecuteHook = Callable[[Context], None]
TaskPostExecuteHook = Callable[[Context, Any], None]
-T = TypeVar('T', bound=FunctionType)
+T = TypeVar("T", bound=FunctionType)
logger = logging.getLogger("airflow.models.baseoperator.BaseOperator")
-def parse_retries(retries: Any) -> Optional[int]:
+def parse_retries(retries: Any) -> int | None:
if retries is None or isinstance(retries, int):
return retries
try:
@@ -121,20 +120,20 @@ def parse_retries(retries: Any) -> Optional[int]:
return parsed_retries
-def coerce_timedelta(value: Union[float, timedelta], *, key: str) -> timedelta:
+def coerce_timedelta(value: float | timedelta, *, key: str) -> timedelta:
if isinstance(value, timedelta):
return value
logger.debug("%s isn't a timedelta object, assuming secs", key)
return timedelta(seconds=value)
-def coerce_resources(resources: Optional[Dict[str, Any]]) -> Optional[Resources]:
+def coerce_resources(resources: dict[str, Any] | None) -> Resources | None:
if resources is None:
return None
return Resources(**resources)
-def _get_parent_defaults(dag: Optional["DAG"], task_group: Optional["TaskGroup"]) -> Tuple[dict, ParamsDict]:
+def _get_parent_defaults(dag: DAG | None, task_group: TaskGroup | None) -> tuple[dict, ParamsDict]:
if not dag:
return {}, ParamsDict()
dag_args = copy.copy(dag.default_args)
@@ -147,11 +146,11 @@ def _get_parent_defaults(dag: Optional["DAG"], task_group: Optional["TaskGroup"]
def get_merged_defaults(
- dag: Optional["DAG"],
- task_group: Optional["TaskGroup"],
- task_params: Optional[dict],
- task_default_args: Optional[dict],
-) -> Tuple[dict, ParamsDict]:
+ dag: DAG | None,
+ task_group: TaskGroup | None,
+ task_params: dict | None,
+ task_default_args: dict | None,
+) -> tuple[dict, ParamsDict]:
args, params = _get_parent_defaults(dag, task_group)
if task_params:
if not isinstance(task_params, collections.abc.Mapping):
@@ -172,7 +171,7 @@ class _PartialDescriptor:
class_method = None
def __get__(
- self, obj: "BaseOperator", cls: "Optional[Type[BaseOperator]]" = None
+ self, obj: BaseOperator, cls: type[BaseOperator] | None = None
) -> Callable[..., OperatorPartial]:
# Call this "partial" so it looks nicer in stack traces.
def partial(**kwargs):
@@ -185,46 +184,46 @@ def partial(**kwargs):
# This is what handles the actual mapping.
def partial(
- operator_class: Type["BaseOperator"],
+ operator_class: type[BaseOperator],
*,
task_id: str,
- dag: Optional["DAG"] = None,
- task_group: Optional["TaskGroup"] = None,
- start_date: Optional[datetime] = None,
- end_date: Optional[datetime] = None,
+ dag: DAG | None = None,
+ task_group: TaskGroup | None = None,
+ start_date: datetime | None = None,
+ end_date: datetime | None = None,
owner: str = DEFAULT_OWNER,
- email: Union[None, str, Iterable[str]] = None,
- params: Optional[dict] = None,
- resources: Optional[Dict[str, Any]] = None,
+ email: None | str | Iterable[str] = None,
+ params: dict | None = None,
+ resources: dict[str, Any] | None = None,
trigger_rule: str = DEFAULT_TRIGGER_RULE,
depends_on_past: bool = False,
ignore_first_depends_on_past: bool = DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,
wait_for_downstream: bool = False,
- retries: Optional[int] = DEFAULT_RETRIES,
+ retries: int | None = DEFAULT_RETRIES,
queue: str = DEFAULT_QUEUE,
- pool: Optional[str] = None,
+ pool: str | None = None,
pool_slots: int = DEFAULT_POOL_SLOTS,
- execution_timeout: Optional[timedelta] = DEFAULT_TASK_EXECUTION_TIMEOUT,
- max_retry_delay: Union[None, timedelta, float] = None,
- retry_delay: Union[timedelta, float] = DEFAULT_RETRY_DELAY,
+ execution_timeout: timedelta | None = DEFAULT_TASK_EXECUTION_TIMEOUT,
+ max_retry_delay: None | timedelta | float = None,
+ retry_delay: timedelta | float = DEFAULT_RETRY_DELAY,
retry_exponential_backoff: bool = False,
priority_weight: int = DEFAULT_PRIORITY_WEIGHT,
weight_rule: str = DEFAULT_WEIGHT_RULE,
- sla: Optional[timedelta] = None,
- max_active_tis_per_dag: Optional[int] = None,
- on_execute_callback: Optional[TaskStateChangeCallback] = None,
- on_failure_callback: Optional[TaskStateChangeCallback] = None,
- on_success_callback: Optional[TaskStateChangeCallback] = None,
- on_retry_callback: Optional[TaskStateChangeCallback] = None,
- run_as_user: Optional[str] = None,
- executor_config: Optional[Dict] = None,
- inlets: Optional[Any] = None,
- outlets: Optional[Any] = None,
- doc: Optional[str] = None,
- doc_md: Optional[str] = None,
- doc_json: Optional[str] = None,
- doc_yaml: Optional[str] = None,
- doc_rst: Optional[str] = None,
+ sla: timedelta | None = None,
+ max_active_tis_per_dag: int | None = None,
+ on_execute_callback: TaskStateChangeCallback | None = None,
+ on_failure_callback: TaskStateChangeCallback | None = None,
+ on_success_callback: TaskStateChangeCallback | None = None,
+ on_retry_callback: TaskStateChangeCallback | None = None,
+ run_as_user: str | None = None,
+ executor_config: dict | None = None,
+ inlets: Any | None = None,
+ outlets: Any | None = None,
+ doc: str | None = None,
+ doc_md: str | None = None,
+ doc_json: str | None = None,
+ doc_yaml: str | None = None,
+ doc_rst: str | None = None,
**kwargs,
) -> OperatorPartial:
from airflow.models.dag import DagContext
@@ -239,7 +238,7 @@ def partial(
task_id = task_group.child_id(task_id)
# Merge DAG and task group level defaults into user-supplied values.
- partial_kwargs, default_params = get_merged_defaults(
+ partial_kwargs, partial_params = get_merged_defaults(
dag=dag,
task_group=task_group,
task_params=params,
@@ -255,7 +254,6 @@ def partial(
partial_kwargs.setdefault("end_date", end_date)
partial_kwargs.setdefault("owner", owner)
partial_kwargs.setdefault("email", email)
- partial_kwargs.setdefault("params", default_params)
partial_kwargs.setdefault("trigger_rule", trigger_rule)
partial_kwargs.setdefault("depends_on_past", depends_on_past)
partial_kwargs.setdefault("ignore_first_depends_on_past", ignore_first_depends_on_past)
@@ -278,8 +276,8 @@ def partial(
partial_kwargs.setdefault("on_success_callback", on_success_callback)
partial_kwargs.setdefault("run_as_user", run_as_user)
partial_kwargs.setdefault("executor_config", executor_config)
- partial_kwargs.setdefault("inlets", inlets)
- partial_kwargs.setdefault("outlets", outlets)
+ partial_kwargs.setdefault("inlets", inlets or [])
+ partial_kwargs.setdefault("outlets", outlets or [])
partial_kwargs.setdefault("resources", resources)
partial_kwargs.setdefault("doc", doc)
partial_kwargs.setdefault("doc_json", doc_json)
@@ -306,7 +304,11 @@ def partial(
partial_kwargs["executor_config"] = partial_kwargs["executor_config"] or {}
partial_kwargs["resources"] = coerce_resources(partial_kwargs["resources"])
- return OperatorPartial(operator_class=operator_class, kwargs=partial_kwargs)
+ return OperatorPartial(
+ operator_class=operator_class,
+ kwargs=partial_kwargs,
+ params=partial_params,
+ )
class BaseOperatorMeta(abc.ABCMeta):
@@ -331,7 +333,7 @@ def _apply_defaults(cls, func: T) -> T:
non_variadic_params = {
name: param
for (name, param) in sig_cache.parameters.items()
- if param.name != 'self' and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD)
+ if param.name != "self" and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD)
}
non_optional_args = {
name
@@ -339,28 +341,10 @@ def _apply_defaults(cls, func: T) -> T:
if param.default == param.empty and name != "task_id"
}
- class autostacklevel_warn:
- def __init__(self):
- self.warnings = __import__('warnings')
-
- def __getattr__(self, name):
- return getattr(self.warnings, name)
-
- def __dir__(self):
- return dir(self.warnings)
-
- def warn(self, message, category=None, stacklevel=1, source=None):
- self.warnings.warn(message, category, stacklevel + 2, source)
-
- if func.__globals__.get('warnings') is sys.modules['warnings']:
- # Yes, this is slightly hacky, but it _automatically_ sets the right
- # stacklevel parameter to `warnings.warn` to ignore the decorator. Now
- # that the decorator is applied automatically, this makes the needed
- # stacklevel parameter less confusing.
- func.__globals__['warnings'] = autostacklevel_warn()
+ fixup_decorator_warning_stack(func)
@functools.wraps(func)
- def apply_defaults(self: "BaseOperator", *args: Any, **kwargs: Any) -> Any:
+ def apply_defaults(self: BaseOperator, *args: Any, **kwargs: Any) -> Any:
from airflow.models.dag import DagContext
from airflow.utils.task_group import TaskGroupContext
@@ -372,8 +356,8 @@ def apply_defaults(self: "BaseOperator", *args: Any, **kwargs: Any) -> Any:
getattr(self, "_BaseOperator__from_mapped", False),
)
- dag: Optional[DAG] = kwargs.get('dag') or DagContext.get_current_dag()
- task_group: Optional[TaskGroup] = kwargs.get('task_group')
+ dag: DAG | None = kwargs.get("dag") or DagContext.get_current_dag()
+ task_group: TaskGroup | None = kwargs.get("task_group")
if dag and not task_group:
task_group = TaskGroupContext.get_current_task_group(dag)
@@ -398,12 +382,12 @@ def apply_defaults(self: "BaseOperator", *args: Any, **kwargs: Any) -> Any:
if merged_params:
kwargs["params"] = merged_params
- hook = getattr(self, '_hook_apply_defaults', None)
+ hook = getattr(self, "_hook_apply_defaults", None)
if hook:
args, kwargs = hook(**kwargs, default_args=default_args)
- default_args = kwargs.pop('default_args', {})
+ default_args = kwargs.pop("default_args", {})
- if not hasattr(self, '_BaseOperator__init_kwargs'):
+ if not hasattr(self, "_BaseOperator__init_kwargs"):
self._BaseOperator__init_kwargs = {}
self._BaseOperator__from_mapped = instantiated_from_mapped
@@ -412,8 +396,9 @@ def apply_defaults(self: "BaseOperator", *args: Any, **kwargs: Any) -> Any:
# Store the args passed to init -- we need them to support task.map serialzation!
self._BaseOperator__init_kwargs.update(kwargs) # type: ignore
- if not instantiated_from_mapped:
- # Set upstream task defined by XComArgs passed to template fields of the operator.
+ # Set upstream task defined by XComArgs passed to template fields of the operator.
+ # BUT: only do this _ONCE_, not once for each class in the hierarchy
+ if not instantiated_from_mapped and func == self.__init__.__wrapped__: # type: ignore[misc]
self.set_xcomargs_dependencies()
# Mark instance as instantiated.
self._BaseOperator__instantiated = True
@@ -428,8 +413,8 @@ def apply_defaults(self: "BaseOperator", *args: Any, **kwargs: Any) -> Any:
def __new__(cls, name, bases, namespace, **kwargs):
new_cls = super().__new__(cls, name, bases, namespace, **kwargs)
with contextlib.suppress(KeyError):
- # Update the partial descriptor with the class method so it call call the actual function (but let
- # subclasses override it if they need to)
+ # Update the partial descriptor with the class method, so it calls the actual function
+ # (but let subclasses override it if they need to)
partial_desc = vars(new_cls)["partial"]
if isinstance(partial_desc, _PartialDescriptor):
partial_desc.class_method = classmethod(partial)
@@ -463,7 +448,7 @@ class derived from this one results in the creation of a task object,
(e.g. user/person/team/role name) to clarify ownership is recommended.
:param email: the 'to' email address(es) used in email alerts. This can be a
single email or multiple ones. Multiple addresses can be specified as a
- comma or semi-colon separated string or by passing a list of strings.
+ comma or semicolon separated string or by passing a list of strings.
:param email_on_retry: Indicates whether email alerts should be sent when a
task is retried
:param email_on_failure: Indicates whether email alerts should be sent when
@@ -576,7 +561,7 @@ class derived from this one results in the creation of a task object,
|experimental|
:param trigger_rule: defines the rule by which dependencies are applied
for the task to get triggered. Options are:
- ``{ all_success | all_failed | all_done | all_skipped | one_success |
+ ``{ all_success | all_failed | all_done | all_skipped | one_success | one_done |
one_failed | none_failed | none_failed_min_one_success | none_skipped | always}``
default is ``all_success``. Options can be set as string or
using the constants defined in the static class
@@ -620,54 +605,54 @@ class derived from this one results in the creation of a task object,
template_fields: Sequence[str] = ()
template_ext: Sequence[str] = ()
- template_fields_renderers: Dict[str, str] = {}
+ template_fields_renderers: dict[str, str] = {}
# Defines the color in the UI
- ui_color: str = '#fff'
- ui_fgcolor: str = '#000'
+ ui_color: str = "#fff"
+ ui_fgcolor: str = "#000"
pool: str = ""
# base list which includes all the attrs that don't need deep copy.
- _base_operator_shallow_copy_attrs: Tuple[str, ...] = (
- 'user_defined_macros',
- 'user_defined_filters',
- 'params',
- '_log',
+ _base_operator_shallow_copy_attrs: tuple[str, ...] = (
+ "user_defined_macros",
+ "user_defined_filters",
+ "params",
+ "_log",
)
# each operator should override this class attr for shallow copy attrs.
shallow_copy_attrs: Sequence[str] = ()
# Defines the operator level extra links
- operator_extra_links: Collection['BaseOperatorLink'] = ()
+ operator_extra_links: Collection[BaseOperatorLink] = ()
# The _serialized_fields are lazily loaded when get_serialized_fields() method is called
- __serialized_fields: Optional[FrozenSet[str]] = None
+ __serialized_fields: frozenset[str] | None = None
partial: Callable[..., OperatorPartial] = _PartialDescriptor() # type: ignore
_comps = {
- 'task_id',
- 'dag_id',
- 'owner',
- 'email',
- 'email_on_retry',
- 'retry_delay',
- 'retry_exponential_backoff',
- 'max_retry_delay',
- 'start_date',
- 'end_date',
- 'depends_on_past',
- 'wait_for_downstream',
- 'priority_weight',
- 'sla',
- 'execution_timeout',
- 'on_execute_callback',
- 'on_failure_callback',
- 'on_success_callback',
- 'on_retry_callback',
- 'do_xcom_push',
+ "task_id",
+ "dag_id",
+ "owner",
+ "email",
+ "email_on_retry",
+ "retry_delay",
+ "retry_exponential_backoff",
+ "max_retry_delay",
+ "start_date",
+ "end_date",
+ "depends_on_past",
+ "wait_for_downstream",
+ "priority_weight",
+ "sla",
+ "execution_timeout",
+ "on_execute_callback",
+ "on_failure_callback",
+ "on_success_callback",
+ "on_retry_callback",
+ "do_xcom_push",
}
# Defines if the operator supports lineage without manual definitions
@@ -677,25 +662,20 @@ class derived from this one results in the creation of a task object,
__instantiated = False
# List of args as passed to `init()`, after apply_defaults() has been updated. Used to "recreate" the task
# when mapping
- __init_kwargs: Dict[str, Any]
+ __init_kwargs: dict[str, Any]
# Set to True before calling execute method
_lock_for_execution = False
- _dag: Optional["DAG"] = None
- task_group: Optional["TaskGroup"] = None
+ _dag: DAG | None = None
+ task_group: TaskGroup | None = None
# subdag parameter is only set for SubDagOperator.
# Setting it to None by default as other Operators do not have that field
- subdag: Optional["DAG"] = None
-
- start_date: Optional[pendulum.DateTime] = None
- end_date: Optional[pendulum.DateTime] = None
+ subdag: DAG | None = None
- # How operator-mapping arguments should be validated. If True, a default validation implementation that
- # calls the operator's constructor is used. If False, the operator should implement its own validation
- # logic (default implementation is 'pass' i.e. no validation whatsoever).
- mapped_arguments_validated_by_init: ClassVar[bool] = False
+ start_date: pendulum.DateTime | None = None
+ end_date: pendulum.DateTime | None = None
# Set to True for an operator instantiated by a mapped operator.
__from_mapped = False
@@ -704,49 +684,49 @@ def __init__(
self,
task_id: str,
owner: str = DEFAULT_OWNER,
- email: Optional[Union[str, Iterable[str]]] = None,
- email_on_retry: bool = conf.getboolean('email', 'default_email_on_retry', fallback=True),
- email_on_failure: bool = conf.getboolean('email', 'default_email_on_failure', fallback=True),
- retries: Optional[int] = DEFAULT_RETRIES,
- retry_delay: Union[timedelta, float] = DEFAULT_RETRY_DELAY,
+ email: str | Iterable[str] | None = None,
+ email_on_retry: bool = conf.getboolean("email", "default_email_on_retry", fallback=True),
+ email_on_failure: bool = conf.getboolean("email", "default_email_on_failure", fallback=True),
+ retries: int | None = DEFAULT_RETRIES,
+ retry_delay: timedelta | float = DEFAULT_RETRY_DELAY,
retry_exponential_backoff: bool = False,
- max_retry_delay: Optional[Union[timedelta, float]] = None,
- start_date: Optional[datetime] = None,
- end_date: Optional[datetime] = None,
+ max_retry_delay: timedelta | float | None = None,
+ start_date: datetime | None = None,
+ end_date: datetime | None = None,
depends_on_past: bool = False,
ignore_first_depends_on_past: bool = DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,
wait_for_downstream: bool = False,
- dag: Optional['DAG'] = None,
- params: Optional[Dict] = None,
- default_args: Optional[Dict] = None,
+ dag: DAG | None = None,
+ params: dict | None = None,
+ default_args: dict | None = None,
priority_weight: int = DEFAULT_PRIORITY_WEIGHT,
weight_rule: str = DEFAULT_WEIGHT_RULE,
queue: str = DEFAULT_QUEUE,
- pool: Optional[str] = None,
+ pool: str | None = None,
pool_slots: int = DEFAULT_POOL_SLOTS,
- sla: Optional[timedelta] = None,
- execution_timeout: Optional[timedelta] = DEFAULT_TASK_EXECUTION_TIMEOUT,
- on_execute_callback: Optional[TaskStateChangeCallback] = None,
- on_failure_callback: Optional[TaskStateChangeCallback] = None,
- on_success_callback: Optional[TaskStateChangeCallback] = None,
- on_retry_callback: Optional[TaskStateChangeCallback] = None,
- pre_execute: Optional[TaskPreExecuteHook] = None,
- post_execute: Optional[TaskPostExecuteHook] = None,
+ sla: timedelta | None = None,
+ execution_timeout: timedelta | None = DEFAULT_TASK_EXECUTION_TIMEOUT,
+ on_execute_callback: TaskStateChangeCallback | None = None,
+ on_failure_callback: TaskStateChangeCallback | None = None,
+ on_success_callback: TaskStateChangeCallback | None = None,
+ on_retry_callback: TaskStateChangeCallback | None = None,
+ pre_execute: TaskPreExecuteHook | None = None,
+ post_execute: TaskPostExecuteHook | None = None,
trigger_rule: str = DEFAULT_TRIGGER_RULE,
- resources: Optional[Dict[str, Any]] = None,
- run_as_user: Optional[str] = None,
- task_concurrency: Optional[int] = None,
- max_active_tis_per_dag: Optional[int] = None,
- executor_config: Optional[Dict] = None,
+ resources: dict[str, Any] | None = None,
+ run_as_user: str | None = None,
+ task_concurrency: int | None = None,
+ max_active_tis_per_dag: int | None = None,
+ executor_config: dict | None = None,
do_xcom_push: bool = True,
- inlets: Optional[Any] = None,
- outlets: Optional[Any] = None,
- task_group: Optional["TaskGroup"] = None,
- doc: Optional[str] = None,
- doc_md: Optional[str] = None,
- doc_json: Optional[str] = None,
- doc_yaml: Optional[str] = None,
- doc_rst: Optional[str] = None,
+ inlets: Any | None = None,
+ outlets: Any | None = None,
+ task_group: TaskGroup | None = None,
+ doc: str | None = None,
+ doc_md: str | None = None,
+ doc_json: str | None = None,
+ doc_yaml: str | None = None,
+ doc_rst: str | None = None,
**kwargs,
):
from airflow.models.dag import DagContext
@@ -756,17 +736,18 @@ def __init__(
super().__init__()
+ kwargs.pop("_airflow_mapped_validation_only", None)
if kwargs:
- if not conf.getboolean('operators', 'ALLOW_ILLEGAL_ARGUMENTS'):
+ if not conf.getboolean("operators", "ALLOW_ILLEGAL_ARGUMENTS"):
raise AirflowException(
f"Invalid arguments were passed to {self.__class__.__name__} (task_id: {task_id}). "
f"Invalid arguments were:\n**kwargs: {kwargs}",
)
warnings.warn(
- f'Invalid arguments were passed to {self.__class__.__name__} (task_id: {task_id}). '
- 'Support for passing such arguments will be dropped in future. '
- f'Invalid arguments were:\n**kwargs: {kwargs}',
- category=PendingDeprecationWarning,
+ f"Invalid arguments were passed to {self.__class__.__name__} (task_id: {task_id}). "
+ "Support for passing such arguments will be dropped in future. "
+ f"Invalid arguments were:\n**kwargs: {kwargs}",
+ category=RemovedInAirflow3Warning,
stacklevel=3,
)
validate_key(task_id)
@@ -785,7 +766,7 @@ def __init__(
if execution_timeout is not None and not isinstance(execution_timeout, timedelta):
raise ValueError(
- f'execution_timeout must be timedelta object but passed as type: {type(execution_timeout)}'
+ f"execution_timeout must be timedelta object but passed as type: {type(execution_timeout)}"
)
self.execution_timeout = execution_timeout
@@ -818,7 +799,7 @@ def __init__(
if trigger_rule == "dummy":
warnings.warn(
"dummy Trigger Rule is deprecated. Please use `TriggerRule.ALWAYS`.",
- DeprecationWarning,
+ RemovedInAirflow3Warning,
stacklevel=2,
)
trigger_rule = TriggerRule.ALWAYS
@@ -827,7 +808,7 @@ def __init__(
warnings.warn(
"none_failed_or_skipped Trigger Rule is deprecated. "
"Please use `none_failed_min_one_success`.",
- DeprecationWarning,
+ RemovedInAirflow3Warning,
stacklevel=2,
)
trigger_rule = TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS
@@ -838,7 +819,7 @@ def __init__(
f"'{dag.dag_id if dag else ''}.{task_id}'; received '{trigger_rule}'."
)
- self.trigger_rule = TriggerRule(trigger_rule)
+ self.trigger_rule: TriggerRule = TriggerRule(trigger_rule)
self.depends_on_past: bool = depends_on_past
self.ignore_first_depends_on_past: bool = ignore_first_depends_on_past
self.wait_for_downstream: bool = wait_for_downstream
@@ -854,7 +835,7 @@ def __init__(
)
# At execution_time this becomes a normal dict
- self.params: Union[ParamsDict, dict] = ParamsDict(params)
+ self.params: ParamsDict | dict = ParamsDict(params)
if priority_weight is not None and not isinstance(priority_weight, int):
raise AirflowException(
f"`priority_weight` for task '{self.task_id}' only accepts integers, "
@@ -873,11 +854,11 @@ def __init__(
# TODO: Remove in Airflow 3.0
warnings.warn(
"The 'task_concurrency' parameter is deprecated. Please use 'max_active_tis_per_dag'.",
- DeprecationWarning,
+ RemovedInAirflow3Warning,
stacklevel=2,
)
max_active_tis_per_dag = task_concurrency
- self.max_active_tis_per_dag: Optional[int] = max_active_tis_per_dag
+ self.max_active_tis_per_dag: int | None = max_active_tis_per_dag
self.do_xcom_push = do_xcom_push
self.doc_md = doc_md
@@ -886,8 +867,8 @@ def __init__(
self.doc_rst = doc_rst
self.doc = doc
- self.upstream_task_ids: Set[str] = set()
- self.downstream_task_ids: Set[str] = set()
+ self.upstream_task_ids: set[str] = set()
+ self.downstream_task_ids: set[str] = set()
if dag:
self.dag = dag
@@ -895,14 +876,11 @@ def __init__(
self._log = logging.getLogger("airflow.task.operators")
# Lineage
- self.inlets: List = []
- self.outlets: List = []
-
- self._inlets: List = []
- self._outlets: List = []
+ self.inlets: list = []
+ self.outlets: list = []
if inlets:
- self._inlets = (
+ self.inlets = (
inlets
if isinstance(inlets, list)
else [
@@ -911,7 +889,7 @@ def __init__(
)
if outlets:
- self._outlets = (
+ self.outlets = (
outlets
if isinstance(outlets, list)
else [
@@ -954,11 +932,11 @@ def __hash__(self):
def __or__(self, other):
"""
Called for [This Operator] | [Operator], The inlets of other
- will be set to pickup the outlets from this operator. Other will
+ will be set to pick up the outlets from this operator. Other will
be set as a downstream task of this operator.
"""
if isinstance(other, BaseOperator):
- if not self._outlets and not self.supports_lineage:
+ if not self.outlets and not self.supports_lineage:
raise ValueError("No outlets defined for this operator")
other.add_inlets([self.task_id])
self.set_downstream(other)
@@ -1014,33 +992,33 @@ def __setattr__(self, key, value):
def add_inlets(self, inlets: Iterable[Any]):
"""Sets inlets to this operator"""
- self._inlets.extend(inlets)
+ self.inlets.extend(inlets)
def add_outlets(self, outlets: Iterable[Any]):
"""Defines the outlets of this operator"""
- self._outlets.extend(outlets)
+ self.outlets.extend(outlets)
def get_inlet_defs(self):
- """:return: list of inlets defined for this operator"""
- return self._inlets
+ """:meta private:"""
+ return self.inlets
def get_outlet_defs(self):
- """:return: list of outlets defined for this operator"""
- return self._outlets
+ """:meta private:"""
+ return self.outlets
- def get_dag(self) -> "Optional[DAG]":
+ def get_dag(self) -> DAG | None:
return self._dag
@property # type: ignore[override]
- def dag(self) -> 'DAG': # type: ignore[override]
+ def dag(self) -> DAG: # type: ignore[override]
"""Returns the Operator's DAG if set, otherwise raises an error"""
if self._dag:
return self._dag
else:
- raise AirflowException(f'Operator {self} has not been assigned to a DAG yet')
+ raise AirflowException(f"Operator {self} has not been assigned to a DAG yet")
@dag.setter
- def dag(self, dag: Optional['DAG']):
+ def dag(self, dag: DAG | None):
"""
Operators can be assigned to one DAG, one time. Repeat assignments to
that same DAG are ok.
@@ -1051,7 +1029,7 @@ def dag(self, dag: Optional['DAG']):
self._dag = None
return
if not isinstance(dag, DAG):
- raise TypeError(f'Expected DAG; received {dag.__class__.__name__}')
+ raise TypeError(f"Expected DAG; received {dag.__class__.__name__}")
elif self.has_dag() and self.dag is not dag:
raise AirflowException(f"The DAG assigned to {self} can not be changed.")
@@ -1068,7 +1046,7 @@ def has_dag(self):
"""Returns True if the Operator has been assigned to a DAG."""
return self._dag is not None
- deps: FrozenSet[BaseTIDep] = frozenset(
+ deps: frozenset[BaseTIDep] = frozenset(
{
NotInRetryPeriodDep(),
PrevDagrunDep(),
@@ -1082,7 +1060,7 @@ def has_dag(self):
extended/overridden by subclasses.
"""
- def prepare_for_execution(self) -> "BaseOperator":
+ def prepare_for_execution(self) -> BaseOperator:
"""
Lock task for execution to disable custom action in __setattr__ and
returns a copy of the task
@@ -1146,18 +1124,17 @@ def post_execute(self, context: Any, result: Any = None):
def on_kill(self) -> None:
"""
- Override this method to cleanup subprocesses when a task instance
+ Override this method to clean up subprocesses when a task instance
gets killed. Any use of the threading, subprocess or multiprocessing
- module within an operator needs to be cleaned up or it will leave
+ module within an operator needs to be cleaned up, or it will leave
ghost processes behind.
"""
def __deepcopy__(self, memo):
- """
- Hack sorting double chained task lists by task_id to avoid hitting
- max_depth on deepcopy operations.
- """
+ # Hack sorting double chained task lists by task_id to avoid hitting
+ # max_depth on deepcopy operations.
sys.setrecursionlimit(5000) # TODO fix this in a better way
+
cls = self.__class__
result = cls.__new__(cls)
memo[id(self)] = result
@@ -1165,15 +1142,19 @@ def __deepcopy__(self, memo):
shallow_copy = cls.shallow_copy_attrs + cls._base_operator_shallow_copy_attrs
for k, v in self.__dict__.items():
+ if k == "_BaseOperator__instantiated":
+ # Don't set this until the _end_, as it changes behaviour of __setattr__
+ continue
if k not in shallow_copy:
setattr(result, k, copy.deepcopy(v, memo))
else:
setattr(result, k, copy.copy(v))
+ result.__instantiated = self.__instantiated
return result
def __getstate__(self):
state = dict(self.__dict__)
- del state['_log']
+ del state["_log"]
return state
@@ -1184,25 +1165,24 @@ def __setstate__(self, state):
def render_template_fields(
self,
context: Context,
- jinja_env: Optional["jinja2.Environment"] = None,
- ) -> Optional["BaseOperator"]:
- """Template all attributes listed in template_fields.
+ jinja_env: jinja2.Environment | None = None,
+ ) -> None:
+ """Template all attributes listed in *self.template_fields*.
This mutates the attributes in-place and is irreversible.
- :param context: Dict with values to apply on content
- :param jinja_env: Jinja environment
+ :param context: Context dict with values to apply on content.
+ :param jinja_env: Jinja's environment to use for rendering.
"""
if not jinja_env:
jinja_env = self.get_template_env()
self._do_render_template_fields(self, self.template_fields, context, jinja_env, set())
- return self
@provide_session
def clear(
self,
- start_date: Optional[datetime] = None,
- end_date: Optional[datetime] = None,
+ start_date: datetime | None = None,
+ end_date: datetime | None = None,
upstream: bool = False,
downstream: bool = False,
session: Session = NEW_SESSION,
@@ -1236,10 +1216,10 @@ def clear(
@provide_session
def get_task_instances(
self,
- start_date: Optional[datetime] = None,
- end_date: Optional[datetime] = None,
+ start_date: datetime | None = None,
+ end_date: datetime | None = None,
session: Session = NEW_SESSION,
- ) -> List[TaskInstance]:
+ ) -> list[TaskInstance]:
"""Get task instances related to this task for a specific date range."""
from airflow.models import DagRun
@@ -1258,8 +1238,8 @@ def get_task_instances(
@provide_session
def run(
self,
- start_date: Optional[datetime] = None,
- end_date: Optional[datetime] = None,
+ start_date: datetime | None = None,
+ end_date: datetime | None = None,
ignore_first_depends_on_past: bool = True,
ignore_ti_state: bool = False,
mark_success: bool = False,
@@ -1314,7 +1294,7 @@ def run(
def dry_run(self) -> None:
"""Performs dry run for the operator - just render template fields."""
- self.log.info('Dry run')
+ self.log.info("Dry run")
for field in self.template_fields:
try:
content = getattr(self, field)
@@ -1325,10 +1305,10 @@ def dry_run(self) -> None:
)
if content and isinstance(content, str):
- self.log.info('Rendering template for %s', field)
+ self.log.info("Rendering template for %s", field)
self.log.info(content)
- def get_direct_relatives(self, upstream: bool = False) -> Iterable["DAGNode"]:
+ def get_direct_relatives(self, upstream: bool = False) -> Iterable[DAGNode]:
"""
Get list of the direct relatives to the current task, upstream or
downstream.
@@ -1342,7 +1322,7 @@ def __repr__(self):
return "".format(self=self)
@property
- def operator_class(self) -> Type["BaseOperator"]: # type: ignore[override]
+ def operator_class(self) -> type[BaseOperator]: # type: ignore[override]
return self.__class__
@property
@@ -1351,17 +1331,25 @@ def task_type(self) -> str:
return self.__class__.__name__
@property
- def roots(self) -> List["BaseOperator"]:
+ def operator_name(self) -> str:
+ """@property: use a more friendly display name for the operator, if set"""
+ try:
+ return self.custom_operator_name # type: ignore
+ except AttributeError:
+ return self.task_type
+
+ @property
+ def roots(self) -> list[BaseOperator]:
"""Required by DAGNode."""
return [self]
@property
- def leaves(self) -> List["BaseOperator"]:
+ def leaves(self) -> list[BaseOperator]:
"""Required by DAGNode."""
return [self]
@property
- def output(self):
+ def output(self) -> XComArg:
"""Returns reference to XCom pushed by current operator"""
from airflow.models.xcom_arg import XComArg
@@ -1372,7 +1360,7 @@ def xcom_push(
context: Any,
key: str,
value: Any,
- execution_date: Optional[datetime] = None,
+ execution_date: datetime | None = None,
) -> None:
"""
Make an XCom available for tasks to pull.
@@ -1385,15 +1373,15 @@ def xcom_push(
this date. This can be used, for example, to send a message to a
task on a future date without it being immediately visible.
"""
- context['ti'].xcom_push(key=key, value=value, execution_date=execution_date)
+ context["ti"].xcom_push(key=key, value=value, execution_date=execution_date)
@staticmethod
def xcom_pull(
context: Any,
- task_ids: Optional[Union[str, List[str]]] = None,
- dag_id: Optional[str] = None,
+ task_ids: str | list[str] | None = None,
+ dag_id: str | None = None,
key: str = XCOM_RETURN_KEY,
- include_prior_dates: Optional[bool] = None,
+ include_prior_dates: bool | None = None,
) -> Any:
"""
Pull XComs that optionally meet certain criteria.
@@ -1421,7 +1409,7 @@ def xcom_pull(
execution_date are returned. If True, XComs from previous dates
are returned as well.
"""
- return context['ti'].xcom_pull(
+ return context["ti"].xcom_pull(
key=key, task_ids=task_ids, dag_id=dag_id, include_prior_dates=include_prior_dates
)
@@ -1437,61 +1425,54 @@ def get_serialized_fields(cls):
# Exception in SerializedDAG.serialize_dag() call.
DagContext.push_context_managed_dag(None)
cls.__serialized_fields = frozenset(
- vars(BaseOperator(task_id='test')).keys()
+ vars(BaseOperator(task_id="test")).keys()
- {
- 'inlets',
- 'outlets',
- 'upstream_task_ids',
- 'default_args',
- 'dag',
- '_dag',
- 'label',
- '_BaseOperator__instantiated',
- '_BaseOperator__init_kwargs',
- '_BaseOperator__from_mapped',
+ "upstream_task_ids",
+ "default_args",
+ "dag",
+ "_dag",
+ "label",
+ "_BaseOperator__instantiated",
+ "_BaseOperator__init_kwargs",
+ "_BaseOperator__from_mapped",
}
| { # Class level defaults need to be added to this list
- 'start_date',
- 'end_date',
- '_task_type',
- 'subdag',
- 'ui_color',
- 'ui_fgcolor',
- 'template_ext',
- 'template_fields',
- 'template_fields_renderers',
- 'params',
+ "start_date",
+ "end_date",
+ "_task_type",
+ "_operator_name",
+ "subdag",
+ "ui_color",
+ "ui_fgcolor",
+ "template_ext",
+ "template_fields",
+ "template_fields_renderers",
+ "params",
}
)
DagContext.pop_context_managed_dag()
return cls.__serialized_fields
- def serialize_for_task_group(self) -> Tuple[DagAttributeTypes, Any]:
+ def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
"""Required by DAGNode."""
return DagAttributeTypes.OP, self.task_id
- def is_smart_sensor_compatible(self):
- """Return if this operator can use smart service. Default False."""
- return False
-
- is_mapped: ClassVar[bool] = False
-
@property
def inherits_from_empty_operator(self):
"""Used to determine if an Operator is inherited from EmptyOperator"""
# This looks like `isinstance(self, EmptyOperator) would work, but this also
# needs to cope when `self` is a Serialized instance of a EmptyOperator or one
- # of its sub-classes (which don't inherit from anything but BaseOperator).
- return getattr(self, '_is_empty', False)
+ # of its subclasses (which don't inherit from anything but BaseOperator).
+ return getattr(self, "_is_empty", False)
def defer(
self,
*,
trigger: BaseTrigger,
method_name: str,
- kwargs: Optional[Dict[str, Any]] = None,
- timeout: Optional[timedelta] = None,
+ kwargs: dict[str, Any] | None = None,
+ timeout: timedelta | None = None,
):
"""
Marks this Operator as being "deferred" - that is, suspending its
@@ -1502,13 +1483,7 @@ def defer(
"""
raise TaskDeferred(trigger=trigger, method_name=method_name, kwargs=kwargs, timeout=timeout)
- @classmethod
- def validate_mapped_arguments(cls, **kwargs: Any) -> None:
- """Validate arguments when this operator is being mapped."""
- if cls.mapped_arguments_validated_by_init:
- cls(**kwargs, _airflow_from_mapped=True)
-
- def unmap(self) -> "BaseOperator":
+ def unmap(self, resolve: None | dict[str, Any] | tuple[Context, Session]) -> BaseOperator:
""":meta private:"""
return self
@@ -1517,7 +1492,7 @@ def unmap(self) -> "BaseOperator":
Chainable = Union[DependencyMixin, Sequence[DependencyMixin]]
-def chain(*tasks: Union[DependencyMixin, Sequence[DependencyMixin]]) -> None:
+def chain(*tasks: DependencyMixin | Sequence[DependencyMixin]) -> None:
r"""
Given a number of tasks, builds a dependency chain.
@@ -1588,12 +1563,12 @@ def chain(*tasks: Union[DependencyMixin, Sequence[DependencyMixin]]) -> None:
.. code-block:: python
- chain(t1, [Label("branch one"), Label("branch two")], [x1(), x2()], task_group1, t2())
+ chain(t1, [Label("branch one"), Label("branch two")], [x1(), x2()], task_group1, x3())
is equivalent to::
/ "branch one" -> x1 \
- t1 -> t2 -> x3
+ t1 -> task_group1 -> x3
\ "branch two" -> x2 /
.. code-block:: python
@@ -1634,13 +1609,13 @@ def chain(*tasks: Union[DependencyMixin, Sequence[DependencyMixin]]) -> None:
down_task.set_upstream(up_task)
continue
if not isinstance(up_task, Sequence) or not isinstance(down_task, Sequence):
- raise TypeError(f'Chain not supported between instances of {type(up_task)} and {type(down_task)}')
+ raise TypeError(f"Chain not supported between instances of {type(up_task)} and {type(down_task)}")
up_task_list = up_task
down_task_list = down_task
if len(up_task_list) != len(down_task_list):
raise AirflowException(
- f'Chain not supported for different length Iterable. '
- f'Got {len(up_task_list)} and {len(down_task_list)}.'
+ f"Chain not supported for different length Iterable. "
+ f"Got {len(up_task_list)} and {len(down_task_list)}."
)
for up_t, down_t in zip(up_task_list, down_task_list):
up_t.set_downstream(down_t)
@@ -1648,7 +1623,7 @@ def chain(*tasks: Union[DependencyMixin, Sequence[DependencyMixin]]) -> None:
def cross_downstream(
from_tasks: Sequence[DependencyMixin],
- to_tasks: Union[DependencyMixin, Sequence[DependencyMixin]],
+ to_tasks: DependencyMixin | Sequence[DependencyMixin],
):
r"""
Set downstream dependencies for all tasks in from_tasks to all tasks in to_tasks.
@@ -1747,11 +1722,19 @@ def cross_downstream(
task.set_downstream(to_tasks)
+# pyupgrade assumes all type annotations can be lazily evaluated, but this is
+# not the case for attrs-decorated classes, since cattrs needs to evaluate the
+# annotation expressions at runtime, and Python before 3.9.0 does not lazily
+# evaluate those. Putting the expression in a top-level assignment statement
+# communicates this runtime requirement to pyupgrade.
+BaseOperatorClassList = List[Type[BaseOperator]]
+
+
@attr.s(auto_attribs=True)
class BaseOperatorLink(metaclass=ABCMeta):
"""Abstract base class that defines how we get an operator link."""
- operators: ClassVar[List[Type[BaseOperator]]] = []
+ operators: ClassVar[BaseOperatorClassList] = []
"""
This property will be used by Airflow Plugins to find the Operators to which you want
to assign this Operator Link
@@ -1762,21 +1745,16 @@ class BaseOperatorLink(metaclass=ABCMeta):
@property
@abstractmethod
def name(self) -> str:
- """
- Name of the link. This will be the button name on the task UI.
-
- :return: link name
- """
+ """Name of the link. This will be the button name on the task UI."""
@abstractmethod
- def get_link(self, operator: AbstractOperator, *, ti_key: "TaskInstanceKey") -> str:
- """
- Link to external system.
+ def get_link(self, operator: BaseOperator, *, ti_key: TaskInstanceKey) -> str:
+ """Link to external system.
Note: The old signature of this function was ``(self, operator, dttm: datetime)``. That is still
supported at runtime but is deprecated.
- :param operator: airflow operator
- :param ti_key: TaskInstance ID to return link for
+ :param operator: The Airflow operator object this link is associated to.
+ :param ti_key: TaskInstance ID to return link for.
:return: link to external system
"""
diff --git a/airflow/models/connection.py b/airflow/models/connection.py
index b21e68b73b356..5f7406d9d4835 100644
--- a/airflow/models/connection.py
+++ b/airflow/models/connection.py
@@ -15,23 +15,22 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
import json
import logging
import warnings
from json import JSONDecodeError
-from typing import Dict, Optional, Union
-from urllib.parse import parse_qsl, quote, unquote, urlencode, urlparse
+from urllib.parse import parse_qsl, quote, unquote, urlencode, urlsplit
from sqlalchemy import Boolean, Column, Integer, String, Text
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm import reconstructor, synonym
from airflow.configuration import ensure_secrets_loaded
-from airflow.exceptions import AirflowException, AirflowNotFoundException
+from airflow.exceptions import AirflowException, AirflowNotFoundException, RemovedInAirflow3Warning
from airflow.models.base import ID_LEN, Base
from airflow.models.crypto import get_fernet
-from airflow.providers_manager import ProvidersManager
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.log.secrets_masker import mask_secret
from airflow.utils.module_loading import import_string
@@ -41,7 +40,7 @@
def parse_netloc_to_hostname(*args, **kwargs):
"""This method is deprecated."""
- warnings.warn("This method is deprecated.", DeprecationWarning)
+ warnings.warn("This method is deprecated.", RemovedInAirflow3Warning)
return _parse_netloc_to_hostname(*args, **kwargs)
@@ -49,8 +48,8 @@ def parse_netloc_to_hostname(*args, **kwargs):
# See: https://issues.apache.org/jira/browse/AIRFLOW-3615
def _parse_netloc_to_hostname(uri_parts):
"""Parse a URI string to get correct Hostname."""
- hostname = unquote(uri_parts.hostname or '')
- if '/' in hostname:
+ hostname = unquote(uri_parts.hostname or "")
+ if "/" in hostname:
hostname = uri_parts.netloc
if "@" in hostname:
hostname = hostname.rsplit("@", 1)[1]
@@ -83,35 +82,35 @@ class Connection(Base, LoggingMixin):
:param uri: URI address describing connection parameters.
"""
- EXTRA_KEY = '__extra__'
+ EXTRA_KEY = "__extra__"
__tablename__ = "connection"
id = Column(Integer(), primary_key=True)
conn_id = Column(String(ID_LEN), unique=True, nullable=False)
conn_type = Column(String(500), nullable=False)
- description = Column(Text(5000))
+ description = Column(Text().with_variant(Text(5000), "mysql").with_variant(String(5000), "sqlite"))
host = Column(String(500))
schema = Column(String(500))
login = Column(String(500))
- _password = Column('password', String(5000))
+ _password = Column("password", String(5000))
port = Column(Integer())
is_encrypted = Column(Boolean, unique=False, default=False)
is_extra_encrypted = Column(Boolean, unique=False, default=False)
- _extra = Column('extra', Text())
+ _extra = Column("extra", Text())
def __init__(
self,
- conn_id: Optional[str] = None,
- conn_type: Optional[str] = None,
- description: Optional[str] = None,
- host: Optional[str] = None,
- login: Optional[str] = None,
- password: Optional[str] = None,
- schema: Optional[str] = None,
- port: Optional[int] = None,
- extra: Optional[Union[str, dict]] = None,
- uri: Optional[str] = None,
+ conn_id: str | None = None,
+ conn_type: str | None = None,
+ description: str | None = None,
+ host: str | None = None,
+ login: str | None = None,
+ password: str | None = None,
+ schema: str | None = None,
+ port: int | None = None,
+ extra: str | dict | None = None,
+ uri: str | None = None,
):
super().__init__()
self.conn_id = conn_id
@@ -155,14 +154,14 @@ def _validate_extra(extra, conn_id) -> None:
"Encountered JSON value in `extra` which does not parse as a dictionary in "
f"connection {conn_id!r}. From Airflow 3.0, the `extra` field must contain a JSON "
"representation of a Python dict.",
- DeprecationWarning,
+ RemovedInAirflow3Warning,
stacklevel=3,
)
except json.JSONDecodeError:
warnings.warn(
f"Encountered non-JSON in `extra` field for connection {conn_id!r}. Support for "
"non-JSON `extra` will be removed in Airflow 3.0",
- DeprecationWarning,
+ RemovedInAirflow3Warning,
stacklevel=2,
)
return None
@@ -175,20 +174,21 @@ def on_db_load(self):
def parse_from_uri(self, **uri):
"""This method is deprecated. Please use uri parameter in constructor."""
warnings.warn(
- "This method is deprecated. Please use uri parameter in constructor.", DeprecationWarning
+ "This method is deprecated. Please use uri parameter in constructor.",
+ RemovedInAirflow3Warning,
)
self._parse_from_uri(**uri)
@staticmethod
def _normalize_conn_type(conn_type):
- if conn_type == 'postgresql':
- conn_type = 'postgres'
- elif '-' in conn_type:
- conn_type = conn_type.replace('-', '_')
+ if conn_type == "postgresql":
+ conn_type = "postgres"
+ elif "-" in conn_type:
+ conn_type = conn_type.replace("-", "_")
return conn_type
def _parse_from_uri(self, uri: str):
- uri_parts = urlparse(uri)
+ uri_parts = urlsplit(uri)
conn_type = uri_parts.scheme
self.conn_type = self._normalize_conn_type(conn_type)
self.host = _parse_netloc_to_hostname(uri_parts)
@@ -206,35 +206,38 @@ def _parse_from_uri(self, uri: str):
def get_uri(self) -> str:
"""Return connection in URI format"""
- if '_' in self.conn_type:
+ if self.conn_type and "_" in self.conn_type:
self.log.warning(
"Connection schemes (type: %s) shall not contain '_' according to RFC3986.",
self.conn_type,
)
- uri = f"{str(self.conn_type).lower().replace('_', '-')}://"
+ if self.conn_type:
+ uri = f"{self.conn_type.lower().replace('_', '-')}://"
+ else:
+ uri = "//"
- authority_block = ''
+ authority_block = ""
if self.login is not None:
- authority_block += quote(self.login, safe='')
+ authority_block += quote(self.login, safe="")
if self.password is not None:
- authority_block += ':' + quote(self.password, safe='')
+ authority_block += ":" + quote(self.password, safe="")
- if authority_block > '':
- authority_block += '@'
+ if authority_block > "":
+ authority_block += "@"
uri += authority_block
- host_block = ''
+ host_block = ""
if self.host:
- host_block += quote(self.host, safe='')
+ host_block += quote(self.host, safe="")
if self.port:
- if host_block > '':
- host_block += f':{self.port}'
+ if host_block == "" and authority_block == "":
+ host_block += f"@:{self.port}"
else:
- host_block += f'@:{self.port}'
+ host_block += f":{self.port}"
if self.schema:
host_block += f"/{quote(self.schema, safe='')}"
@@ -243,17 +246,17 @@ def get_uri(self) -> str:
if self.extra:
try:
- query: Optional[str] = urlencode(self.extra_dejson)
+ query: str | None = urlencode(self.extra_dejson)
except TypeError:
query = None
if query and self.extra_dejson == dict(parse_qsl(query, keep_blank_values=True)):
- uri += '?' + query
+ uri += ("?" if self.schema else "/?") + query
else:
- uri += '?' + urlencode({self.EXTRA_KEY: self.extra})
+ uri += ("?" if self.schema else "/?") + urlencode({self.EXTRA_KEY: self.extra})
return uri
- def get_password(self) -> Optional[str]:
+ def get_password(self) -> str | None:
"""Return encrypted password."""
if self._password and self.is_encrypted:
fernet = get_fernet()
@@ -262,23 +265,23 @@ def get_password(self) -> Optional[str]:
f"Can't decrypt encrypted password for login={self.login} "
f"FERNET_KEY configuration is missing"
)
- return fernet.decrypt(bytes(self._password, 'utf-8')).decode()
+ return fernet.decrypt(bytes(self._password, "utf-8")).decode()
else:
return self._password
- def set_password(self, value: Optional[str]):
+ def set_password(self, value: str | None):
"""Encrypt password and set in object attribute."""
if value:
fernet = get_fernet()
- self._password = fernet.encrypt(bytes(value, 'utf-8')).decode()
+ self._password = fernet.encrypt(bytes(value, "utf-8")).decode()
self.is_encrypted = fernet.is_encrypted
@declared_attr
def password(cls):
"""Password. The value is decrypted/encrypted when reading/setting the value."""
- return synonym('_password', descriptor=property(cls.get_password, cls.set_password))
+ return synonym("_password", descriptor=property(cls.get_password, cls.set_password))
- def get_extra(self) -> Dict:
+ def get_extra(self) -> str:
"""Return encrypted extra-data."""
if self._extra and self.is_extra_encrypted:
fernet = get_fernet()
@@ -287,7 +290,7 @@ def get_extra(self) -> Dict:
f"Can't decrypt `extra` params for login={self.login}, "
f"FERNET_KEY configuration is missing"
)
- extra_val = fernet.decrypt(bytes(self._extra, 'utf-8')).decode()
+ extra_val = fernet.decrypt(bytes(self._extra, "utf-8")).decode()
else:
extra_val = self._extra
if extra_val:
@@ -299,7 +302,7 @@ def set_extra(self, value: str):
if value:
self._validate_extra(value, self.conn_id)
fernet = get_fernet()
- self._extra = fernet.encrypt(bytes(value, 'utf-8')).decode()
+ self._extra = fernet.encrypt(bytes(value, "utf-8")).decode()
self.is_extra_encrypted = fernet.is_encrypted
else:
self._extra = value
@@ -308,18 +311,20 @@ def set_extra(self, value: str):
@declared_attr
def extra(cls):
"""Extra data. The value is decrypted/encrypted when reading/setting the value."""
- return synonym('_extra', descriptor=property(cls.get_extra, cls.set_extra))
+ return synonym("_extra", descriptor=property(cls.get_extra, cls.set_extra))
def rotate_fernet_key(self):
"""Encrypts data with a new key. See: :ref:`security/fernet`"""
fernet = get_fernet()
if self._password and self.is_encrypted:
- self._password = fernet.rotate(self._password.encode('utf-8')).decode()
+ self._password = fernet.rotate(self._password.encode("utf-8")).decode()
if self._extra and self.is_extra_encrypted:
- self._extra = fernet.rotate(self._extra.encode('utf-8')).decode()
+ self._extra = fernet.rotate(self._extra.encode("utf-8")).decode()
def get_hook(self, *, hook_params=None):
"""Return hook based on conn_type"""
+ from airflow.providers_manager import ProvidersManager
+
hook = ProvidersManager().hooks.get(self.conn_type, None)
if hook is None:
@@ -339,7 +344,7 @@ def get_hook(self, *, hook_params=None):
return hook_class(**{hook.connection_id_attribute_name: self.conn_id}, **hook_params)
def __repr__(self):
- return self.conn_id or ''
+ return self.conn_id or ""
def log_info(self):
"""
@@ -349,7 +354,7 @@ def log_info(self):
warnings.warn(
"This method is deprecated. You can read each field individually or "
"use the default representation (__repr__).",
- DeprecationWarning,
+ RemovedInAirflow3Warning,
stacklevel=2,
)
return (
@@ -366,7 +371,7 @@ def debug_info(self):
warnings.warn(
"This method is deprecated. You can read each field individually or "
"use the default representation (__repr__).",
- DeprecationWarning,
+ RemovedInAirflow3Warning,
stacklevel=2,
)
return (
@@ -377,10 +382,10 @@ def debug_info(self):
def test_connection(self):
"""Calls out get_hook method and executes test_connection method on that."""
- status, message = False, ''
+ status, message = False, ""
try:
hook = self.get_hook()
- if getattr(hook, 'test_connection', False):
+ if getattr(hook, "test_connection", False):
status, message = hook.test_connection()
else:
message = (
@@ -392,7 +397,7 @@ def test_connection(self):
return status, message
@property
- def extra_dejson(self) -> Dict:
+ def extra_dejson(self) -> dict:
"""Returns the extra property by deserializing json."""
obj = {}
if self.extra:
@@ -408,7 +413,7 @@ def extra_dejson(self) -> Dict:
return obj
@classmethod
- def get_connection_from_secrets(cls, conn_id: str) -> 'Connection':
+ def get_connection_from_secrets(cls, conn_id: str) -> Connection:
"""
Get connection by conn_id.
@@ -422,26 +427,26 @@ def get_connection_from_secrets(cls, conn_id: str) -> 'Connection':
return conn
except Exception:
log.exception(
- 'Unable to retrieve connection from secrets backend (%s). '
- 'Checking subsequent secrets backend.',
+ "Unable to retrieve connection from secrets backend (%s). "
+ "Checking subsequent secrets backend.",
type(secrets_backend).__name__,
)
raise AirflowNotFoundException(f"The conn_id `{conn_id}` isn't defined")
@classmethod
- def from_json(cls, value, conn_id=None) -> 'Connection':
+ def from_json(cls, value, conn_id=None) -> Connection:
kwargs = json.loads(value)
- extra = kwargs.pop('extra', None)
+ extra = kwargs.pop("extra", None)
if extra:
- kwargs['extra'] = extra if isinstance(extra, str) else json.dumps(extra)
- conn_type = kwargs.pop('conn_type', None)
+ kwargs["extra"] = extra if isinstance(extra, str) else json.dumps(extra)
+ conn_type = kwargs.pop("conn_type", None)
if conn_type:
- kwargs['conn_type'] = cls._normalize_conn_type(conn_type)
- port = kwargs.pop('port', None)
+ kwargs["conn_type"] = cls._normalize_conn_type(conn_type)
+ port = kwargs.pop("port", None)
if port:
try:
- kwargs['port'] = int(port)
+ kwargs["port"] = int(port)
except ValueError:
raise ValueError(f"Expected integer value for `port`, but got {port!r} instead.")
return Connection(conn_id=conn_id, **kwargs)
diff --git a/airflow/models/crypto.py b/airflow/models/crypto.py
index b57c53732df32..d1a2b06a60599 100644
--- a/airflow/models/crypto.py
+++ b/airflow/models/crypto.py
@@ -15,10 +15,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-import logging
-from typing import Optional
+from __future__ import annotations
-from cryptography.fernet import Fernet, MultiFernet
+import logging
from airflow.configuration import conf
from airflow.exceptions import AirflowException
@@ -58,7 +57,7 @@ def encrypt(self, b):
return b
-_fernet = None # type: Optional[FernetProtocol]
+_fernet: FernetProtocol | None = None
def get_fernet():
@@ -71,19 +70,21 @@ def get_fernet():
:return: Fernet object
:raises: airflow.exceptions.AirflowException if there's a problem trying to load Fernet
"""
+ from cryptography.fernet import Fernet, MultiFernet
+
global _fernet
if _fernet:
return _fernet
try:
- fernet_key = conf.get('core', 'FERNET_KEY')
+ fernet_key = conf.get("core", "FERNET_KEY")
if not fernet_key:
log.warning("empty cryptography key - values will not be stored encrypted.")
_fernet = NullFernet()
else:
_fernet = MultiFernet(
- [Fernet(fernet_part.encode('utf-8')) for fernet_part in fernet_key.split(',')]
+ [Fernet(fernet_part.encode("utf-8")) for fernet_part in fernet_key.split(",")]
)
_fernet.is_encrypted = True
except (ValueError, TypeError) as value_error:
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index ece03bafe8c92..4cb59754468d2 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -15,7 +15,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+import collections
import copy
import functools
import itertools
@@ -28,6 +30,7 @@
import traceback
import warnings
import weakref
+from collections import deque
from datetime import datetime, timedelta
from inspect import signature
from typing import (
@@ -35,25 +38,23 @@
Any,
Callable,
Collection,
- Dict,
- FrozenSet,
+ Deque,
Iterable,
+ Iterator,
List,
- Optional,
Sequence,
- Set,
- Tuple,
- Type,
Union,
cast,
overload,
)
+from urllib.parse import urlsplit
import jinja2
import pendulum
from dateutil.relativedelta import relativedelta
from pendulum.tz.timezone import Timezone
-from sqlalchemy import Boolean, Column, ForeignKey, Index, Integer, String, Text, func, not_, or_
+from sqlalchemy import Boolean, Column, ForeignKey, Index, Integer, String, Text, and_, case, func, not_, or_
+from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.orm import backref, joinedload, relationship
from sqlalchemy.orm.query import Query
from sqlalchemy.orm.session import Session
@@ -62,28 +63,36 @@
import airflow.templates
from airflow import settings, utils
from airflow.compat.functools import cached_property
-from airflow.configuration import conf
-from airflow.exceptions import AirflowException, DuplicateTaskIdFound, TaskNotFound
+from airflow.configuration import conf, secrets_backend_list
+from airflow.exceptions import (
+ AirflowDagInconsistent,
+ AirflowException,
+ AirflowSkipException,
+ DuplicateTaskIdFound,
+ RemovedInAirflow3Warning,
+ TaskNotFound,
+)
from airflow.models.abstractoperator import AbstractOperator
-from airflow.models.base import ID_LEN, Base
-from airflow.models.dagbag import DagBag
+from airflow.models.base import Base, StringID
from airflow.models.dagcode import DagCode
from airflow.models.dagpickle import DagPickle
from airflow.models.dagrun import DagRun
from airflow.models.operator import Operator
from airflow.models.param import DagParam, ParamsDict
from airflow.models.taskinstance import Context, TaskInstance, TaskInstanceKey, clear_task_instances
+from airflow.secrets.local_filesystem import LocalFilesystemBackend
from airflow.security import permissions
from airflow.stats import Stats
from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, Timetable
from airflow.timetables.interval import CronDataIntervalTimetable, DeltaDataIntervalTimetable
-from airflow.timetables.simple import NullTimetable, OnceTimetable
+from airflow.timetables.simple import DatasetTriggeredTimetable, NullTimetable, OnceTimetable
from airflow.typing_compat import Literal
from airflow.utils import timezone
from airflow.utils.dag_cycle_tester import check_cycle
from airflow.utils.dates import cron_presets, date_range as utils_date_range
+from airflow.utils.decorators import fixup_decorator_warning_stack
from airflow.utils.file import correct_maybe_zipped
-from airflow.utils.helpers import exactly_one, validate_key
+from airflow.utils.helpers import at_most_one, exactly_one, validate_key
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import Interval, UtcDateTime, skip_locked, tuple_in_condition, with_row_locks
@@ -91,16 +100,21 @@
from airflow.utils.types import NOTSET, ArgNotSet, DagRunType, EdgeInfoType
if TYPE_CHECKING:
+ from types import ModuleType
+
+ from airflow.datasets import Dataset
from airflow.decorators import TaskDecoratorCollection
+ from airflow.models.dagbag import DagBag
from airflow.models.slamiss import SlaMiss
from airflow.utils.task_group import TaskGroup
log = logging.getLogger(__name__)
-DEFAULT_VIEW_PRESETS = ['grid', 'graph', 'duration', 'gantt', 'landing_times']
-ORIENTATION_PRESETS = ['LR', 'TB', 'RL', 'BT']
+DEFAULT_VIEW_PRESETS = ["grid", "graph", "duration", "gantt", "landing_times"]
+ORIENTATION_PRESETS = ["LR", "TB", "RL", "BT"]
+TAG_MAX_LEN = 100
DagStateChangeCallback = Callable[[Context], None]
ScheduleInterval = Union[None, str, timedelta, relativedelta]
@@ -109,6 +123,9 @@
# but Mypy cannot handle that right now. Track progress of PEP 661 for progress.
# See also: https://discuss.python.org/t/9126/7
ScheduleIntervalArg = Union[ArgNotSet, ScheduleInterval]
+ScheduleArg = Union[ArgNotSet, ScheduleInterval, Timetable, Collection["Dataset"]]
+
+SLAMissCallback = Callable[["DAG", str, str, List["SlaMiss"], List[TaskInstance]], None]
# Backward compatibility: If neither schedule_interval nor timetable is
@@ -142,7 +159,7 @@ def _get_model_data_interval(
instance: Any,
start_field_name: str,
end_field_name: str,
-) -> Optional[DataInterval]:
+) -> DataInterval | None:
start = timezone.coerce_datetime(getattr(instance, start_field_name))
end = timezone.coerce_datetime(getattr(instance, end_field_name))
if start is None:
@@ -172,7 +189,7 @@ def create_timetable(interval: ScheduleIntervalArg, timezone: Timezone) -> Timet
def get_last_dagrun(dag_id, session, include_externally_triggered=False):
"""
Returns the last dag run for a dag, None if there was none.
- Last dag run can be any type of run eg. scheduled or backfilled.
+ Last dag run can be any type of run e.g. scheduled or backfilled.
Overridden DagRuns are ignored.
"""
DR = DagRun
@@ -183,6 +200,49 @@ def get_last_dagrun(dag_id, session, include_externally_triggered=False):
return query.first()
+def get_dataset_triggered_next_run_info(
+ dag_ids: list[str], *, session: Session
+) -> dict[str, dict[str, int | str]]:
+ """
+ Given a list of dag_ids, get string representing how close any that are dataset triggered are
+ their next run, e.g. "1 of 2 datasets updated"
+ """
+ from airflow.models.dataset import DagScheduleDatasetReference, DatasetDagRunQueue as DDRQ, DatasetModel
+
+ return {
+ x.dag_id: {
+ "uri": x.uri,
+ "ready": x.ready,
+ "total": x.total,
+ }
+ for x in session.query(
+ DagScheduleDatasetReference.dag_id,
+ # This is a dirty hack to workaround group by requiring an aggregate, since grouping by dataset
+ # is not what we want to do here...but it works
+ case((func.count() == 1, func.max(DatasetModel.uri)), else_="").label("uri"),
+ func.count().label("total"),
+ func.sum(case((DDRQ.target_dag_id.is_not(None), 1), else_=0)).label("ready"),
+ )
+ .join(
+ DDRQ,
+ and_(
+ DDRQ.dataset_id == DagScheduleDatasetReference.dataset_id,
+ DDRQ.target_dag_id == DagScheduleDatasetReference.dag_id,
+ ),
+ isouter=True,
+ )
+ .join(
+ DatasetModel,
+ DatasetModel.id == DagScheduleDatasetReference.dataset_id,
+ )
+ .group_by(
+ DagScheduleDatasetReference.dag_id,
+ )
+ .filter(DagScheduleDatasetReference.dag_id.in_(dag_ids))
+ .all()
+ }
+
+
@functools.total_ordering
class DAG(LoggingMixin):
"""
@@ -199,19 +259,25 @@ class DAG(LoggingMixin):
Note that if you plan to use time zones all the dates provided should be pendulum
dates. See :ref:`timezone_aware_dags`.
+ .. versionadded:: 2.4
+ The *schedule* argument to specify either time-based scheduling logic
+ (timetable), or dataset-driven triggers.
+
+ .. deprecated:: 2.4
+ The arguments *schedule_interval* and *timetable*. Their functionalities
+ are merged into the new *schedule* argument.
+
:param dag_id: The id of the DAG; must consist exclusively of alphanumeric
characters, dashes, dots and underscores (all ASCII)
:param description: The description for the DAG to e.g. be shown on the webserver
- :param schedule_interval: Defines how often that DAG runs, this
- timedelta object gets added to your latest task instance's
- execution_date to figure out the next schedule
- :param timetable: Specify which timetable to use (in which case schedule_interval
- must not be set). See :doc:`/howto/timetable` for more information
+ :param schedule: Defines the rules according to which DAG runs are scheduled. Can
+ accept cron string, timedelta object, Timetable, or list of Dataset objects.
+ See also :doc:`/howto/timetable`.
:param start_date: The timestamp from which the scheduler will
attempt to backfill
:param end_date: A date beyond which your DAG won't run, leave to None
- for open ended scheduling
- :param template_searchpath: This list of folders (non relative)
+ for open-ended scheduling
+ :param template_searchpath: This list of folders (non-relative)
defines where jinja will look for your templates. Order matters.
Note that jinja/airflow includes the path of your DAG file by
default
@@ -279,21 +345,25 @@ class DAG(LoggingMixin):
to render templates as native Python types. If False, a Jinja
``Environment`` is used to render templates as string values.
:param tags: List of tags to help filtering DAGs in the UI.
+ :param owner_links: Dict of owners and their links, that will be clickable on the DAGs view UI.
+ Can be used as an HTTP link (for example the link to your Slack channel), or a mailto link.
+ e.g: {"dag_owner": "https://airflow.apache.org/"}
+ :param auto_register: Automatically register this DAG when it is used in a ``with`` block
"""
_comps = {
- 'dag_id',
- 'task_ids',
- 'parent_dag',
- 'start_date',
- 'end_date',
- 'schedule_interval',
- 'fileloc',
- 'template_searchpath',
- 'last_loaded',
+ "dag_id",
+ "task_ids",
+ "parent_dag",
+ "start_date",
+ "end_date",
+ "schedule_interval",
+ "fileloc",
+ "template_searchpath",
+ "last_loaded",
}
- __serialized_fields: Optional[FrozenSet[str]] = None
+ __serialized_fields: frozenset[str] | None = None
fileloc: str
"""
@@ -303,44 +373,51 @@ class DAG(LoggingMixin):
from a ZIP file or other DAG distribution format.
"""
- parent_dag: Optional["DAG"] = None # Gets set when DAGs are loaded
+ parent_dag: DAG | None = None # Gets set when DAGs are loaded
+ # NOTE: When updating arguments here, please also keep arguments in @dag()
+ # below in sync. (Search for 'def dag(' in this file.)
def __init__(
self,
dag_id: str,
- description: Optional[str] = None,
+ description: str | None = None,
+ schedule: ScheduleArg = NOTSET,
schedule_interval: ScheduleIntervalArg = NOTSET,
- timetable: Optional[Timetable] = None,
- start_date: Optional[datetime] = None,
- end_date: Optional[datetime] = None,
- full_filepath: Optional[str] = None,
- template_searchpath: Optional[Union[str, Iterable[str]]] = None,
- template_undefined: Type[jinja2.StrictUndefined] = jinja2.StrictUndefined,
- user_defined_macros: Optional[Dict] = None,
- user_defined_filters: Optional[Dict] = None,
- default_args: Optional[Dict] = None,
- concurrency: Optional[int] = None,
- max_active_tasks: int = conf.getint('core', 'max_active_tasks_per_dag'),
- max_active_runs: int = conf.getint('core', 'max_active_runs_per_dag'),
- dagrun_timeout: Optional[timedelta] = None,
- sla_miss_callback: Optional[
- Callable[["DAG", str, str, List["SlaMiss"], List[TaskInstance]], None]
- ] = None,
- default_view: str = conf.get_mandatory_value('webserver', 'dag_default_view').lower(),
- orientation: str = conf.get_mandatory_value('webserver', 'dag_orientation'),
- catchup: bool = conf.getboolean('scheduler', 'catchup_by_default'),
- on_success_callback: Optional[DagStateChangeCallback] = None,
- on_failure_callback: Optional[DagStateChangeCallback] = None,
- doc_md: Optional[str] = None,
- params: Optional[Dict] = None,
- access_control: Optional[Dict] = None,
- is_paused_upon_creation: Optional[bool] = None,
- jinja_environment_kwargs: Optional[Dict] = None,
+ timetable: Timetable | None = None,
+ start_date: datetime | None = None,
+ end_date: datetime | None = None,
+ full_filepath: str | None = None,
+ template_searchpath: str | Iterable[str] | None = None,
+ template_undefined: type[jinja2.StrictUndefined] = jinja2.StrictUndefined,
+ user_defined_macros: dict | None = None,
+ user_defined_filters: dict | None = None,
+ default_args: dict | None = None,
+ concurrency: int | None = None,
+ max_active_tasks: int = conf.getint("core", "max_active_tasks_per_dag"),
+ max_active_runs: int = conf.getint("core", "max_active_runs_per_dag"),
+ dagrun_timeout: timedelta | None = None,
+ sla_miss_callback: SLAMissCallback | None = None,
+ default_view: str = conf.get_mandatory_value("webserver", "dag_default_view").lower(),
+ orientation: str = conf.get_mandatory_value("webserver", "dag_orientation"),
+ catchup: bool = conf.getboolean("scheduler", "catchup_by_default"),
+ on_success_callback: DagStateChangeCallback | None = None,
+ on_failure_callback: DagStateChangeCallback | None = None,
+ doc_md: str | None = None,
+ params: dict | None = None,
+ access_control: dict | None = None,
+ is_paused_upon_creation: bool | None = None,
+ jinja_environment_kwargs: dict | None = None,
render_template_as_native_obj: bool = False,
- tags: Optional[List[str]] = None,
+ tags: list[str] | None = None,
+ owner_links: dict[str, str] | None = None,
+ auto_register: bool = True,
):
from airflow.utils.task_group import TaskGroup
+ if tags and any(len(tag) > TAG_MAX_LEN for tag in tags):
+ raise AirflowException(f"tag cannot be longer than {TAG_MAX_LEN} characters")
+
+ self.owner_links = owner_links if owner_links else {}
self.user_defined_macros = user_defined_macros
self.user_defined_filters = user_defined_filters
if default_args and not isinstance(default_args, dict):
@@ -349,9 +426,9 @@ def __init__(
params = params or {}
# merging potentially conflicting default_args['params'] into params
- if 'params' in self.default_args:
- params.update(self.default_args['params'])
- del self.default_args['params']
+ if "params" in self.default_args:
+ params.update(self.default_args["params"])
+ del self.default_args["params"]
# check self.params and convert them into ParamsDict
self.params = ParamsDict(params)
@@ -359,7 +436,7 @@ def __init__(
if full_filepath:
warnings.warn(
"Passing full_filepath to DAG() is deprecated and has no effect",
- DeprecationWarning,
+ RemovedInAirflow3Warning,
stacklevel=2,
)
@@ -370,29 +447,29 @@ def __init__(
# TODO: Remove in Airflow 3.0
warnings.warn(
"The 'concurrency' parameter is deprecated. Please use 'max_active_tasks'.",
- DeprecationWarning,
+ RemovedInAirflow3Warning,
stacklevel=2,
)
max_active_tasks = concurrency
self._max_active_tasks = max_active_tasks
- self._pickle_id: Optional[int] = None
+ self._pickle_id: int | None = None
self._description = description
# set file location to caller source path
back = sys._getframe().f_back
self.fileloc = back.f_code.co_filename if back else ""
- self.task_dict: Dict[str, Operator] = {}
+ self.task_dict: dict[str, Operator] = {}
# set timezone from start_date
tz = None
if start_date and start_date.tzinfo:
tzinfo = None if start_date.tzinfo else settings.TIMEZONE
tz = pendulum.instance(start_date, tz=tzinfo).timezone
- elif 'start_date' in self.default_args and self.default_args['start_date']:
- date = self.default_args['start_date']
+ elif "start_date" in self.default_args and self.default_args["start_date"]:
+ date = self.default_args["start_date"]
if not isinstance(date, datetime):
date = timezone.parse(date)
- self.default_args['start_date'] = date
+ self.default_args["start_date"] = date
start_date = date
tzinfo = None if date.tzinfo else settings.TIMEZONE
@@ -400,62 +477,96 @@ def __init__(
self.timezone = tz or settings.TIMEZONE
# Apply the timezone we settled on to end_date if it wasn't supplied
- if 'end_date' in self.default_args and self.default_args['end_date']:
- if isinstance(self.default_args['end_date'], str):
- self.default_args['end_date'] = timezone.parse(
- self.default_args['end_date'], timezone=self.timezone
+ if "end_date" in self.default_args and self.default_args["end_date"]:
+ if isinstance(self.default_args["end_date"], str):
+ self.default_args["end_date"] = timezone.parse(
+ self.default_args["end_date"], timezone=self.timezone
)
self.start_date = timezone.convert_to_utc(start_date)
self.end_date = timezone.convert_to_utc(end_date)
# also convert tasks
- if 'start_date' in self.default_args:
- self.default_args['start_date'] = timezone.convert_to_utc(self.default_args['start_date'])
- if 'end_date' in self.default_args:
- self.default_args['end_date'] = timezone.convert_to_utc(self.default_args['end_date'])
+ if "start_date" in self.default_args:
+ self.default_args["start_date"] = timezone.convert_to_utc(self.default_args["start_date"])
+ if "end_date" in self.default_args:
+ self.default_args["end_date"] = timezone.convert_to_utc(self.default_args["end_date"])
+
+ # sort out DAG's scheduling behavior
+ scheduling_args = [schedule_interval, timetable, schedule]
+ if not at_most_one(*scheduling_args):
+ raise ValueError("At most one allowed for args 'schedule_interval', 'timetable', and 'schedule'.")
+ if schedule_interval is not NOTSET:
+ warnings.warn(
+ "Param `schedule_interval` is deprecated and will be removed in a future release. "
+ "Please use `schedule` instead. ",
+ RemovedInAirflow3Warning,
+ stacklevel=2,
+ )
+ if timetable is not None:
+ warnings.warn(
+ "Param `timetable` is deprecated and will be removed in a future release. "
+ "Please use `schedule` instead. ",
+ RemovedInAirflow3Warning,
+ stacklevel=2,
+ )
- # Calculate the DAG's timetable.
- if timetable is None:
- self.timetable = create_timetable(schedule_interval, self.timezone)
- if isinstance(schedule_interval, ArgNotSet):
- schedule_interval = DEFAULT_SCHEDULE_INTERVAL
- self.schedule_interval: ScheduleInterval = schedule_interval
- elif isinstance(schedule_interval, ArgNotSet):
+ self.timetable: Timetable
+ self.schedule_interval: ScheduleInterval
+ self.dataset_triggers: Collection[Dataset] = []
+
+ if isinstance(schedule, Collection) and not isinstance(schedule, str):
+ from airflow.datasets import Dataset
+
+ if not all(isinstance(x, Dataset) for x in schedule):
+ raise ValueError("All elements in 'schedule' should be datasets")
+ self.dataset_triggers = list(schedule)
+ elif isinstance(schedule, Timetable):
+ timetable = schedule
+ elif schedule is not NOTSET:
+ schedule_interval = schedule
+
+ if self.dataset_triggers:
+ self.timetable = DatasetTriggeredTimetable()
+ self.schedule_interval = self.timetable.summary
+ elif timetable:
self.timetable = timetable
self.schedule_interval = self.timetable.summary
else:
- raise TypeError("cannot specify both 'schedule_interval' and 'timetable'")
+ if isinstance(schedule_interval, ArgNotSet):
+ schedule_interval = DEFAULT_SCHEDULE_INTERVAL
+ self.schedule_interval = schedule_interval
+ self.timetable = create_timetable(schedule_interval, self.timezone)
if isinstance(template_searchpath, str):
template_searchpath = [template_searchpath]
self.template_searchpath = template_searchpath
self.template_undefined = template_undefined
self.last_loaded = timezone.utcnow()
- self.safe_dag_id = dag_id.replace('.', '__dot__')
+ self.safe_dag_id = dag_id.replace(".", "__dot__")
self.max_active_runs = max_active_runs
self.dagrun_timeout = dagrun_timeout
self.sla_miss_callback = sla_miss_callback
if default_view in DEFAULT_VIEW_PRESETS:
self._default_view: str = default_view
- elif default_view == 'tree':
+ elif default_view == "tree":
warnings.warn(
"`default_view` of 'tree' has been renamed to 'grid' -- please update your DAG",
- DeprecationWarning,
+ RemovedInAirflow3Warning,
stacklevel=2,
)
- self._default_view = 'grid'
+ self._default_view = "grid"
else:
raise AirflowException(
- f'Invalid values of dag.default_view: only support '
- f'{DEFAULT_VIEW_PRESETS}, but get {default_view}'
+ f"Invalid values of dag.default_view: only support "
+ f"{DEFAULT_VIEW_PRESETS}, but get {default_view}"
)
if orientation in ORIENTATION_PRESETS:
self.orientation = orientation
else:
raise AirflowException(
- f'Invalid values of dag.orientation: only support '
- f'{ORIENTATION_PRESETS}, but get {orientation}'
+ f"Invalid values of dag.orientation: only support "
+ f"{ORIENTATION_PRESETS}, but get {orientation}"
)
self.catchup = catchup
@@ -466,23 +577,96 @@ def __init__(
# Keeps track of any extra edge metadata (sparse; will not contain all
# edges, so do not iterate over it for that). Outer key is upstream
# task ID, inner key is downstream task ID.
- self.edge_info: Dict[str, Dict[str, EdgeInfoType]] = {}
+ self.edge_info: dict[str, dict[str, EdgeInfoType]] = {}
# To keep it in parity with Serialized DAGs
# and identify if DAG has on_*_callback without actually storing them in Serialized JSON
self.has_on_success_callback = self.on_success_callback is not None
self.has_on_failure_callback = self.on_failure_callback is not None
- self.doc_md = doc_md
-
self._access_control = DAG._upgrade_outdated_dag_access_control(access_control)
self.is_paused_upon_creation = is_paused_upon_creation
+ self.auto_register = auto_register
self.jinja_environment_kwargs = jinja_environment_kwargs
self.render_template_as_native_obj = render_template_as_native_obj
+
+ self.doc_md = self.get_doc_md(doc_md)
+
self.tags = tags or []
self._task_group = TaskGroup.create_root(self)
self.validate_schedule_and_params()
+ wrong_links = dict(self.iter_invalid_owner_links())
+ if wrong_links:
+ raise AirflowException(
+ "Wrong link format was used for the owner. Use a valid link \n"
+ f"Bad formatted links are: {wrong_links}"
+ )
+
+ # this will only be set at serialization time
+ # it's only use is for determining the relative
+ # fileloc based only on the serialize dag
+ self._processor_dags_folder = None
+
+ def get_doc_md(self, doc_md: str | None) -> str | None:
+ if doc_md is None:
+ return doc_md
+
+ env = self.get_template_env(force_sandboxed=True)
+
+ if not doc_md.endswith(".md"):
+ template = jinja2.Template(doc_md)
+ else:
+ try:
+ template = env.get_template(doc_md)
+ except jinja2.exceptions.TemplateNotFound:
+ return f"""
+ # Templating Error!
+ Not able to find the template file: `{doc_md}`.
+ """
+
+ return template.render()
+
+ def _check_schedule_interval_matches_timetable(self) -> bool:
+ """Check ``schedule_interval`` and ``timetable`` match.
+
+ This is done as a part of the DAG validation done before it's bagged, to
+ guard against the DAG's ``timetable`` (or ``schedule_interval``) from
+ being changed after it's created, e.g.
+
+ .. code-block:: python
+
+ dag1 = DAG("d1", timetable=MyTimetable())
+ dag1.schedule_interval = "@once"
+
+ dag2 = DAG("d2", schedule="@once")
+ dag2.timetable = MyTimetable()
+
+ Validation is done by creating a timetable and check its summary matches
+ ``schedule_interval``. The logic is not bullet-proof, especially if a
+ custom timetable does not provide a useful ``summary``. But this is the
+ best we can do.
+ """
+ if self.schedule_interval == self.timetable.summary:
+ return True
+ try:
+ timetable = create_timetable(self.schedule_interval, self.timezone)
+ except ValueError:
+ return False
+ return timetable.summary == self.timetable.summary
+
+ def validate(self):
+ """Validate the DAG has a coherent setup.
+
+ This is called by the DAG bag before bagging the DAG.
+ """
+ if not self._check_schedule_interval_matches_timetable():
+ raise AirflowDagInconsistent(
+ f"inconsistent schedule: timetable {self.timetable.summary!r} "
+ f"does not match schedule_interval {self.schedule_interval!r}",
+ )
+ self.params.validate()
+ self.timetable.validate()
def __repr__(self):
return f""
@@ -504,7 +688,7 @@ def __hash__(self):
hash_components = [type(self)]
for c in self._comps:
# task_ids returns a list and lists can't be hashed
- if c == 'task_ids':
+ if c == "task_ids":
val = tuple(self.task_dict.keys())
else:
val = getattr(self, c, None)
@@ -546,7 +730,7 @@ def _upgrade_outdated_dag_access_control(access_control=None):
warnings.warn(
"The 'can_dag_read' and 'can_dag_edit' permissions are deprecated. "
"Please use 'can_read' and 'can_edit', respectively.",
- DeprecationWarning,
+ RemovedInAirflow3Warning,
stacklevel=3,
)
@@ -555,19 +739,19 @@ def _upgrade_outdated_dag_access_control(access_control=None):
def date_range(
self,
start_date: pendulum.DateTime,
- num: Optional[int] = None,
- end_date: Optional[datetime] = None,
- ) -> List[datetime]:
+ num: int | None = None,
+ end_date: datetime | None = None,
+ ) -> list[datetime]:
message = "`DAG.date_range()` is deprecated."
if num is not None:
- warnings.warn(message, category=DeprecationWarning, stacklevel=2)
+ warnings.warn(message, category=RemovedInAirflow3Warning, stacklevel=2)
with warnings.catch_warnings():
- warnings.simplefilter("ignore", DeprecationWarning)
+ warnings.simplefilter("ignore", RemovedInAirflow3Warning)
return utils_date_range(
start_date=start_date, num=num, delta=self.normalized_schedule_interval
)
message += " Please use `DAG.iter_dagrun_infos_between(..., align=False)` instead."
- warnings.warn(message, category=DeprecationWarning, stacklevel=2)
+ warnings.warn(message, category=RemovedInAirflow3Warning, stacklevel=2)
if end_date is None:
coerced_end_date = timezone.utcnow()
else:
@@ -578,7 +762,7 @@ def date_range(
def is_fixed_time_schedule(self):
warnings.warn(
"`DAG.is_fixed_time_schedule()` is deprecated.",
- category=DeprecationWarning,
+ category=RemovedInAirflow3Warning,
stacklevel=2,
)
try:
@@ -595,7 +779,7 @@ def following_schedule(self, dttm):
"""
warnings.warn(
"`DAG.following_schedule()` is deprecated. Use `DAG.next_dagrun_info(restricted=False)` instead.",
- category=DeprecationWarning,
+ category=RemovedInAirflow3Warning,
stacklevel=2,
)
data_interval = self.infer_automated_data_interval(timezone.coerce_datetime(dttm))
@@ -609,21 +793,21 @@ def previous_schedule(self, dttm):
warnings.warn(
"`DAG.previous_schedule()` is deprecated.",
- category=DeprecationWarning,
+ category=RemovedInAirflow3Warning,
stacklevel=2,
)
if not isinstance(self.timetable, _DataIntervalTimetable):
return None
return self.timetable._get_prev(timezone.coerce_datetime(dttm))
- def get_next_data_interval(self, dag_model: "DagModel") -> Optional[DataInterval]:
+ def get_next_data_interval(self, dag_model: DagModel) -> DataInterval | None:
"""Get the data interval of the next scheduled run.
For compatibility, this method infers the data interval from the DAG's
schedule if the run does not have an explicit one set, which is possible
for runs created prior to AIP-39.
- This function is private to Airflow core and should not be depended as a
+ This function is private to Airflow core and should not be depended on as a
part of the Python API.
:meta private:
@@ -648,7 +832,7 @@ def get_run_data_interval(self, run: DagRun) -> DataInterval:
schedule if the run does not have an explicit one set, which is possible for
runs created prior to AIP-39.
- This function is private to Airflow core and should not be depended as a
+ This function is private to Airflow core and should not be depended on as a
part of the Python API.
:meta private:
@@ -686,10 +870,10 @@ def infer_automated_data_interval(self, logical_date: datetime) -> DataInterval:
def next_dagrun_info(
self,
- last_automated_dagrun: Union[None, datetime, DataInterval],
+ last_automated_dagrun: None | datetime | DataInterval,
*,
restricted: bool = True,
- ) -> Optional[DagRunInfo]:
+ ) -> DagRunInfo | None:
"""Get information about the next DagRun of this dag after ``date_last_automated_dagrun``.
This calculates what time interval the next DagRun should operate on
@@ -716,7 +900,7 @@ def next_dagrun_info(
if isinstance(last_automated_dagrun, datetime):
warnings.warn(
"Passing a datetime to DAG.next_dagrun_info is deprecated. Use a DataInterval instead.",
- DeprecationWarning,
+ RemovedInAirflow3Warning,
stacklevel=2,
)
data_interval = self.infer_automated_data_interval(
@@ -742,10 +926,10 @@ def next_dagrun_info(
info = None
return info
- def next_dagrun_after_date(self, date_last_automated_dagrun: Optional[pendulum.DateTime]):
+ def next_dagrun_after_date(self, date_last_automated_dagrun: pendulum.DateTime | None):
warnings.warn(
"`DAG.next_dagrun_after_date()` is deprecated. Please use `DAG.next_dagrun_info()` instead.",
- category=DeprecationWarning,
+ category=RemovedInAirflow3Warning,
stacklevel=2,
)
if date_last_automated_dagrun is None:
@@ -776,7 +960,7 @@ def _time_restriction(self) -> TimeRestriction:
def iter_dagrun_infos_between(
self,
- earliest: Optional[pendulum.DateTime],
+ earliest: pendulum.DateTime | None,
latest: pendulum.DateTime,
*,
align: bool = True,
@@ -856,7 +1040,7 @@ def iter_dagrun_infos_between(
)
break
- def get_run_dates(self, start_date, end_date=None):
+ def get_run_dates(self, start_date, end_date=None) -> list:
"""
Returns a list of dates between the interval received as parameter using this
dag's schedule interval. Returned dates can be used for execution dates.
@@ -864,11 +1048,10 @@ def get_run_dates(self, start_date, end_date=None):
:param start_date: The start date of the interval.
:param end_date: The end date of the interval. Defaults to ``timezone.utcnow()``.
:return: A list of dates within the interval following the dag's schedule.
- :rtype: list
"""
warnings.warn(
"`DAG.get_run_dates()` is deprecated. Please use `DAG.iter_dagrun_infos_between()` instead.",
- category=DeprecationWarning,
+ category=RemovedInAirflow3Warning,
stacklevel=2,
)
earliest = timezone.coerce_datetime(start_date)
@@ -881,16 +1064,16 @@ def get_run_dates(self, start_date, end_date=None):
def normalize_schedule(self, dttm):
warnings.warn(
"`DAG.normalize_schedule()` is deprecated.",
- category=DeprecationWarning,
+ category=RemovedInAirflow3Warning,
stacklevel=2,
)
with warnings.catch_warnings():
- warnings.simplefilter("ignore", DeprecationWarning)
+ warnings.simplefilter("ignore", RemovedInAirflow3Warning)
following = self.following_schedule(dttm)
if not following: # in case of @once
return dttm
with warnings.catch_warnings():
- warnings.simplefilter("ignore", DeprecationWarning)
+ warnings.simplefilter("ignore", RemovedInAirflow3Warning)
previous_of_following = self.previous_schedule(following)
if previous_of_following != dttm:
return following
@@ -928,7 +1111,7 @@ def full_filepath(self) -> str:
""":meta private:"""
warnings.warn(
"DAG.full_filepath is deprecated in favour of fileloc",
- DeprecationWarning,
+ RemovedInAirflow3Warning,
stacklevel=2,
)
return self.fileloc
@@ -937,7 +1120,7 @@ def full_filepath(self) -> str:
def full_filepath(self, value) -> None:
warnings.warn(
"DAG.full_filepath is deprecated in favour of fileloc",
- DeprecationWarning,
+ RemovedInAirflow3Warning,
stacklevel=2,
)
self.fileloc = value
@@ -947,7 +1130,7 @@ def concurrency(self) -> int:
# TODO: Remove in Airflow 3.0
warnings.warn(
"The 'DAG.concurrency' attribute is deprecated. Please use 'DAG.max_active_tasks'.",
- DeprecationWarning,
+ RemovedInAirflow3Warning,
stacklevel=2,
)
return self._max_active_tasks
@@ -973,7 +1156,7 @@ def access_control(self, value):
self._access_control = DAG._upgrade_outdated_dag_access_control(value)
@property
- def description(self) -> Optional[str]:
+ def description(self) -> str | None:
return self._description
@property
@@ -981,7 +1164,7 @@ def default_view(self) -> str:
return self._default_view
@property
- def pickle_id(self) -> Optional[int]:
+ def pickle_id(self) -> int | None:
return self._pickle_id
@pickle_id.setter
@@ -999,26 +1182,28 @@ def param(self, name: str, default: Any = NOTSET) -> DagParam:
return DagParam(current_dag=self, name=name, default=default)
@property
- def tasks(self) -> List[Operator]:
+ def tasks(self) -> list[Operator]:
return list(self.task_dict.values())
@tasks.setter
def tasks(self, val):
- raise AttributeError('DAG.tasks can not be modified. Use dag.add_task() instead.')
+ raise AttributeError("DAG.tasks can not be modified. Use dag.add_task() instead.")
@property
- def task_ids(self) -> List[str]:
+ def task_ids(self) -> list[str]:
return list(self.task_dict.keys())
@property
- def task_group(self) -> "TaskGroup":
+ def task_group(self) -> TaskGroup:
return self._task_group
@property
def filepath(self) -> str:
""":meta private:"""
warnings.warn(
- "filepath is deprecated, use relative_fileloc instead", DeprecationWarning, stacklevel=2
+ "filepath is deprecated, use relative_fileloc instead",
+ RemovedInAirflow3Warning,
+ stacklevel=2,
)
return str(self.relative_fileloc)
@@ -1027,7 +1212,11 @@ def relative_fileloc(self) -> pathlib.Path:
"""File location of the importable dag 'file' relative to the configured DAGs folder."""
path = pathlib.Path(self.fileloc)
try:
- return path.relative_to(settings.DAGS_FOLDER)
+ rel_path = path.relative_to(self._processor_dags_folder or settings.DAGS_FOLDER)
+ if rel_path == pathlib.Path("."):
+ return path
+ else:
+ return rel_path
except ValueError:
# Not relative to DAGS_FOLDER.
return path
@@ -1043,7 +1232,6 @@ def owner(self) -> str:
Return list of all owners found in DAG tasks.
:return: Comma separated list of owners in DAG tasks
- :rtype: str
"""
return ", ".join({t.owner for t in self.tasks})
@@ -1069,18 +1257,18 @@ def concurrency_reached(self):
"""This attribute is deprecated. Please use `airflow.models.DAG.get_concurrency_reached` method."""
warnings.warn(
"This attribute is deprecated. Please use `airflow.models.DAG.get_concurrency_reached` method.",
- DeprecationWarning,
+ RemovedInAirflow3Warning,
stacklevel=2,
)
return self.get_concurrency_reached()
@provide_session
- def get_is_active(self, session=NEW_SESSION) -> Optional[None]:
+ def get_is_active(self, session=NEW_SESSION) -> None:
"""Returns a boolean indicating whether this DAG is active"""
return session.query(DagModel.is_active).filter(DagModel.dag_id == self.dag_id).scalar()
@provide_session
- def get_is_paused(self, session=NEW_SESSION) -> Optional[None]:
+ def get_is_paused(self, session=NEW_SESSION) -> None:
"""Returns a boolean indicating whether this DAG is paused"""
return session.query(DagModel.is_paused).filter(DagModel.dag_id == self.dag_id).scalar()
@@ -1089,7 +1277,7 @@ def is_paused(self):
"""This attribute is deprecated. Please use `airflow.models.DAG.get_is_paused` method."""
warnings.warn(
"This attribute is deprecated. Please use `airflow.models.DAG.get_is_paused` method.",
- DeprecationWarning,
+ RemovedInAirflow3Warning,
stacklevel=2,
)
return self.get_is_paused()
@@ -1098,12 +1286,12 @@ def is_paused(self):
def normalized_schedule_interval(self) -> ScheduleInterval:
warnings.warn(
"DAG.normalized_schedule_interval() is deprecated.",
- category=DeprecationWarning,
+ category=RemovedInAirflow3Warning,
stacklevel=2,
)
if isinstance(self.schedule_interval, str) and self.schedule_interval in cron_presets:
_schedule_interval: ScheduleInterval = cron_presets.get(self.schedule_interval)
- elif self.schedule_interval == '@once':
+ elif self.schedule_interval == "@once":
_schedule_interval = None
else:
_schedule_interval = self.schedule_interval
@@ -1127,12 +1315,12 @@ def handle_callback(self, dagrun, success=True, reason=None, session=NEW_SESSION
"""
callback = self.on_success_callback if success else self.on_failure_callback
if callback:
- self.log.info('Executing dag callback function: %s', callback)
+ self.log.info("Executing dag callback function: %s", callback)
tis = dagrun.get_task_instances(session=session)
ti = tis[-1] # get first TaskInstance of DagRun
ti.task = self.get_task(ti.task_id)
context = ti.get_template_context(session=session)
- context.update({'reason': reason})
+ context.update({"reason": reason})
try:
callback(context)
except Exception:
@@ -1179,8 +1367,8 @@ def get_num_active_runs(self, external_trigger=None, only_running=True, session=
@provide_session
def get_dagrun(
self,
- execution_date: Optional[datetime] = None,
- run_id: Optional[str] = None,
+ execution_date: datetime | None = None,
+ run_id: str | None = None,
session: Session = NEW_SESSION,
):
"""
@@ -1224,7 +1412,7 @@ def get_dagruns_between(self, start_date, end_date, session=NEW_SESSION):
return dagruns
@provide_session
- def get_latest_execution_date(self, session: Session = NEW_SESSION) -> Optional[pendulum.DateTime]:
+ def get_latest_execution_date(self, session: Session = NEW_SESSION) -> pendulum.DateTime | None:
"""Returns the latest date for which at least one dag run exists"""
return session.query(func.max(DagRun.execution_date)).filter(DagRun.dag_id == self.dag_id).scalar()
@@ -1233,7 +1421,7 @@ def latest_execution_date(self):
"""This attribute is deprecated. Please use `airflow.models.DAG.get_latest_execution_date`."""
warnings.warn(
"This attribute is deprecated. Please use `airflow.models.DAG.get_latest_execution_date`.",
- DeprecationWarning,
+ RemovedInAirflow3Warning,
stacklevel=2,
)
return self.get_latest_execution_date()
@@ -1250,8 +1438,8 @@ def subdags(self):
isinstance(task, SubDagOperator)
or
# TODO remove in Airflow 2.0
- type(task).__name__ == 'SubDagOperator'
- or task.task_type == 'SubDagOperator'
+ type(task).__name__ == "SubDagOperator"
+ or task.task_type == "SubDagOperator"
):
subdag_lst.append(task.subdag)
subdag_lst += task.subdag.subdags
@@ -1270,10 +1458,10 @@ def get_template_env(self, *, force_sandboxed: bool = False) -> jinja2.Environme
# Default values (for backward compatibility)
jinja_env_options = {
- 'loader': jinja2.FileSystemLoader(searchpath),
- 'undefined': self.template_undefined,
- 'extensions': ["jinja2.ext.do"],
- 'cache_size': 0,
+ "loader": jinja2.FileSystemLoader(searchpath),
+ "undefined": self.template_undefined,
+ "extensions": ["jinja2.ext.do"],
+ "cache_size": 0,
}
if self.jinja_environment_kwargs:
jinja_env_options.update(self.jinja_environment_kwargs)
@@ -1306,7 +1494,7 @@ def get_task_instances_before(
num: int,
*,
session: Session = NEW_SESSION,
- ) -> List[TaskInstance]:
+ ) -> list[TaskInstance]:
"""Get ``num`` task instances before (including) ``base_date``.
The returned list may contain exactly ``num`` task instances. It can
@@ -1314,7 +1502,7 @@ def get_task_instances_before(
``base_date``, or more if there are manual task runs between the
requested period, which does not count toward ``num``.
"""
- min_date: Optional[datetime] = (
+ min_date: datetime | None = (
session.query(DagRun.execution_date)
.filter(
DagRun.dag_id == self.dag_id,
@@ -1333,11 +1521,11 @@ def get_task_instances_before(
@provide_session
def get_task_instances(
self,
- start_date: Optional[datetime] = None,
- end_date: Optional[datetime] = None,
- state: Optional[List[TaskInstanceState]] = None,
+ start_date: datetime | None = None,
+ end_date: datetime | None = None,
+ state: list[TaskInstanceState] | None = None,
session: Session = NEW_SESSION,
- ) -> List[TaskInstance]:
+ ) -> list[TaskInstance]:
if not start_date:
start_date = (timezone.utcnow() - timedelta(30)).replace(
hour=0, minute=0, second=0, microsecond=0
@@ -1360,17 +1548,17 @@ def get_task_instances(
def _get_task_instances(
self,
*,
- task_ids: Optional[Collection[Union[str, Tuple[str, int]]]],
- start_date: Optional[datetime],
- end_date: Optional[datetime],
- run_id: Optional[str],
- state: Union[TaskInstanceState, Sequence[TaskInstanceState]],
+ task_ids: Collection[str | tuple[str, int]] | None,
+ start_date: datetime | None,
+ end_date: datetime | None,
+ run_id: str | None,
+ state: TaskInstanceState | Sequence[TaskInstanceState],
include_subdags: bool,
include_parentdag: bool,
include_dependent_dags: bool,
- exclude_task_ids: Optional[Collection[Union[str, Tuple[str, int]]]],
+ exclude_task_ids: Collection[str | tuple[str, int]] | None,
session: Session,
- dag_bag: Optional["DagBag"] = ...,
+ dag_bag: DagBag | None = ...,
) -> Iterable[TaskInstance]:
... # pragma: no cover
@@ -1378,43 +1566,43 @@ def _get_task_instances(
def _get_task_instances(
self,
*,
- task_ids: Optional[Collection[Union[str, Tuple[str, int]]]],
+ task_ids: Collection[str | tuple[str, int]] | None,
as_pk_tuple: Literal[True],
- start_date: Optional[datetime],
- end_date: Optional[datetime],
- run_id: Optional[str],
- state: Union[TaskInstanceState, Sequence[TaskInstanceState]],
+ start_date: datetime | None,
+ end_date: datetime | None,
+ run_id: str | None,
+ state: TaskInstanceState | Sequence[TaskInstanceState],
include_subdags: bool,
include_parentdag: bool,
include_dependent_dags: bool,
- exclude_task_ids: Optional[Collection[Union[str, Tuple[str, int]]]],
+ exclude_task_ids: Collection[str | tuple[str, int]] | None,
session: Session,
- dag_bag: Optional["DagBag"] = ...,
+ dag_bag: DagBag | None = ...,
recursion_depth: int = ...,
max_recursion_depth: int = ...,
- visited_external_tis: Set[TaskInstanceKey] = ...,
- ) -> Set["TaskInstanceKey"]:
+ visited_external_tis: set[TaskInstanceKey] = ...,
+ ) -> set[TaskInstanceKey]:
... # pragma: no cover
def _get_task_instances(
self,
*,
- task_ids: Optional[Collection[Union[str, Tuple[str, int]]]],
+ task_ids: Collection[str | tuple[str, int]] | None,
as_pk_tuple: Literal[True, None] = None,
- start_date: Optional[datetime],
- end_date: Optional[datetime],
- run_id: Optional[str],
- state: Union[TaskInstanceState, Sequence[TaskInstanceState]],
+ start_date: datetime | None,
+ end_date: datetime | None,
+ run_id: str | None,
+ state: TaskInstanceState | Sequence[TaskInstanceState],
include_subdags: bool,
include_parentdag: bool,
include_dependent_dags: bool,
- exclude_task_ids: Optional[Collection[Union[str, Tuple[str, int]]]],
+ exclude_task_ids: Collection[str | tuple[str, int]] | None,
session: Session,
- dag_bag: Optional["DagBag"] = None,
+ dag_bag: DagBag | None = None,
recursion_depth: int = 0,
- max_recursion_depth: Optional[int] = None,
- visited_external_tis: Optional[Set[TaskInstanceKey]] = None,
- ) -> Union[Iterable[TaskInstance], Set[TaskInstanceKey]]:
+ max_recursion_depth: int | None = None,
+ visited_external_tis: set[TaskInstanceKey] | None = None,
+ ) -> Iterable[TaskInstance] | set[TaskInstanceKey]:
TI = TaskInstance
# If we are looking at subdags/dependent dags we want to avoid UNION calls
@@ -1423,7 +1611,7 @@ def _get_task_instances(
#
# This will be empty if we are only looking at one dag, in which case
# we can return the filtered TI query object directly.
- result: Set[TaskInstanceKey] = set()
+ result: set[TaskInstanceKey] = set()
# Do we want full objects, or just the primary columns?
if as_pk_tuple:
@@ -1481,7 +1669,7 @@ def _get_task_instances(
visited_external_tis = set()
p_dag = self.parent_dag.partial_subset(
- task_ids_or_regex=r"^{}$".format(self.dag_id.split('.')[1]),
+ task_ids_or_regex=r"^{}$".format(self.dag_id.split(".")[1]),
include_upstream=False,
include_downstream=True,
)
@@ -1553,6 +1741,8 @@ def _get_task_instances(
for tii in external_tis:
if not dag_bag:
+ from airflow.models.dagbag import DagBag
+
dag_bag = DagBag(read_dags_from_db=True)
external_dag = dag_bag.get_dag(tii.dag_id, session=session)
if not external_dag:
@@ -1585,7 +1775,7 @@ def _get_task_instances(
if result or as_pk_tuple:
# Only execute the `ti` query if we have also collected some other results (i.e. subdags etc.)
if as_pk_tuple:
- result.update(TaskInstanceKey(*cols) for cols in tis.all())
+ result.update(TaskInstanceKey(**cols._mapping) for cols in tis.all())
else:
result.update(ti.key for ti in tis)
@@ -1618,9 +1808,9 @@ def set_task_instance_state(
self,
*,
task_id: str,
- map_indexes: Optional[Collection[int]] = None,
- execution_date: Optional[datetime] = None,
- run_id: Optional[str] = None,
+ map_indexes: Collection[int] | None = None,
+ execution_date: datetime | None = None,
+ run_id: str | None = None,
state: TaskInstanceState,
upstream: bool = False,
downstream: bool = False,
@@ -1628,7 +1818,7 @@ def set_task_instance_state(
past: bool = False,
commit: bool = True,
session=NEW_SESSION,
- ) -> List[TaskInstance]:
+ ) -> list[TaskInstance]:
"""
Set the state of a TaskInstance to the given state, and clear its downstream tasks that are
in failed or upstream_failed state.
@@ -1653,7 +1843,7 @@ def set_task_instance_state(
task = self.get_task(task_id)
task.dag = self
- tasks_to_set_state: List[Union[Operator, Tuple[Operator, int]]]
+ tasks_to_set_state: list[Operator | tuple[Operator, int]]
if map_indexes is None:
tasks_to_set_state = [task]
else:
@@ -1709,12 +1899,12 @@ def set_task_instance_state(
return altered
@property
- def roots(self) -> List[Operator]:
+ def roots(self) -> list[Operator]:
"""Return nodes with no parents. These are first to execute and are called roots or root nodes."""
return [task for task in self.tasks if not task.upstream_list]
@property
- def leaves(self) -> List[Operator]:
+ def leaves(self) -> list[Operator]:
"""Return nodes with no children. These are last to execute and are called leaves or leaf nodes."""
return [task for task in self.tasks if not task.downstream_list]
@@ -1741,13 +1931,13 @@ def set_dag_runs_state(
self,
state: str = State.RUNNING,
session: Session = NEW_SESSION,
- start_date: Optional[datetime] = None,
- end_date: Optional[datetime] = None,
- dag_ids: List[str] = [],
+ start_date: datetime | None = None,
+ end_date: datetime | None = None,
+ dag_ids: list[str] = [],
) -> None:
warnings.warn(
"This method is deprecated and will be removed in a future version.",
- DeprecationWarning,
+ RemovedInAirflow3Warning,
stacklevel=3,
)
dag_ids = dag_ids or [self.dag_id]
@@ -1756,14 +1946,14 @@ def set_dag_runs_state(
query = query.filter(DagRun.execution_date >= start_date)
if end_date:
query = query.filter(DagRun.execution_date <= end_date)
- query.update({DagRun.state: state}, synchronize_session='fetch')
+ query.update({DagRun.state: state}, synchronize_session="fetch")
@provide_session
def clear(
self,
- task_ids: Union[Collection[str], Collection[Tuple[str, int]], None] = None,
- start_date: Optional[datetime] = None,
- end_date: Optional[datetime] = None,
+ task_ids: Collection[str | tuple[str, int]] | None = None,
+ start_date: datetime | None = None,
+ end_date: datetime | None = None,
only_failed: bool = False,
only_running: bool = False,
confirm_prompt: bool = False,
@@ -1774,10 +1964,10 @@ def clear(
session: Session = NEW_SESSION,
get_tis: bool = False,
recursion_depth: int = 0,
- max_recursion_depth: Optional[int] = None,
- dag_bag: Optional["DagBag"] = None,
- exclude_task_ids: Union[FrozenSet[str], FrozenSet[Tuple[str, int]], None] = frozenset(),
- ) -> Union[int, Iterable[TaskInstance]]:
+ max_recursion_depth: int | None = None,
+ dag_bag: DagBag | None = None,
+ exclude_task_ids: frozenset[str] | frozenset[tuple[str, int]] | None = frozenset(),
+ ) -> int | Iterable[TaskInstance]:
"""
Clears a set of task instances associated with the current dag for
a specified date range.
@@ -1802,7 +1992,7 @@ def clear(
if get_tis:
warnings.warn(
"Passing `get_tis` to dag.clear() is deprecated. Use `dry_run` parameter instead.",
- DeprecationWarning,
+ RemovedInAirflow3Warning,
stacklevel=2,
)
dry_run = True
@@ -1810,13 +2000,13 @@ def clear(
if recursion_depth:
warnings.warn(
"Passing `recursion_depth` to dag.clear() is deprecated.",
- DeprecationWarning,
+ RemovedInAirflow3Warning,
stacklevel=2,
)
if max_recursion_depth:
warnings.warn(
"Passing `max_recursion_depth` to dag.clear() is deprecated.",
- DeprecationWarning,
+ RemovedInAirflow3Warning,
stacklevel=2,
)
@@ -1937,12 +2127,12 @@ def __deepcopy__(self, memo):
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
- if k not in ('user_defined_macros', 'user_defined_filters', '_log'):
+ if k not in ("user_defined_macros", "user_defined_filters", "_log"):
setattr(result, k, copy.deepcopy(v, memo))
result.user_defined_macros = self.user_defined_macros
result.user_defined_filters = self.user_defined_filters
- if hasattr(self, '_log'):
+ if hasattr(self, "_log"):
result._log = self._log
return result
@@ -1950,14 +2140,14 @@ def sub_dag(self, *args, **kwargs):
"""This method is deprecated in favor of partial_subset"""
warnings.warn(
"This method is deprecated and will be removed in a future version. Please use partial_subset",
- DeprecationWarning,
+ RemovedInAirflow3Warning,
stacklevel=2,
)
return self.partial_subset(*args, **kwargs)
def partial_subset(
self,
- task_ids_or_regex: Union[str, re.Pattern, Iterable[str]],
+ task_ids_or_regex: str | re.Pattern | Iterable[str],
include_downstream=False,
include_upstream=True,
include_direct_upstream=False,
@@ -1976,7 +2166,6 @@ def partial_subset(
:param include_direct_upstream: Include all tasks directly upstream of matched
and downstream (if include_downstream = True) tasks
"""
-
from airflow.models.baseoperator import BaseOperator
from airflow.models.mappedoperator import MappedOperator
@@ -1990,14 +2179,14 @@ def partial_subset(
else:
matched_tasks = [t for t in self.tasks if t.task_id in task_ids_or_regex]
- also_include: List[Operator] = []
+ also_include: list[Operator] = []
for t in matched_tasks:
if include_downstream:
also_include.extend(t.get_flat_relatives(upstream=False))
if include_upstream:
also_include.extend(t.get_flat_relatives(upstream=True))
- direct_upstreams: List[Operator] = []
+ direct_upstreams: list[Operator] = []
if include_direct_upstream:
for t in itertools.chain(matched_tasks, also_include):
upstream = (u for u in t.upstream_list if isinstance(u, (BaseOperator, MappedOperator)))
@@ -2006,7 +2195,7 @@ def partial_subset(
# Compiling the unique list of tasks that made the cut
# Make sure to not recursively deepcopy the dag or task_group while copying the task.
# task_group is reset later
- def _deepcopy_task(t) -> "Operator":
+ def _deepcopy_task(t) -> Operator:
memo.setdefault(id(t.task_group), None)
return copy.deepcopy(t, memo)
@@ -2064,6 +2253,13 @@ def filter_task_group(group, parent_group):
def has_task(self, task_id: str):
return task_id in self.task_dict
+ def has_task_group(self, task_group_id: str) -> bool:
+ return task_group_id in self.task_group_dict
+
+ @cached_property
+ def task_group_dict(self):
+ return {k: v for k, v in self._task_group.get_task_group_dict().items() if k is not None}
+
def get_task(self, task_id: str, include_subdags: bool = False) -> Operator:
if task_id in self.task_dict:
return self.task_dict[task_id]
@@ -2075,16 +2271,16 @@ def get_task(self, task_id: str, include_subdags: bool = False) -> Operator:
def pickle_info(self):
d = {}
- d['is_picklable'] = True
+ d["is_picklable"] = True
try:
dttm = timezone.utcnow()
pickled = pickle.dumps(self)
- d['pickle_len'] = len(pickled)
- d['pickling_duration'] = str(timezone.utcnow() - dttm)
+ d["pickle_len"] = len(pickled)
+ d["pickling_duration"] = str(timezone.utcnow() - dttm)
except Exception as e:
self.log.debug(e)
- d['is_picklable'] = False
- d['stacktrace'] = traceback.format_exc()
+ d["is_picklable"] = False
+ d["stacktrace"] = traceback.format_exc()
return d
@provide_session
@@ -2115,7 +2311,7 @@ def get_downstream(task, level=0):
get_downstream(t)
@property
- def task(self) -> "TaskDecoratorCollection":
+ def task(self) -> TaskDecoratorCollection:
from airflow.decorators import task
return cast("TaskDecoratorCollection", functools.partial(task, dag=self))
@@ -2126,6 +2322,8 @@ def add_task(self, task: Operator) -> None:
:param task: the task you want to add
"""
+ from airflow.utils.task_group import TaskGroupContext
+
if not self.start_date and not task.start_date:
raise AirflowException("DAG is missing the start_date parameter")
# if the task has no start date, assign it the same as the DAG
@@ -2144,15 +2342,22 @@ def add_task(self, task: Operator) -> None:
elif task.end_date and self.end_date:
task.end_date = min(task.end_date, self.end_date)
+ task_id = task.task_id
+ if not task.task_group:
+ task_group = TaskGroupContext.get_current_task_group(self)
+ if task_group:
+ task_id = task_group.child_id(task_id)
+ task_group.add(task)
+
if (
- task.task_id in self.task_dict and self.task_dict[task.task_id] is not task
- ) or task.task_id in self._task_group.used_group_ids:
- raise DuplicateTaskIdFound(f"Task id '{task.task_id}' has already been added to the DAG")
+ task_id in self.task_dict and self.task_dict[task_id] is not task
+ ) or task_id in self._task_group.used_group_ids:
+ raise DuplicateTaskIdFound(f"Task id '{task_id}' has already been added to the DAG")
else:
- self.task_dict[task.task_id] = task
+ self.task_dict[task_id] = task
task.dag = self
# Add task_id to used_group_ids to prevent group_id and task_id collisions.
- self._task_group.used_group_ids.add(task.task_id)
+ self._task_group.used_group_ids.add(task_id)
self.task_count = len(self.task_dict)
@@ -2169,7 +2374,7 @@ def _remove_task(self, task_id: str) -> None:
# This is "private" as removing could leave a hole in dependencies if done incorrectly, and this
# doesn't guard against that
task = self.task_dict.pop(task_id)
- tg = getattr(task, 'task_group', None)
+ tg = getattr(task, "task_group", None)
if tg:
tg._remove(task)
@@ -2182,7 +2387,7 @@ def run(
mark_success=False,
local=False,
executor=None,
- donot_pickle=conf.getboolean('core', 'donot_pickle'),
+ donot_pickle=conf.getboolean("core", "donot_pickle"),
ignore_task_deps=False,
ignore_first_depends_on_past=True,
pool=None,
@@ -2193,6 +2398,7 @@ def run(
run_backwards=False,
run_at_least_once=False,
continue_on_failures=False,
+ disable_retry=False,
):
"""
Runs the DAG.
@@ -2243,6 +2449,7 @@ def run(
run_backwards=run_backwards,
run_at_least_once=run_at_least_once,
continue_on_failures=continue_on_failures,
+ disable_retry=disable_retry,
)
job.run()
@@ -2256,20 +2463,96 @@ def cli(self):
args = parser.parse_args()
args.func(args, self)
+ @provide_session
+ def test(
+ self,
+ execution_date: datetime | None = None,
+ run_conf: dict[str, Any] | None = None,
+ conn_file_path: str | None = None,
+ variable_file_path: str | None = None,
+ session: Session = NEW_SESSION,
+ ) -> None:
+ """
+ Execute one single DagRun for a given DAG and execution date.
+
+ :param execution_date: execution date for the DAG run
+ :param run_conf: configuration to pass to newly created dagrun
+ :param conn_file_path: file path to a connection file in either yaml or json
+ :param variable_file_path: file path to a variable file in either yaml or json
+ :param session: database connection (optional)
+ """
+
+ def add_logger_if_needed(ti: TaskInstance):
+ """
+ Add a formatted logger to the taskinstance so all logs are surfaced to the command line instead
+ of into a task file. Since this is a local test run, it is much better for the user to see logs
+ in the command line, rather than needing to search for a log file.
+ Args:
+ ti: The taskinstance that will receive a logger
+
+ """
+ format = logging.Formatter("[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s")
+ handler = logging.StreamHandler(sys.stdout)
+ handler.level = logging.INFO
+ handler.setFormatter(format)
+ # only add log handler once
+ if not any(isinstance(h, logging.StreamHandler) for h in ti.log.handlers):
+ self.log.debug("Adding Streamhandler to taskinstance %s", ti.task_id)
+ ti.log.addHandler(handler)
+
+ if conn_file_path or variable_file_path:
+ local_secrets = LocalFilesystemBackend(
+ variables_file_path=variable_file_path, connections_file_path=conn_file_path
+ )
+ secrets_backend_list.insert(0, local_secrets)
+
+ execution_date = execution_date or timezone.utcnow()
+ self.log.debug("Clearing existing task instances for execution date %s", execution_date)
+ self.clear(
+ start_date=execution_date,
+ end_date=execution_date,
+ dag_run_state=False, # type: ignore
+ session=session,
+ )
+ self.log.debug("Getting dagrun for dag %s", self.dag_id)
+ dr: DagRun = _get_or_create_dagrun(
+ dag=self,
+ start_date=execution_date,
+ execution_date=execution_date,
+ run_id=DagRun.generate_run_id(DagRunType.MANUAL, execution_date),
+ session=session,
+ conf=run_conf,
+ )
+
+ tasks = self.task_dict
+ self.log.debug("starting dagrun")
+ # Instead of starting a scheduler, we run the minimal loop possible to check
+ # for task readiness and dependency management. This is notably faster
+ # than creating a BackfillJob and allows us to surface logs to the user
+ while dr.state == State.RUNNING:
+ schedulable_tis, _ = dr.update_state(session=session)
+ for ti in schedulable_tis:
+ add_logger_if_needed(ti)
+ ti.task = tasks[ti.task_id]
+ _run_task(ti, session=session)
+ if conn_file_path or variable_file_path:
+ # Remove the local variables we have added to the secrets_backend_list
+ secrets_backend_list.pop(0)
+
@provide_session
def create_dagrun(
self,
state: DagRunState,
- execution_date: Optional[datetime] = None,
- run_id: Optional[str] = None,
- start_date: Optional[datetime] = None,
- external_trigger: Optional[bool] = False,
- conf: Optional[dict] = None,
- run_type: Optional[DagRunType] = None,
- session=NEW_SESSION,
- dag_hash: Optional[str] = None,
- creating_job_id: Optional[int] = None,
- data_interval: Optional[Tuple[datetime, datetime]] = None,
+ execution_date: datetime | None = None,
+ run_id: str | None = None,
+ start_date: datetime | None = None,
+ external_trigger: bool | None = False,
+ conf: dict | None = None,
+ run_type: DagRunType | None = None,
+ session: Session = NEW_SESSION,
+ dag_hash: str | None = None,
+ creating_job_id: int | None = None,
+ data_interval: tuple[datetime, datetime] | None = None,
):
"""
Creates a dag run from this dag including the tasks associated with this dag.
@@ -2287,15 +2570,46 @@ def create_dagrun(
:param dag_hash: Hash of Serialized DAG
:param data_interval: Data interval of the DagRun
"""
+ logical_date = timezone.coerce_datetime(execution_date)
+
+ if data_interval and not isinstance(data_interval, DataInterval):
+ data_interval = DataInterval(*map(timezone.coerce_datetime, data_interval))
+
+ if data_interval is None and logical_date is not None:
+ warnings.warn(
+ "Calling `DAG.create_dagrun()` without an explicit data interval is deprecated",
+ RemovedInAirflow3Warning,
+ stacklevel=3,
+ )
+ if run_type == DagRunType.MANUAL:
+ data_interval = self.timetable.infer_manual_data_interval(run_after=logical_date)
+ else:
+ data_interval = self.infer_automated_data_interval(logical_date)
+
+ if run_type is None or isinstance(run_type, DagRunType):
+ pass
+ elif isinstance(run_type, str): # Compatibility: run_type used to be a str.
+ run_type = DagRunType(run_type)
+ else:
+ raise ValueError(f"`run_type` should be a DagRunType, not {type(run_type)}")
+
if run_id: # Infer run_type from run_id if needed.
if not isinstance(run_id, str):
raise ValueError(f"`run_id` should be a str, not {type(run_id)}")
- if not run_type:
- run_type = DagRunType.from_run_id(run_id)
- elif run_type and execution_date is not None: # Generate run_id from run_type and execution_date.
- if not isinstance(run_type, DagRunType):
- raise ValueError(f"`run_type` should be a DagRunType, not {type(run_type)}")
- run_id = DagRun.generate_run_id(run_type, execution_date)
+ inferred_run_type = DagRunType.from_run_id(run_id)
+ if run_type is None:
+ # No explicit type given, use the inferred type.
+ run_type = inferred_run_type
+ elif run_type == DagRunType.MANUAL and inferred_run_type != DagRunType.MANUAL:
+ # Prevent a manual run from using an ID that looks like a scheduled run.
+ raise ValueError(
+ f"A {run_type.value} DAG run cannot use ID {run_id!r} since it "
+ f"is reserved for {inferred_run_type.value} runs"
+ )
+ elif run_type and logical_date is not None: # Generate run_id from run_type and execution_date.
+ run_id = self.timetable.generate_run_id(
+ run_type=run_type, logical_date=logical_date, data_interval=data_interval
+ )
else:
raise AirflowException(
"Creating DagRun needs either `run_id` or both `run_type` and `execution_date`"
@@ -2305,21 +2619,9 @@ def create_dagrun(
warnings.warn(
"Using forward slash ('/') in a DAG run ID is deprecated. Note that this character "
"also makes the run impossible to retrieve via Airflow's REST API.",
- DeprecationWarning,
- stacklevel=3,
- )
-
- logical_date = timezone.coerce_datetime(execution_date)
- if data_interval is None and logical_date is not None:
- warnings.warn(
- "Calling `DAG.create_dagrun()` without an explicit data interval is deprecated",
- DeprecationWarning,
+ RemovedInAirflow3Warning,
stacklevel=3,
)
- if run_type == DagRunType.MANUAL:
- data_interval = self.timetable.infer_manual_data_interval(run_after=logical_date)
- else:
- data_interval = self.infer_automated_data_interval(logical_date)
# create a copy of params before validating
copied_params = copy.deepcopy(self.params)
@@ -2352,18 +2654,27 @@ def create_dagrun(
@classmethod
@provide_session
- def bulk_sync_to_db(cls, dags: Collection["DAG"], session=NEW_SESSION):
+ def bulk_sync_to_db(
+ cls,
+ dags: Collection[DAG],
+ session=NEW_SESSION,
+ ):
"""This method is deprecated in favor of bulk_write_to_db"""
warnings.warn(
"This method is deprecated and will be removed in a future version. Please use bulk_write_to_db",
- DeprecationWarning,
+ RemovedInAirflow3Warning,
stacklevel=2,
)
- return cls.bulk_write_to_db(dags, session)
+ return cls.bulk_write_to_db(dags=dags, session=session)
@classmethod
@provide_session
- def bulk_write_to_db(cls, dags: Collection["DAG"], session=NEW_SESSION):
+ def bulk_write_to_db(
+ cls,
+ dags: Collection[DAG],
+ processor_subdir: str | None = None,
+ session=NEW_SESSION,
+ ):
"""
Ensure the DagModel rows for the given dags are up-to-date in the dag table in the DB, including
calculated fields.
@@ -2378,16 +2689,18 @@ def bulk_write_to_db(cls, dags: Collection["DAG"], session=NEW_SESSION):
log.info("Sync %s DAGs", len(dags))
dag_by_ids = {dag.dag_id: dag for dag in dags}
+
dag_ids = set(dag_by_ids.keys())
query = (
session.query(DagModel)
.options(joinedload(DagModel.tags, innerjoin=False))
.filter(DagModel.dag_id.in_(dag_ids))
+ .options(joinedload(DagModel.schedule_dataset_references))
+ .options(joinedload(DagModel.task_outlet_dataset_references))
)
- orm_dags: List[DagModel] = with_row_locks(query, of=DagModel, session=session).all()
-
- existing_dag_ids = {orm_dag.dag_id for orm_dag in orm_dags}
- missing_dag_ids = dag_ids.difference(existing_dag_ids)
+ orm_dags: list[DagModel] = with_row_locks(query, of=DagModel, session=session).all()
+ existing_dags = {orm_dag.dag_id: orm_dag for orm_dag in orm_dags}
+ missing_dag_ids = dag_ids.difference(existing_dags)
for missing_dag_id in missing_dag_ids:
orm_dag = DagModel(dag_id=missing_dag_id)
@@ -2403,7 +2716,7 @@ def bulk_write_to_db(cls, dags: Collection["DAG"], session=NEW_SESSION):
most_recent_subq = (
session.query(DagRun.dag_id, func.max(DagRun.execution_date).label("max_execution_date"))
.filter(
- DagRun.dag_id.in_(existing_dag_ids),
+ DagRun.dag_id.in_(existing_dags),
or_(DagRun.run_type == DagRunType.BACKFILL_JOB, DagRun.run_type == DagRunType.SCHEDULED),
)
.group_by(DagRun.dag_id)
@@ -2417,7 +2730,7 @@ def bulk_write_to_db(cls, dags: Collection["DAG"], session=NEW_SESSION):
# Get number of active dagruns for all dags we are processing as a single query.
- num_active_runs = DagRun.active_runs_of_dags(dag_ids=existing_dag_ids, session=session)
+ num_active_runs = DagRun.active_runs_of_dags(dag_ids=existing_dags, session=session)
filelocs = []
@@ -2438,13 +2751,14 @@ def bulk_write_to_db(cls, dags: Collection["DAG"], session=NEW_SESSION):
orm_dag.last_parsed_time = timezone.utcnow()
orm_dag.default_view = dag.default_view
orm_dag.description = dag.description
- orm_dag.schedule_interval = dag.schedule_interval
orm_dag.max_active_tasks = dag.max_active_tasks
orm_dag.max_active_runs = dag.max_active_runs
orm_dag.has_task_concurrency_limits = any(t.max_active_tis_per_dag is not None for t in dag.tasks)
+ orm_dag.schedule_interval = dag.schedule_interval
orm_dag.timetable_description = dag.timetable.description
+ orm_dag.processor_subdir = processor_subdir
- run: Optional[DagRun] = most_recent_runs.get(dag.dag_id)
+ run: DagRun | None = most_recent_runs.get(dag.dag_id)
if run is None:
data_interval = None
else:
@@ -2467,17 +2781,119 @@ def bulk_write_to_db(cls, dags: Collection["DAG"], session=NEW_SESSION):
orm_dag.tags.append(dag_tag_orm)
session.add(dag_tag_orm)
+ orm_dag_links = orm_dag.dag_owner_links or []
+ for orm_dag_link in orm_dag_links:
+ if orm_dag_link not in dag.owner_links:
+ session.delete(orm_dag_link)
+ for owner_name, owner_link in dag.owner_links.items():
+ dag_owner_orm = DagOwnerAttributes(dag_id=dag.dag_id, owner=owner_name, link=owner_link)
+ session.add(dag_owner_orm)
+
DagCode.bulk_sync_to_db(filelocs, session=session)
+ from airflow.datasets import Dataset
+ from airflow.models.dataset import (
+ DagScheduleDatasetReference,
+ DatasetModel,
+ TaskOutletDatasetReference,
+ )
+
+ dag_references = collections.defaultdict(set)
+ outlet_references = collections.defaultdict(set)
+ # We can't use a set here as we want to preserve order
+ outlet_datasets: dict[Dataset, None] = {}
+ input_datasets: dict[Dataset, None] = {}
+
+ # here we go through dags and tasks to check for dataset references
+ # if there are now None and previously there were some, we delete them
+ # if there are now *any*, we add them to the above data structures, and
+ # later we'll persist them to the database.
+ for dag in dags:
+ curr_orm_dag = existing_dags.get(dag.dag_id)
+ if not dag.dataset_triggers:
+ if curr_orm_dag and curr_orm_dag.schedule_dataset_references:
+ curr_orm_dag.schedule_dataset_references = []
+ for dataset in dag.dataset_triggers:
+ dag_references[dag.dag_id].add(dataset.uri)
+ input_datasets[DatasetModel.from_public(dataset)] = None
+ curr_outlet_references = curr_orm_dag and curr_orm_dag.task_outlet_dataset_references
+ for task in dag.tasks:
+ dataset_outlets = [x for x in task.outlets or [] if isinstance(x, Dataset)]
+ if not dataset_outlets:
+ if curr_outlet_references:
+ this_task_outlet_refs = [
+ x
+ for x in curr_outlet_references
+ if x.dag_id == dag.dag_id and x.task_id == task.task_id
+ ]
+ for ref in this_task_outlet_refs:
+ curr_outlet_references.remove(ref)
+ for d in dataset_outlets:
+ outlet_references[(task.dag_id, task.task_id)].add(d.uri)
+ outlet_datasets[DatasetModel.from_public(d)] = None
+ all_datasets = outlet_datasets
+ all_datasets.update(input_datasets)
+
+ # store datasets
+ stored_datasets = {}
+ for dataset in all_datasets:
+ stored_dataset = session.query(DatasetModel).filter(DatasetModel.uri == dataset.uri).first()
+ if stored_dataset:
+ # Some datasets may have been previously unreferenced, and therefore orphaned by the
+ # scheduler. But if we're here, then we have found that dataset again in our DAGs, which
+ # means that it is no longer an orphan, so set is_orphaned to False.
+ stored_dataset.is_orphaned = expression.false()
+ stored_datasets[stored_dataset.uri] = stored_dataset
+ else:
+ session.add(dataset)
+ stored_datasets[dataset.uri] = dataset
+
+ session.flush() # this is required to ensure each dataset has its PK loaded
+
+ del all_datasets
+
+ # reconcile dag-schedule-on-dataset references
+ for dag_id, uri_list in dag_references.items():
+ dag_refs_needed = {
+ DagScheduleDatasetReference(dataset_id=stored_datasets[uri].id, dag_id=dag_id)
+ for uri in uri_list
+ }
+ dag_refs_stored = set(
+ existing_dags.get(dag_id)
+ and existing_dags.get(dag_id).schedule_dataset_references # type: ignore
+ or []
+ )
+ dag_refs_to_add = {x for x in dag_refs_needed if x not in dag_refs_stored}
+ session.bulk_save_objects(dag_refs_to_add)
+ for obj in dag_refs_stored - dag_refs_needed:
+ session.delete(obj)
+
+ existing_task_outlet_refs_dict = collections.defaultdict(set)
+ for dag_id, orm_dag in existing_dags.items():
+ for todr in orm_dag.task_outlet_dataset_references:
+ existing_task_outlet_refs_dict[(dag_id, todr.task_id)].add(todr)
+
+ # reconcile task-outlet-dataset references
+ for (dag_id, task_id), uri_list in outlet_references.items():
+ task_refs_needed = {
+ TaskOutletDatasetReference(dataset_id=stored_datasets[uri].id, dag_id=dag_id, task_id=task_id)
+ for uri in uri_list
+ }
+ task_refs_stored = existing_task_outlet_refs_dict[(dag_id, task_id)]
+ task_refs_to_add = {x for x in task_refs_needed if x not in task_refs_stored}
+ session.bulk_save_objects(task_refs_to_add)
+ for obj in task_refs_stored - task_refs_needed:
+ session.delete(obj)
+
# Issue SQL/finish "Unit of Work", but let @provide_session commit (or if passed a session, let caller
# decide when to commit
session.flush()
for dag in dags:
- cls.bulk_write_to_db(dag.subdags, session=session)
+ cls.bulk_write_to_db(dag.subdags, processor_subdir=processor_subdir, session=session)
@provide_session
- def sync_to_db(self, session=NEW_SESSION):
+ def sync_to_db(self, processor_subdir: str | None = None, session=NEW_SESSION):
"""
Save attributes about this DAG to the DB. Note that this method
can be called for both DAGs and SubDAGs. A SubDag is actually a
@@ -2485,12 +2901,12 @@ def sync_to_db(self, session=NEW_SESSION):
:return: None
"""
- self.bulk_write_to_db([self], session)
+ self.bulk_write_to_db([self], processor_subdir=processor_subdir, session=session)
def get_default_view(self):
"""This is only there for backward compatible jinja2 templates"""
if self.default_view is None:
- return conf.get('webserver', 'dag_default_view').lower()
+ return conf.get("webserver", "dag_default_view").lower()
else:
return self.default_view
@@ -2538,7 +2954,7 @@ def deactivate_stale_dags(expiration_date, session=NEW_SESSION):
@staticmethod
@provide_session
- def get_num_task_instances(dag_id, task_ids=None, states=None, session=NEW_SESSION):
+ def get_num_task_instances(dag_id, task_ids=None, states=None, session=NEW_SESSION) -> int:
"""
Returns the number of task instances in the given DAG.
@@ -2547,7 +2963,6 @@ def get_num_task_instances(dag_id, task_ids=None, states=None, session=NEW_SESSI
:param task_ids: A list of valid task IDs for the given DAG
:param states: A list of states to filter by if supplied
:return: The number of running tasks
- :rtype: int
"""
qry = session.query(func.count(TaskInstance.task_id)).filter(
TaskInstance.dag_id == dag_id,
@@ -2574,28 +2989,32 @@ def get_num_task_instances(dag_id, task_ids=None, states=None, session=NEW_SESSI
def get_serialized_fields(cls):
"""Stringified DAGs and operators contain exactly these fields."""
if not cls.__serialized_fields:
- cls.__serialized_fields = frozenset(vars(DAG(dag_id='test')).keys()) - {
- 'parent_dag',
- '_old_context_manager_dags',
- 'safe_dag_id',
- 'last_loaded',
- 'user_defined_filters',
- 'user_defined_macros',
- 'partial',
- 'params',
- '_pickle_id',
- '_log',
- 'task_dict',
- 'template_searchpath',
- 'sla_miss_callback',
- 'on_success_callback',
- 'on_failure_callback',
- 'template_undefined',
- 'jinja_environment_kwargs',
+ exclusion_list = {
+ "parent_dag",
+ "schedule_dataset_references",
+ "task_outlet_dataset_references",
+ "_old_context_manager_dags",
+ "safe_dag_id",
+ "last_loaded",
+ "user_defined_filters",
+ "user_defined_macros",
+ "partial",
+ "params",
+ "_pickle_id",
+ "_log",
+ "task_dict",
+ "template_searchpath",
+ "sla_miss_callback",
+ "on_success_callback",
+ "on_failure_callback",
+ "template_undefined",
+ "jinja_environment_kwargs",
# has_on_*_callback are only stored if the value is True, as the default is False
- 'has_on_success_callback',
- 'has_on_failure_callback',
+ "has_on_success_callback",
+ "has_on_failure_callback",
+ "auto_register",
}
+ cls.__serialized_fields = frozenset(vars(DAG(dag_id="test")).keys()) - exclusion_list
return cls.__serialized_fields
def get_edge_info(self, upstream_task_id: str, downstream_task_id: str) -> EdgeInfoType:
@@ -2632,15 +3051,28 @@ def validate_schedule_and_params(self):
"DAG Schedule must be None, if there are any required params without default values"
)
+ def iter_invalid_owner_links(self) -> Iterator[tuple[str, str]]:
+ """Parses a given link, and verifies if it's a valid URL, or a 'mailto' link.
+ Returns an iterator of invalid (owner, link) pairs.
+ """
+ for owner, link in self.owner_links.items():
+ result = urlsplit(link)
+ if result.scheme == "mailto":
+ # netloc is not existing for 'mailto' link, so we are checking that the path is parsed
+ if not result.path:
+ yield result.path, link
+ elif not result.scheme or not result.netloc:
+ yield owner, link
+
class DagTag(Base):
"""A tag name per dag, to allow quick filtering in the DAG view."""
__tablename__ = "dag_tag"
- name = Column(String(100), primary_key=True)
+ name = Column(String(TAG_MAX_LEN), primary_key=True)
dag_id = Column(
- String(ID_LEN),
- ForeignKey('dag.dag_id', name='dag_tag_dag_id_fkey', ondelete='CASCADE'),
+ StringID(),
+ ForeignKey("dag.dag_id", name="dag_tag_dag_id_fkey", ondelete="CASCADE"),
primary_key=True,
)
@@ -2648,6 +3080,33 @@ def __repr__(self):
return self.name
+class DagOwnerAttributes(Base):
+ """
+ Table defining different owner attributes. For example, a link for an owner that will be passed as
+ a hyperlink to the DAGs view
+ """
+
+ __tablename__ = "dag_owner_attributes"
+ dag_id = Column(
+ StringID(),
+ ForeignKey("dag.dag_id", name="dag.dag_id", ondelete="CASCADE"),
+ nullable=False,
+ primary_key=True,
+ )
+ owner = Column(String(500), primary_key=True, nullable=False)
+ link = Column(String(500), nullable=False)
+
+ def __repr__(self):
+ return f""
+
+ @classmethod
+ def get_all(cls, session) -> dict[str, dict[str, str]]:
+ dag_links: dict = collections.defaultdict(dict)
+ for obj in session.query(cls):
+ dag_links[obj.dag_id].update({obj.owner: obj.link})
+ return dag_links
+
+
class DagModel(Base):
"""Table containing DAG properties"""
@@ -2655,11 +3114,11 @@ class DagModel(Base):
"""
These items are stored in the database for state related information
"""
- dag_id = Column(String(ID_LEN), primary_key=True)
- root_dag_id = Column(String(ID_LEN))
+ dag_id = Column(StringID(), primary_key=True)
+ root_dag_id = Column(StringID())
# A DAG can be paused from the UI / DB
# Set this default value of is_paused based on a configuration value!
- is_paused_at_creation = conf.getboolean('core', 'dags_are_paused_at_creation')
+ is_paused_at_creation = conf.getboolean("core", "dags_are_paused_at_creation")
is_paused = Column(Boolean, default=is_paused_at_creation)
# Whether the DAG is a subdag
is_subdag = Column(Boolean, default=False)
@@ -2681,6 +3140,8 @@ class DagModel(Base):
# packaged DAG, it will point to the subpath of the DAG within the
# associated zip.
fileloc = Column(String(2000))
+ # The base directory used by Dag Processor that parsed this dag.
+ processor_subdir = Column(String(2000), nullable=True)
# String representing the owners
owners = Column(String(2000))
# Description of the dag
@@ -2691,15 +3152,18 @@ class DagModel(Base):
schedule_interval = Column(Interval)
# Timetable/Schedule Interval description
timetable_description = Column(String(1000), nullable=True)
-
# Tags for view filter
- tags = relationship('DagTag', cascade='all, delete, delete-orphan', backref=backref('dag'))
+ tags = relationship("DagTag", cascade="all, delete, delete-orphan", backref=backref("dag"))
+ # Dag owner links for DAGs view
+ dag_owner_links = relationship(
+ "DagOwnerAttributes", cascade="all, delete, delete-orphan", backref=backref("dag")
+ )
max_active_tasks = Column(Integer, nullable=False)
max_active_runs = Column(Integer, nullable=True)
has_task_concurrency_limits = Column(Boolean, nullable=False)
- has_import_errors = Column(Boolean(), default=False)
+ has_import_errors = Column(Boolean(), default=False, server_default="0")
# The logical date of the next dag run.
next_dagrun = Column(UtcDateTime)
@@ -2712,15 +3176,23 @@ class DagModel(Base):
next_dagrun_create_after = Column(UtcDateTime)
__table_args__ = (
- Index('idx_root_dag_id', root_dag_id, unique=False),
- Index('idx_next_dagrun_create_after', next_dagrun_create_after, unique=False),
+ Index("idx_root_dag_id", root_dag_id, unique=False),
+ Index("idx_next_dagrun_create_after", next_dagrun_create_after, unique=False),
)
parent_dag = relationship(
"DagModel", remote_side=[dag_id], primaryjoin=root_dag_id == dag_id, foreign_keys=[root_dag_id]
)
-
- NUM_DAGS_PER_DAGRUN_QUERY = conf.getint('scheduler', 'max_dagruns_to_create_per_loop', fallback=10)
+ schedule_dataset_references = relationship(
+ "DagScheduleDatasetReference",
+ cascade="all, delete, delete-orphan",
+ )
+ schedule_datasets = association_proxy("schedule_dataset_references", "dataset")
+ task_outlet_dataset_references = relationship(
+ "TaskOutletDatasetReference",
+ cascade="all, delete, delete-orphan",
+ )
+ NUM_DAGS_PER_DAGRUN_QUERY = conf.getint("scheduler", "max_dagruns_to_create_per_loop", fallback=10)
def __init__(self, concurrency=None, **kwargs):
super().__init__(**kwargs)
@@ -2728,15 +3200,15 @@ def __init__(self, concurrency=None, **kwargs):
if concurrency:
warnings.warn(
"The 'DagModel.concurrency' parameter is deprecated. Please use 'max_active_tasks'.",
- DeprecationWarning,
+ RemovedInAirflow3Warning,
stacklevel=2,
)
self.max_active_tasks = concurrency
else:
- self.max_active_tasks = conf.getint('core', 'max_active_tasks_per_dag')
+ self.max_active_tasks = conf.getint("core", "max_active_tasks_per_dag")
if self.max_active_runs is None:
- self.max_active_runs = conf.getint('core', 'max_active_runs_per_dag')
+ self.max_active_runs = conf.getint("core", "max_active_runs_per_dag")
if self.has_task_concurrency_limits is None:
# Be safe -- this will be updated later once the DAG is parsed
@@ -2746,7 +3218,7 @@ def __repr__(self):
return f""
@property
- def next_dagrun_data_interval(self) -> Optional[DataInterval]:
+ def next_dagrun_data_interval(self) -> DataInterval | None:
return _get_model_data_interval(
self,
"next_dagrun_data_interval_start",
@@ -2754,7 +3226,7 @@ def next_dagrun_data_interval(self) -> Optional[DataInterval]:
)
@next_dagrun_data_interval.setter
- def next_dagrun_data_interval(self, value: Optional[Tuple[datetime, datetime]]) -> None:
+ def next_dagrun_data_interval(self, value: tuple[datetime, datetime] | None) -> None:
if value is None:
self.next_dagrun_data_interval_start = self.next_dagrun_data_interval_end = None
else:
@@ -2774,28 +3246,19 @@ def get_dagmodel(dag_id, session=NEW_SESSION):
def get_current(cls, dag_id, session=NEW_SESSION):
return session.query(cls).filter(cls.dag_id == dag_id).first()
- @staticmethod
- @provide_session
- def get_all_paused_dag_ids(session: Session = NEW_SESSION) -> Set[str]:
- """Get a set of paused DAG ids"""
- paused_dag_ids = session.query(DagModel.dag_id).filter(DagModel.is_paused == expression.true()).all()
-
- paused_dag_ids = {paused_dag_id for paused_dag_id, in paused_dag_ids}
- return paused_dag_ids
-
@provide_session
def get_last_dagrun(self, session=NEW_SESSION, include_externally_triggered=False):
return get_last_dagrun(
self.dag_id, session=session, include_externally_triggered=include_externally_triggered
)
- def get_is_paused(self, *, session: Optional[Session] = None) -> bool:
+ def get_is_paused(self, *, session: Session | None = None) -> bool:
"""Provide interface compatibility to 'DAG'."""
return self.is_paused
@staticmethod
@provide_session
- def get_paused_dag_ids(dag_ids: List[str], session: Session = NEW_SESSION) -> Set[str]:
+ def get_paused_dag_ids(dag_ids: list[str], session: Session = NEW_SESSION) -> set[str]:
"""
Given a list of dag_ids, get a set of Paused Dag Ids
@@ -2819,14 +3282,14 @@ def get_default_view(self) -> str:
have a value
"""
# This is for backwards-compatibility with old dags that don't have None as default_view
- return self.default_view or conf.get_mandatory_value('webserver', 'dag_default_view').lower()
+ return self.default_view or conf.get_mandatory_value("webserver", "dag_default_view").lower()
@property
def safe_dag_id(self):
- return self.dag_id.replace('.', '__dot__')
+ return self.dag_id.replace(".", "__dot__")
@property
- def relative_fileloc(self) -> Optional[pathlib.Path]:
+ def relative_fileloc(self) -> pathlib.Path | None:
"""File location of the importable dag 'file' relative to the configured DAGs folder."""
if self.fileloc is None:
return None
@@ -2852,13 +3315,13 @@ def set_is_paused(self, is_paused: bool, including_subdags: bool = True, session
if including_subdags:
filter_query.append(DagModel.root_dag_id == self.dag_id)
session.query(DagModel).filter(or_(*filter_query)).update(
- {DagModel.is_paused: is_paused}, synchronize_session='fetch'
+ {DagModel.is_paused: is_paused}, synchronize_session="fetch"
)
session.commit()
@classmethod
@provide_session
- def deactivate_deleted_dags(cls, alive_dag_filelocs: List[str], session=NEW_SESSION):
+ def deactivate_deleted_dags(cls, alive_dag_filelocs: list[str], session=NEW_SESSION):
"""
Set ``is_active=False`` on the DAGs for which the DAG files have been removed.
@@ -2876,36 +3339,74 @@ def deactivate_deleted_dags(cls, alive_dag_filelocs: List[str], session=NEW_SESS
continue
@classmethod
- def dags_needing_dagruns(cls, session: Session):
+ def dags_needing_dagruns(cls, session: Session) -> tuple[Query, dict[str, tuple[datetime, datetime]]]:
"""
Return (and lock) a list of Dag objects that are due to create a new DagRun.
- This will return a resultset of rows that is row-level-locked with a "SELECT ... FOR UPDATE" query,
+ This will return a resultset of rows that is row-level-locked with a "SELECT ... FOR UPDATE" query,
you should ensure that any scheduling decisions are made in a single transaction -- as soon as the
transaction is committed it will be unlocked.
"""
- # TODO[HA]: Bake this query, it is run _A lot_
- # We limit so that _one_ scheduler doesn't try to do all the creation
- # of dag runs
+ from airflow.models.dataset import DagScheduleDatasetReference, DatasetDagRunQueue as DDRQ
+
+ # these dag ids are triggered by datasets, and they are ready to go.
+ dataset_triggered_dag_info = {
+ x.dag_id: (x.first_queued_time, x.last_queued_time)
+ for x in session.query(
+ DagScheduleDatasetReference.dag_id,
+ func.max(DDRQ.created_at).label("last_queued_time"),
+ func.min(DDRQ.created_at).label("first_queued_time"),
+ )
+ .join(DagScheduleDatasetReference.queue_records, isouter=True)
+ .group_by(DagScheduleDatasetReference.dag_id)
+ .having(func.count() == func.sum(case((DDRQ.target_dag_id.is_not(None), 1), else_=0)))
+ .all()
+ }
+ dataset_triggered_dag_ids = set(dataset_triggered_dag_info.keys())
+ if dataset_triggered_dag_ids:
+ exclusion_list = {
+ x.dag_id
+ for x in (
+ session.query(DagModel.dag_id)
+ .join(DagRun.dag_model)
+ .filter(DagRun.state.in_((DagRunState.QUEUED, DagRunState.RUNNING)))
+ .filter(DagModel.dag_id.in_(dataset_triggered_dag_ids))
+ .group_by(DagModel.dag_id)
+ .having(func.count() >= func.max(DagModel.max_active_runs))
+ .all()
+ )
+ }
+ if exclusion_list:
+ dataset_triggered_dag_ids -= exclusion_list
+ dataset_triggered_dag_info = {
+ k: v for k, v in dataset_triggered_dag_info.items() if k not in exclusion_list
+ }
+ # We limit so that _one_ scheduler doesn't try to do all the creation of dag runs
query = (
session.query(cls)
.filter(
cls.is_paused == expression.false(),
cls.is_active == expression.true(),
cls.has_import_errors == expression.false(),
- cls.next_dagrun_create_after <= func.now(),
+ or_(
+ cls.next_dagrun_create_after <= func.now(),
+ cls.dag_id.in_(dataset_triggered_dag_ids),
+ ),
)
.order_by(cls.next_dagrun_create_after)
.limit(cls.NUM_DAGS_PER_DAGRUN_QUERY)
)
- return with_row_locks(query, of=cls, session=session, **skip_locked(session=session))
+ return (
+ with_row_locks(query, of=cls, session=session, **skip_locked(session=session)),
+ dataset_triggered_dag_info,
+ )
def calculate_dagrun_date_fields(
self,
dag: DAG,
- most_recent_dag_run: Union[None, datetime, DataInterval],
+ most_recent_dag_run: None | datetime | DataInterval,
) -> None:
"""
Calculate ``next_dagrun`` and `next_dagrun_create_after``
@@ -2914,12 +3415,12 @@ def calculate_dagrun_date_fields(
:param most_recent_dag_run: DataInterval (or datetime) of most recent run of this dag, or none
if not yet scheduled.
"""
- most_recent_data_interval: Optional[DataInterval]
+ most_recent_data_interval: DataInterval | None
if isinstance(most_recent_dag_run, datetime):
warnings.warn(
"Passing a datetime to `DagModel.calculate_dagrun_date_fields` is deprecated. "
"Provide a data interval instead.",
- DeprecationWarning,
+ RemovedInAirflow3Warning,
stacklevel=2,
)
most_recent_data_interval = dag.infer_automated_data_interval(most_recent_dag_run)
@@ -2940,8 +3441,49 @@ def calculate_dagrun_date_fields(
self.next_dagrun_create_after,
)
-
-def dag(*dag_args, **dag_kwargs):
+ @provide_session
+ def get_dataset_triggered_next_run_info(self, *, session=NEW_SESSION) -> dict[str, int | str] | None:
+ if self.schedule_interval != "Dataset":
+ return None
+ return get_dataset_triggered_next_run_info([self.dag_id], session=session)[self.dag_id]
+
+
+# NOTE: Please keep the list of arguments in sync with DAG.__init__.
+# Only exception: dag_id here should have a default value, but not in DAG.
+def dag(
+ dag_id: str = "",
+ description: str | None = None,
+ schedule: ScheduleArg = NOTSET,
+ schedule_interval: ScheduleIntervalArg = NOTSET,
+ timetable: Timetable | None = None,
+ start_date: datetime | None = None,
+ end_date: datetime | None = None,
+ full_filepath: str | None = None,
+ template_searchpath: str | Iterable[str] | None = None,
+ template_undefined: type[jinja2.StrictUndefined] = jinja2.StrictUndefined,
+ user_defined_macros: dict | None = None,
+ user_defined_filters: dict | None = None,
+ default_args: dict | None = None,
+ concurrency: int | None = None,
+ max_active_tasks: int = conf.getint("core", "max_active_tasks_per_dag"),
+ max_active_runs: int = conf.getint("core", "max_active_runs_per_dag"),
+ dagrun_timeout: timedelta | None = None,
+ sla_miss_callback: SLAMissCallback | None = None,
+ default_view: str = conf.get_mandatory_value("webserver", "dag_default_view").lower(),
+ orientation: str = conf.get_mandatory_value("webserver", "dag_orientation"),
+ catchup: bool = conf.getboolean("scheduler", "catchup_by_default"),
+ on_success_callback: DagStateChangeCallback | None = None,
+ on_failure_callback: DagStateChangeCallback | None = None,
+ doc_md: str | None = None,
+ params: dict | None = None,
+ access_control: dict | None = None,
+ is_paused_upon_creation: bool | None = None,
+ jinja_environment_kwargs: dict | None = None,
+ render_template_as_native_obj: bool = False,
+ tags: list[str] | None = None,
+ owner_links: dict[str, str] | None = None,
+ auto_register: bool = True,
+) -> Callable[[Callable], Callable[..., DAG]]:
"""
Python dag decorator. Wraps a function into an Airflow DAG.
Accepts kwargs for operator kwarg. Can be used to parameterize DAGs.
@@ -2950,26 +3492,51 @@ def dag(*dag_args, **dag_kwargs):
:param dag_kwargs: Kwargs for DAG object.
"""
- def wrapper(f: Callable):
- # Get dag initializer signature and bind it to validate that dag_args, and dag_kwargs are correct
- dag_sig = signature(DAG.__init__)
- dag_bound_args = dag_sig.bind_partial(*dag_args, **dag_kwargs)
-
+ def wrapper(f: Callable) -> Callable[..., DAG]:
@functools.wraps(f)
def factory(*args, **kwargs):
# Generate signature for decorated function and bind the arguments when called
- # we do this to extract parameters so we can annotate them on the DAG object.
+ # we do this to extract parameters, so we can annotate them on the DAG object.
# In addition, this fails if we are missing any args/kwargs with TypeError as expected.
f_sig = signature(f).bind(*args, **kwargs)
# Apply defaults to capture default values if set.
f_sig.apply_defaults()
- # Set function name as dag_id if not set
- dag_id = dag_bound_args.arguments.get('dag_id', f.__name__)
- dag_bound_args.arguments['dag_id'] = dag_id
-
# Initialize DAG with bound arguments
- with DAG(*dag_bound_args.args, **dag_bound_args.kwargs) as dag_obj:
+ with DAG(
+ dag_id or f.__name__,
+ description=description,
+ schedule_interval=schedule_interval,
+ timetable=timetable,
+ start_date=start_date,
+ end_date=end_date,
+ full_filepath=full_filepath,
+ template_searchpath=template_searchpath,
+ template_undefined=template_undefined,
+ user_defined_macros=user_defined_macros,
+ user_defined_filters=user_defined_filters,
+ default_args=default_args,
+ concurrency=concurrency,
+ max_active_tasks=max_active_tasks,
+ max_active_runs=max_active_runs,
+ dagrun_timeout=dagrun_timeout,
+ sla_miss_callback=sla_miss_callback,
+ default_view=default_view,
+ orientation=orientation,
+ catchup=catchup,
+ on_success_callback=on_success_callback,
+ on_failure_callback=on_failure_callback,
+ doc_md=doc_md,
+ params=params,
+ access_control=access_control,
+ is_paused_upon_creation=is_paused_upon_creation,
+ jinja_environment_kwargs=jinja_environment_kwargs,
+ render_template_as_native_obj=render_template_as_native_obj,
+ tags=tags,
+ schedule=schedule,
+ owner_links=owner_links,
+ auto_register=auto_register,
+ ) as dag_obj:
# Set DAG documentation from function documentation.
if f.__doc__:
dag_obj.doc_md = f.__doc__
@@ -2990,13 +3557,15 @@ def factory(*args, **kwargs):
# Return dag object such that it's accessible in Globals.
return dag_obj
+ # Ensure that warnings from inside DAG() are emitted from the caller, not here
+ fixup_decorator_warning_stack(factory)
return factory
return wrapper
STATICA_HACK = True
-globals()['kcah_acitats'[::-1].upper()] = False
+globals()["kcah_acitats"[::-1].upper()] = False
if STATICA_HACK: # pragma: no cover
from airflow.models.serialized_dag import SerializedDagModel
@@ -3016,7 +3585,7 @@ class DagContext:
with DAG(
dag_id="example_dag",
default_args=default_args,
- schedule_interval="0 0 * * *",
+ schedule="0 0 * * *",
dagrun_timeout=timedelta(minutes=60),
) as dag:
...
@@ -3026,24 +3595,91 @@ class DagContext:
"""
- _context_managed_dag: Optional[DAG] = None
- _previous_context_managed_dags: List[DAG] = []
+ _context_managed_dags: Deque[DAG] = deque()
+ autoregistered_dags: set[tuple[DAG, ModuleType]] = set()
+ current_autoregister_module_name: str | None = None
@classmethod
def push_context_managed_dag(cls, dag: DAG):
- if cls._context_managed_dag:
- cls._previous_context_managed_dags.append(cls._context_managed_dag)
- cls._context_managed_dag = dag
+ cls._context_managed_dags.appendleft(dag)
@classmethod
- def pop_context_managed_dag(cls) -> Optional[DAG]:
- old_dag = cls._context_managed_dag
- if cls._previous_context_managed_dags:
- cls._context_managed_dag = cls._previous_context_managed_dags.pop()
- else:
- cls._context_managed_dag = None
- return old_dag
+ def pop_context_managed_dag(cls) -> DAG | None:
+ dag = cls._context_managed_dags.popleft()
+
+ # In a few cases around serialization we explicitly push None in to the stack
+ if cls.current_autoregister_module_name is not None and dag and dag.auto_register:
+ mod = sys.modules[cls.current_autoregister_module_name]
+ cls.autoregistered_dags.add((dag, mod))
+
+ return dag
@classmethod
- def get_current_dag(cls) -> Optional[DAG]:
- return cls._context_managed_dag
+ def get_current_dag(cls) -> DAG | None:
+ try:
+ return cls._context_managed_dags[0]
+ except IndexError:
+ return None
+
+
+def _run_task(ti: TaskInstance, session):
+ """
+ Run a single task instance, and push result to Xcom for downstream tasks. Bypasses a lot of
+ extra steps used in `task.run` to keep our local running as fast as possible
+ This function is only meant for the `dag.test` function as a helper function.
+
+ Args:
+ ti: TaskInstance to run
+ """
+ log.info("*****************************************************")
+ if ti.map_index > 0:
+ log.info("Running task %s index %d", ti.task_id, ti.map_index)
+ else:
+ log.info("Running task %s", ti.task_id)
+ try:
+ ti._run_raw_task(session=session)
+ session.flush()
+ log.info("%s ran successfully!", ti.task_id)
+ except AirflowSkipException:
+ log.info("Task Skipped, continuing")
+ log.info("*****************************************************")
+
+
+def _get_or_create_dagrun(
+ dag: DAG,
+ conf: dict[Any, Any] | None,
+ start_date: datetime,
+ execution_date: datetime,
+ run_id: str,
+ session: Session,
+) -> DagRun:
+ """
+ Create a DAGRun, but only after clearing the previous instance of said dagrun to prevent collisions.
+ This function is only meant for the `dag.test` function as a helper function.
+ :param dag: Dag to be used to find dagrun
+ :param conf: configuration to pass to newly created dagrun
+ :param start_date: start date of new dagrun, defaults to execution_date
+ :param execution_date: execution_date for finding the dagrun
+ :param run_id: run_id to pass to new dagrun
+ :param session: sqlalchemy session
+ :return:
+ """
+ log.info("dagrun id: %s", dag.dag_id)
+ dr: DagRun = (
+ session.query(DagRun)
+ .filter(DagRun.dag_id == dag.dag_id, DagRun.execution_date == execution_date)
+ .first()
+ )
+ if dr:
+ session.delete(dr)
+ session.commit()
+ dr = dag.create_dagrun(
+ state=DagRunState.RUNNING,
+ execution_date=execution_date,
+ run_id=run_id,
+ start_date=start_date or execution_date,
+ session=session,
+ conf=conf, # type: ignore
+ )
+ log.info("created dagrun " + str(dr))
+ return dr
diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py
index 3673ce095ea16..f78125ebd9d46 100644
--- a/airflow/models/dagbag.py
+++ b/airflow/models/dagbag.py
@@ -15,6 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
import hashlib
import importlib
@@ -27,7 +28,7 @@
import warnings
import zipfile
from datetime import datetime, timedelta
-from typing import TYPE_CHECKING, Dict, List, NamedTuple, Optional, Union
+from typing import TYPE_CHECKING, NamedTuple
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import Session
@@ -39,8 +40,10 @@
AirflowClusterPolicyViolation,
AirflowDagCycleException,
AirflowDagDuplicatedIdException,
+ AirflowDagInconsistent,
AirflowTimetableInvalid,
ParamValidationError,
+ RemovedInAirflow3Warning,
)
from airflow.stats import Stats
from airflow.utils import timezone
@@ -51,6 +54,7 @@
from airflow.utils.retries import MAX_DB_RETRIES, run_with_db_retries
from airflow.utils.session import provide_session
from airflow.utils.timeout import timeout
+from airflow.utils.types import NOTSET, ArgNotSet
if TYPE_CHECKING:
import pathlib
@@ -79,8 +83,6 @@ class DagBag(LoggingMixin):
:param dag_folder: the folder to scan to find DAGs
:param include_examples: whether to include the examples that ship
with airflow or not
- :param include_smart_sensor: whether to include the smart sensor native
- DAGs that create the smart sensor operators for whole cluster
:param read_dags_from_db: Read DAGs from DB if ``True`` is passed.
If ``False`` DAGs are read from python files.
:param load_op_links: Should the extra operator link be loaded via plugins when
@@ -90,49 +92,58 @@ class DagBag(LoggingMixin):
def __init__(
self,
- dag_folder: Union[str, "pathlib.Path", None] = None,
- include_examples: bool = conf.getboolean('core', 'LOAD_EXAMPLES'),
- include_smart_sensor: bool = conf.getboolean('smart_sensor', 'USE_SMART_SENSOR'),
- safe_mode: bool = conf.getboolean('core', 'DAG_DISCOVERY_SAFE_MODE'),
+ dag_folder: str | pathlib.Path | None = None,
+ include_examples: bool | ArgNotSet = NOTSET,
+ safe_mode: bool | ArgNotSet = NOTSET,
read_dags_from_db: bool = False,
- store_serialized_dags: Optional[bool] = None,
+ store_serialized_dags: bool | None = None,
load_op_links: bool = True,
+ collect_dags: bool = True,
):
# Avoid circular import
from airflow.models.dag import DAG
super().__init__()
+ include_examples = (
+ include_examples
+ if isinstance(include_examples, bool)
+ else conf.getboolean("core", "LOAD_EXAMPLES")
+ )
+ safe_mode = (
+ safe_mode if isinstance(safe_mode, bool) else conf.getboolean("core", "DAG_DISCOVERY_SAFE_MODE")
+ )
+
if store_serialized_dags:
warnings.warn(
"The store_serialized_dags parameter has been deprecated. "
"You should pass the read_dags_from_db parameter.",
- DeprecationWarning,
+ RemovedInAirflow3Warning,
stacklevel=2,
)
read_dags_from_db = store_serialized_dags
dag_folder = dag_folder or settings.DAGS_FOLDER
self.dag_folder = dag_folder
- self.dags: Dict[str, DAG] = {}
+ self.dags: dict[str, DAG] = {}
# the file's last modified timestamp when we last read it
- self.file_last_changed: Dict[str, datetime] = {}
- self.import_errors: Dict[str, str] = {}
+ self.file_last_changed: dict[str, datetime] = {}
+ self.import_errors: dict[str, str] = {}
self.has_logged = False
self.read_dags_from_db = read_dags_from_db
# Only used by read_dags_from_db=True
- self.dags_last_fetched: Dict[str, datetime] = {}
+ self.dags_last_fetched: dict[str, datetime] = {}
# Only used by SchedulerJob to compare the dag_hash to identify change in DAGs
- self.dags_hash: Dict[str, str] = {}
-
- self.dagbag_import_error_tracebacks = conf.getboolean('core', 'dagbag_import_error_tracebacks')
- self.dagbag_import_error_traceback_depth = conf.getint('core', 'dagbag_import_error_traceback_depth')
- self.collect_dags(
- dag_folder=dag_folder,
- include_examples=include_examples,
- include_smart_sensor=include_smart_sensor,
- safe_mode=safe_mode,
- )
+ self.dags_hash: dict[str, str] = {}
+
+ self.dagbag_import_error_tracebacks = conf.getboolean("core", "dagbag_import_error_tracebacks")
+ self.dagbag_import_error_traceback_depth = conf.getint("core", "dagbag_import_error_traceback_depth")
+ if collect_dags:
+ self.collect_dags(
+ dag_folder=dag_folder,
+ include_examples=include_examples,
+ safe_mode=safe_mode,
+ )
# Should the extra operator link be loaded via plugins?
# This flag is set to False in Scheduler so that Extra Operator links are not loaded
self.load_op_links = load_op_links
@@ -143,19 +154,20 @@ def size(self) -> int:
@property
def store_serialized_dags(self) -> bool:
- """Whether or not to read dags from DB"""
+ """Whether to read dags from DB"""
warnings.warn(
"The store_serialized_dags property has been deprecated. Use read_dags_from_db instead.",
- DeprecationWarning,
+ RemovedInAirflow3Warning,
stacklevel=2,
)
return self.read_dags_from_db
@property
- def dag_ids(self) -> List[str]:
+ def dag_ids(self) -> list[str]:
"""
+ Get DAG ids.
+
:return: a list of DAG IDs in this bag
- :rtype: List[unicode]
"""
return list(self.dags.keys())
@@ -164,7 +176,7 @@ def get_dag(self, dag_id, session: Session = None):
"""
Gets the DAG out of the dictionary, and refreshes it if expired
- :param dag_id: DAG Id
+ :param dag_id: DAG ID
"""
# Avoid circular import
from airflow.models.dag import DagModel
@@ -262,9 +274,12 @@ def process_file(self, filepath, only_if_updated=True, safe_mode=True):
Given a path to a python module or zip file, this method imports
the module and look for dag objects within it.
"""
+ from airflow.models.dag import DagContext
+
# if the source file no longer exists in the DB or in the filesystem,
# return an empty list
# todo: raise exception?
+
if filepath is None or not os.path.isfile(filepath):
return []
@@ -282,6 +297,9 @@ def process_file(self, filepath, only_if_updated=True, safe_mode=True):
self.log.exception(e)
return []
+ # Ensure we don't pick up anything else we didn't mean to
+ DagContext.autoregistered_dags.clear()
+
if filepath.endswith(".py") or not zipfile.is_zipfile(filepath):
mods = self._load_modules_from_file(filepath, safe_mode)
else:
@@ -293,6 +311,8 @@ def process_file(self, filepath, only_if_updated=True, safe_mode=True):
return found_dags
def _load_modules_from_file(self, filepath, safe_mode):
+ from airflow.models.dag import DagContext
+
if not might_contain_dag(filepath, safe_mode):
# Don't want to spam user with skip messages
if not self.has_logged:
@@ -302,12 +322,14 @@ def _load_modules_from_file(self, filepath, safe_mode):
self.log.debug("Importing %s", filepath)
org_mod_name, _ = os.path.splitext(os.path.split(filepath)[-1])
- path_hash = hashlib.sha1(filepath.encode('utf-8')).hexdigest()
- mod_name = f'unusual_prefix_{path_hash}_{org_mod_name}'
+ path_hash = hashlib.sha1(filepath.encode("utf-8")).hexdigest()
+ mod_name = f"unusual_prefix_{path_hash}_{org_mod_name}"
if mod_name in sys.modules:
del sys.modules[mod_name]
+ DagContext.current_autoregister_module_name = mod_name
+
def parse(mod_name, filepath):
try:
loader = importlib.machinery.SourceFileLoader(mod_name, filepath)
@@ -317,6 +339,7 @@ def parse(mod_name, filepath):
loader.exec_module(new_module)
return [new_module]
except Exception as e:
+ DagContext.autoregistered_dags.clear()
self.log.exception("Failed to import: %s", filepath)
if self.dagbag_import_error_tracebacks:
self.import_errors[filepath] = traceback.format_exc(
@@ -330,7 +353,7 @@ def parse(mod_name, filepath):
if not isinstance(dagbag_import_timeout, (int, float)):
raise TypeError(
- f'Value ({dagbag_import_timeout}) from get_dagbag_import_timeout must be int or float'
+ f"Value ({dagbag_import_timeout}) from get_dagbag_import_timeout must be int or float"
)
if dagbag_import_timeout <= 0: # no parsing timeout
@@ -346,6 +369,8 @@ def parse(mod_name, filepath):
return parse(mod_name, filepath)
def _load_modules_from_zip(self, filepath, safe_mode):
+ from airflow.models.dag import DagContext
+
mods = []
with zipfile.ZipFile(filepath) as current_zip_file:
for zip_info in current_zip_file.infolist():
@@ -356,7 +381,7 @@ def _load_modules_from_zip(self, filepath, safe_mode):
if head:
continue
- if mod_name == '__init__':
+ if mod_name == "__init__":
self.log.warning("Found __init__.%s at root of %s", ext, filepath)
self.log.debug("Reading %s from %s", zip_info.filename, filepath)
@@ -374,11 +399,13 @@ def _load_modules_from_zip(self, filepath, safe_mode):
if mod_name in sys.modules:
del sys.modules[mod_name]
+ DagContext.current_autoregister_module_name = mod_name
try:
sys.path.insert(0, filepath)
current_module = importlib.import_module(mod_name)
mods.append(current_module)
except Exception as e:
+ DagContext.autoregistered_dags.clear()
fileloc = os.path.join(filepath, zip_info.filename)
self.log.exception("Failed to import: %s", fileloc)
if self.dagbag_import_error_tracebacks:
@@ -393,34 +420,39 @@ def _load_modules_from_zip(self, filepath, safe_mode):
return mods
def _process_modules(self, filepath, mods, file_last_changed_on_disk):
- from airflow.models.dag import DAG # Avoid circular import
+ from airflow.models.dag import DAG, DagContext # Avoid circular import
+
+ top_level_dags = {(o, m) for m in mods for o in m.__dict__.values() if isinstance(o, DAG)}
- top_level_dags = ((o, m) for m in mods for o in m.__dict__.values() if isinstance(o, DAG))
+ top_level_dags.update(DagContext.autoregistered_dags)
+
+ DagContext.current_autoregister_module_name = None
+ DagContext.autoregistered_dags.clear()
found_dags = []
for (dag, mod) in top_level_dags:
dag.fileloc = mod.__file__
try:
- dag.timetable.validate()
- # validate dag params
- dag.params.validate()
+ dag.validate()
self.bag_dag(dag=dag, root_dag=dag)
- found_dags.append(dag)
- found_dags += dag.subdags
except AirflowTimetableInvalid as exception:
self.log.exception("Failed to bag_dag: %s", dag.fileloc)
self.import_errors[dag.fileloc] = f"Invalid timetable expression: {exception}"
self.file_last_changed[dag.fileloc] = file_last_changed_on_disk
except (
+ AirflowClusterPolicyViolation,
AirflowDagCycleException,
AirflowDagDuplicatedIdException,
- AirflowClusterPolicyViolation,
+ AirflowDagInconsistent,
ParamValidationError,
) as exception:
self.log.exception("Failed to bag_dag: %s", dag.fileloc)
self.import_errors[dag.fileloc] = str(exception)
self.file_last_changed[dag.fileloc] = file_last_changed_on_disk
+ else:
+ found_dags.append(dag)
+ found_dags += dag.subdags
return found_dags
def bag_dag(self, dag, root_dag):
@@ -468,10 +500,10 @@ def _bag_dag(self, *, dag, root_dag, recursive):
existing=self.dags[dag.dag_id].fileloc,
)
self.dags[dag.dag_id] = dag
- self.log.debug('Loaded DAG %s', dag)
+ self.log.debug("Loaded DAG %s", dag)
except (AirflowDagCycleException, AirflowDagDuplicatedIdException):
# There was an error in bagging the dag. Remove it from the list of dags
- self.log.exception('Exception bagging dag: %s', dag.dag_id)
+ self.log.exception("Exception bagging dag: %s", dag.dag_id)
# Only necessary at the root level since DAG.subdags automatically
# performs DFS to search through all subdags
if recursive:
@@ -482,11 +514,10 @@ def _bag_dag(self, *, dag, root_dag, recursive):
def collect_dags(
self,
- dag_folder: Union[str, "pathlib.Path", None] = None,
+ dag_folder: str | pathlib.Path | None = None,
only_if_updated: bool = True,
- include_examples: bool = conf.getboolean('core', 'LOAD_EXAMPLES'),
- include_smart_sensor: bool = conf.getboolean('smart_sensor', 'USE_SMART_SENSOR'),
- safe_mode: bool = conf.getboolean('core', 'DAG_DISCOVERY_SAFE_MODE'),
+ include_examples: bool = conf.getboolean("core", "LOAD_EXAMPLES"),
+ safe_mode: bool = conf.getboolean("core", "DAG_DISCOVERY_SAFE_MODE"),
):
"""
Given a file path or a folder, this method looks for python modules,
@@ -515,7 +546,6 @@ def collect_dags(
dag_folder,
safe_mode=safe_mode,
include_examples=include_examples,
- include_smart_sensor=include_smart_sensor,
):
try:
file_parse_start_dttm = timezone.utcnow()
@@ -524,7 +554,7 @@ def collect_dags(
file_parse_end_dttm = timezone.utcnow()
stats.append(
FileLoadStat(
- file=filepath.replace(settings.DAGS_FOLDER, ''),
+ file=filepath.replace(settings.DAGS_FOLDER, ""),
duration=file_parse_end_dttm - file_parse_start_dttm,
dag_num=len(found_dags),
task_num=sum(len(dag.tasks) for dag in found_dags),
@@ -540,7 +570,7 @@ def collect_dags_from_db(self):
"""Collects DAGs from database."""
from airflow.models.serialized_dag import SerializedDagModel
- with Stats.timer('collect_db_dags'):
+ with Stats.timer("collect_db_dags"):
self.log.info("Filling up the DagBag from database")
# The dagbag contains all rows in serialized_dag table. Deleted DAGs are deleted
@@ -579,7 +609,7 @@ def dagbag_report(self):
return report
@provide_session
- def sync_to_db(self, session: Session = None):
+ def sync_to_db(self, processor_subdir: str | None = None, session: Session = None):
"""Save attributes about list of DAG to the DB."""
# To avoid circular import - airflow.models.dagbag -> airflow.models.dag -> airflow.models.dagbag
from airflow.models.dag import DAG
@@ -626,7 +656,9 @@ def _serialize_dag_capturing_errors(dag, session):
for dag in self.dags.values():
serialize_errors.extend(_serialize_dag_capturing_errors(dag, session))
- DAG.bulk_write_to_db(self.dags.values(), session=session)
+ DAG.bulk_write_to_db(
+ self.dags.values(), processor_subdir=processor_subdir, session=session
+ )
except OperationalError:
session.rollback()
raise
@@ -640,6 +672,8 @@ def _sync_perm_for_dag(self, dag, session: Session = None):
from airflow.security.permissions import DAG_ACTIONS, resource_name_for_dag
from airflow.www.fab_security.sqla.models import Action, Permission, Resource
+ root_dag_id = dag.parent_dag.dag_id if dag.parent_dag else dag.dag_id
+
def needs_perms(dag_id: str) -> bool:
dag_resource_name = resource_name_for_dag(dag_id)
for permission_name in DAG_ACTIONS:
@@ -654,9 +688,9 @@ def needs_perms(dag_id: str) -> bool:
return True
return False
- if dag.access_control or needs_perms(dag.dag_id):
- self.log.debug("Syncing DAG permissions: %s to the DB", dag.dag_id)
+ if dag.access_control or needs_perms(root_dag_id):
+ self.log.debug("Syncing DAG permissions: %s to the DB", root_dag_id)
from airflow.www.security import ApplessAirflowSecurityManager
security_manager = ApplessAirflowSecurityManager(session=session)
- security_manager.sync_perm_for_dag(dag.dag_id, dag.access_control)
+ security_manager.sync_perm_for_dag(root_dag_id, dag.access_control)
diff --git a/airflow/models/dagcode.py b/airflow/models/dagcode.py
index 7322ba92fb76a..47b9588cc8d00 100644
--- a/airflow/models/dagcode.py
+++ b/airflow/models/dagcode.py
@@ -14,13 +14,16 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
import logging
import os
import struct
from datetime import datetime
-from typing import Iterable, List, Optional
+from typing import Iterable
from sqlalchemy import BigInteger, Column, String, Text
+from sqlalchemy.dialects.mysql import MEDIUMTEXT
from sqlalchemy.sql.expression import literal
from airflow.exceptions import AirflowException, DagCodeNotFound
@@ -41,15 +44,15 @@ class DagCode(Base):
For details on dag serialization see SerializedDagModel
"""
- __tablename__ = 'dag_code'
+ __tablename__ = "dag_code"
fileloc_hash = Column(BigInteger, nullable=False, primary_key=True, autoincrement=False)
fileloc = Column(String(2000), nullable=False)
# The max length of fileloc exceeds the limit of indexing.
last_updated = Column(UtcDateTime, nullable=False)
- source_code = Column(Text, nullable=False)
+ source_code = Column(Text().with_variant(MEDIUMTEXT(), "mysql"), nullable=False)
- def __init__(self, full_filepath: str, source_code: Optional[str] = None):
+ def __init__(self, full_filepath: str, source_code: str | None = None):
self.fileloc = full_filepath
self.fileloc_hash = DagCode.dag_fileloc_hash(self.fileloc)
self.last_updated = timezone.utcnow()
@@ -122,7 +125,7 @@ def bulk_sync_to_db(cls, filelocs: Iterable[str], session=None):
@classmethod
@provide_session
- def remove_deleted_code(cls, alive_dag_filelocs: List[str], session=None):
+ def remove_deleted_code(cls, alive_dag_filelocs: list[str], session=None):
"""Deletes code not included in alive_dag_filelocs.
:param alive_dag_filelocs: file paths of alive DAGs
@@ -134,7 +137,7 @@ def remove_deleted_code(cls, alive_dag_filelocs: List[str], session=None):
session.query(cls).filter(
cls.fileloc_hash.notin_(alive_fileloc_hashes), cls.fileloc.notin_(alive_dag_filelocs)
- ).delete(synchronize_session='fetch')
+ ).delete(synchronize_session="fetch")
@classmethod
@provide_session
@@ -166,7 +169,7 @@ def code(cls, fileloc) -> str:
@staticmethod
def _get_code_from_file(fileloc):
- with open_maybe_zipped(fileloc, 'r') as f:
+ with open_maybe_zipped(fileloc, "r") as f:
code = f.read()
return code
@@ -192,4 +195,4 @@ def dag_fileloc_hash(full_filepath: str) -> int:
import hashlib
# Only 7 bytes because MySQL BigInteger can hold only 8 bytes (signed).
- return struct.unpack('>Q', hashlib.sha1(full_filepath.encode('utf-8')).digest()[-8:])[0] >> 8
+ return struct.unpack(">Q", hashlib.sha1(full_filepath.encode("utf-8")).digest()[-8:])[0] >> 8
diff --git a/airflow/models/dagparam.py b/airflow/models/dagparam.py
index 83a2f2c05532b..f20bd078c80a6 100644
--- a/airflow/models/dagparam.py
+++ b/airflow/models/dagparam.py
@@ -14,15 +14,16 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""This module is deprecated. Please use :mod:`airflow.models.param`."""
+from __future__ import annotations
import warnings
+from airflow.exceptions import RemovedInAirflow3Warning
from airflow.models.param import DagParam # noqa
warnings.warn(
"This module is deprecated. Please use `airflow.models.param`.",
- DeprecationWarning,
+ RemovedInAirflow3Warning,
stacklevel=2,
)
diff --git a/airflow/models/dagpickle.py b/airflow/models/dagpickle.py
index aa56ce3e5884b..caa319e9840f4 100644
--- a/airflow/models/dagpickle.py
+++ b/airflow/models/dagpickle.py
@@ -15,9 +15,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
import dill
-from sqlalchemy import Column, Integer, PickleType, Text
+from sqlalchemy import BigInteger, Column, Integer, PickleType
from airflow.models.base import Base
from airflow.utils import timezone
@@ -39,13 +40,13 @@ class DagPickle(Base):
id = Column(Integer, primary_key=True)
pickle = Column(PickleType(pickler=dill))
created_dttm = Column(UtcDateTime, default=timezone.utcnow)
- pickle_hash = Column(Text)
+ pickle_hash = Column(BigInteger)
__tablename__ = "dag_pickle"
def __init__(self, dag):
self.dag_id = dag.dag_id
- if hasattr(dag, 'template_env'):
+ if hasattr(dag, "template_env"):
dag.template_env = None
self.pickle_hash = hash(dag)
self.pickle = dag
diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index eeec4d5b099b2..2c736c4c2efe5 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -15,34 +15,26 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
import itertools
import os
import warnings
from collections import defaultdict
from datetime import datetime
-from typing import (
- TYPE_CHECKING,
- Any,
- Dict,
- Generator,
- Iterable,
- List,
- NamedTuple,
- Optional,
- Sequence,
- Tuple,
- Union,
- cast,
-)
+from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, NamedTuple, Sequence, TypeVar, overload
from sqlalchemy import (
Boolean,
Column,
ForeignKey,
+ ForeignKeyConstraint,
Index,
Integer,
PickleType,
+ PrimaryKeyConstraint,
String,
+ Text,
UniqueConstraint,
and_,
func,
@@ -50,6 +42,7 @@
text,
)
from sqlalchemy.exc import IntegrityError
+from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm import joinedload, relationship, synonym
from sqlalchemy.orm.session import Session
@@ -58,14 +51,17 @@
from airflow import settings
from airflow.callbacks.callback_requests import DagCallbackRequest
from airflow.configuration import conf as airflow_conf
-from airflow.exceptions import AirflowException, TaskNotFound
-from airflow.models.base import COLLATION_ARGS, ID_LEN, Base
-from airflow.models.mappedoperator import MappedOperator
+from airflow.exceptions import AirflowException, RemovedInAirflow3Warning, TaskNotFound
+from airflow.listeners.listener import get_listener_manager
+from airflow.models.abstractoperator import NotMapped
+from airflow.models.base import Base, StringID
+from airflow.models.expandinput import NotFullyPopulated
from airflow.models.taskinstance import TaskInstance as TI
from airflow.models.tasklog import LogTemplate
from airflow.stats import Stats
from airflow.ti_deps.dep_context import DepContext
from airflow.ti_deps.dependencies_states import SCHEDULEABLE_STATES
+from airflow.typing_compat import Literal
from airflow.utils import timezone
from airflow.utils.helpers import is_container
from airflow.utils.log.logging_mixin import LoggingMixin
@@ -78,15 +74,28 @@
from airflow.models.dag import DAG
from airflow.models.operator import Operator
+ CreatedTasks = TypeVar("CreatedTasks", Iterator["dict[str, Any]"], Iterator[TI])
+ TaskCreator = Callable[[Operator, Iterable[int]], CreatedTasks]
+
class TISchedulingDecision(NamedTuple):
"""Type of return for DagRun.task_instance_scheduling_decisions"""
- tis: List[TI]
- schedulable_tis: List[TI]
+ tis: list[TI]
+ schedulable_tis: list[TI]
changed_tis: bool
- unfinished_tis: List[TI]
- finished_tis: List[TI]
+ unfinished_tis: list[TI]
+ finished_tis: list[TI]
+
+
+def _creator_note(val):
+ """Custom creator for the ``note`` association proxy."""
+ if isinstance(val, str):
+ return DagRunNote(content=val)
+ elif isinstance(val, dict):
+ return DagRunNote(**val)
+ else:
+ return DagRunNote(*val)
class DagRun(Base, LoggingMixin):
@@ -98,13 +107,13 @@ class DagRun(Base, LoggingMixin):
__tablename__ = "dag_run"
id = Column(Integer, primary_key=True)
- dag_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False)
+ dag_id = Column(StringID(), nullable=False)
queued_at = Column(UtcDateTime)
execution_date = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
start_date = Column(UtcDateTime)
end_date = Column(UtcDateTime)
- _state = Column('state', String(50), default=State.QUEUED)
- run_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False)
+ _state = Column("state", String(50), default=State.QUEUED)
+ run_id = Column(StringID(), nullable=False)
creating_job_id = Column(Integer)
external_trigger = Column(Boolean, default=True)
run_type = Column(String(50), nullable=False)
@@ -123,23 +132,24 @@ class DagRun(Base, LoggingMixin):
ForeignKey("log_template.id", name="task_instance_log_template_id_fkey", ondelete="NO ACTION"),
default=select([func.max(LogTemplate.__table__.c.id)]),
)
+ updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow)
# Remove this `if` after upgrading Sphinx-AutoAPI
if not TYPE_CHECKING and "BUILDING_AIRFLOW_DOCS" in os.environ:
- dag: "Optional[DAG]"
+ dag: DAG | None
else:
- dag: "Optional[DAG]" = None
+ dag: DAG | None = None
__table_args__ = (
- Index('dag_id_state', dag_id, _state),
- UniqueConstraint('dag_id', 'execution_date', name='dag_run_dag_id_execution_date_key'),
- UniqueConstraint('dag_id', 'run_id', name='dag_run_dag_id_run_id_key'),
- Index('idx_last_scheduling_decision', last_scheduling_decision),
- Index('idx_dag_run_dag_id', dag_id),
+ Index("dag_id_state", dag_id, _state),
+ UniqueConstraint("dag_id", "execution_date", name="dag_run_dag_id_execution_date_key"),
+ UniqueConstraint("dag_id", "run_id", name="dag_run_dag_id_run_id_key"),
+ Index("idx_last_scheduling_decision", last_scheduling_decision),
+ Index("idx_dag_run_dag_id", dag_id),
Index(
- 'idx_dag_run_running_dags',
- 'state',
- 'dag_id',
+ "idx_dag_run_running_dags",
+ "state",
+ "dag_id",
postgresql_where=text("state='running'"),
mssql_where=text("state='running'"),
sqlite_where=text("state='running'"),
@@ -147,9 +157,9 @@ class DagRun(Base, LoggingMixin):
# since mysql lacks filtered/partial indices, this creates a
# duplicate index on mysql. Not the end of the world
Index(
- 'idx_dag_run_queued_dags',
- 'state',
- 'dag_id',
+ "idx_dag_run_queued_dags",
+ "state",
+ "dag_id",
postgresql_where=text("state='queued'"),
mssql_where=text("state='queued'"),
sqlite_where=text("state='queued'"),
@@ -157,29 +167,37 @@ class DagRun(Base, LoggingMixin):
)
task_instances = relationship(
- TI, back_populates="dag_run", cascade='save-update, merge, delete, delete-orphan'
+ TI, back_populates="dag_run", cascade="save-update, merge, delete, delete-orphan"
+ )
+ dag_model = relationship(
+ "DagModel",
+ primaryjoin="foreign(DagRun.dag_id) == DagModel.dag_id",
+ uselist=False,
+ viewonly=True,
)
+ dag_run_note = relationship("DagRunNote", back_populates="dag_run", uselist=False)
+ note = association_proxy("dag_run_note", "content", creator=_creator_note)
DEFAULT_DAGRUNS_TO_EXAMINE = airflow_conf.getint(
- 'scheduler',
- 'max_dagruns_per_loop_to_schedule',
+ "scheduler",
+ "max_dagruns_per_loop_to_schedule",
fallback=20,
)
def __init__(
self,
- dag_id: Optional[str] = None,
- run_id: Optional[str] = None,
- queued_at: Union[datetime, None, ArgNotSet] = NOTSET,
- execution_date: Optional[datetime] = None,
- start_date: Optional[datetime] = None,
- external_trigger: Optional[bool] = None,
- conf: Optional[Any] = None,
- state: Optional[DagRunState] = None,
- run_type: Optional[str] = None,
- dag_hash: Optional[str] = None,
- creating_job_id: Optional[int] = None,
- data_interval: Optional[Tuple[datetime, datetime]] = None,
+ dag_id: str | None = None,
+ run_id: str | None = None,
+ queued_at: datetime | None | ArgNotSet = NOTSET,
+ execution_date: datetime | None = None,
+ start_date: datetime | None = None,
+ external_trigger: bool | None = None,
+ conf: Any | None = None,
+ state: DagRunState | None = None,
+ run_type: str | None = None,
+ dag_hash: str | None = None,
+ creating_job_id: int | None = None,
+ data_interval: tuple[datetime, datetime] | None = None,
):
if data_interval is None:
# Legacy: Only happen for runs created prior to Airflow 2.2.
@@ -206,11 +224,14 @@ def __init__(
def __repr__(self):
return (
- ''
+ ""
).format(
dag_id=self.dag_id,
execution_date=self.execution_date,
run_id=self.run_id,
+ state=self.state,
+ queued_at=self.queued_at,
external_trigger=self.external_trigger,
)
@@ -232,7 +253,7 @@ def set_state(self, state: DagRunState):
@declared_attr
def state(self):
- return synonym('_state', descriptor=property(self.get_state, self.set_state))
+ return synonym("_state", descriptor=property(self.get_state, self.set_state))
@provide_session
def refresh_from_db(self, session: Session = NEW_SESSION) -> None:
@@ -247,9 +268,9 @@ def refresh_from_db(self, session: Session = NEW_SESSION) -> None:
@classmethod
@provide_session
- def active_runs_of_dags(cls, dag_ids=None, only_running=False, session=None) -> Dict[str, int]:
+ def active_runs_of_dags(cls, dag_ids=None, only_running=False, session=None) -> dict[str, int]:
"""Get the number of active dag runs for each dag."""
- query = session.query(cls.dag_id, func.count('*'))
+ query = session.query(cls.dag_id, func.count("*"))
if dag_ids is not None:
# 'set' called to avoid duplicate dag_ids, but converted back to 'list'
# because SQLAlchemy doesn't accept a set here.
@@ -266,8 +287,8 @@ def next_dagruns_to_examine(
cls,
state: DagRunState,
session: Session,
- max_number: Optional[int] = None,
- ):
+ max_number: int | None = None,
+ ) -> list[DagRun]:
"""
Return the next DagRuns that the scheduler should attempt to schedule.
@@ -275,7 +296,6 @@ def next_dagruns_to_examine(
query, you should ensure that any scheduling decisions are made in a single transaction -- as soon as
the transaction is committed it will be unlocked.
- :rtype: list[airflow.models.DagRun]
"""
from airflow.models.dag import DagModel
@@ -285,6 +305,7 @@ def next_dagruns_to_examine(
# TODO: Bake this query, it is run _A lot_
query = (
session.query(cls)
+ .with_hint(cls, "USE INDEX (idx_dag_run_running_dags)", dialect_name="mysql")
.filter(cls.state == state, cls.run_type != DagRunType.BACKFILL_JOB)
.join(DagModel, DagModel.dag_id == cls.dag_id)
.filter(DagModel.is_paused == false(), DagModel.is_active == true())
@@ -293,7 +314,7 @@ def next_dagruns_to_examine(
# For dag runs in the queued state, we check if they have reached the max_active_runs limit
# and if so we drop them
running_drs = (
- session.query(DagRun.dag_id, func.count(DagRun.state).label('num_running'))
+ session.query(DagRun.dag_id, func.count(DagRun.state).label("num_running"))
.filter(DagRun.state == DagRunState.RUNNING)
.group_by(DagRun.dag_id)
.subquery()
@@ -317,17 +338,17 @@ def next_dagruns_to_examine(
@provide_session
def find(
cls,
- dag_id: Optional[Union[str, List[str]]] = None,
- run_id: Optional[Iterable[str]] = None,
- execution_date: Optional[Union[datetime, Iterable[datetime]]] = None,
- state: Optional[DagRunState] = None,
- external_trigger: Optional[bool] = None,
+ dag_id: str | list[str] | None = None,
+ run_id: Iterable[str] | None = None,
+ execution_date: datetime | Iterable[datetime] | None = None,
+ state: DagRunState | None = None,
+ external_trigger: bool | None = None,
no_backfills: bool = False,
- run_type: Optional[DagRunType] = None,
+ run_type: DagRunType | None = None,
session: Session = NEW_SESSION,
- execution_start_date: Optional[datetime] = None,
- execution_end_date: Optional[datetime] = None,
- ) -> List["DagRun"]:
+ execution_start_date: datetime | None = None,
+ execution_end_date: datetime | None = None,
+ ) -> list[DagRun]:
"""
Returns a set of dag runs for the given search criteria.
@@ -381,7 +402,7 @@ def find_duplicate(
run_id: str,
execution_date: datetime,
session: Session = NEW_SESSION,
- ) -> Optional['DagRun']:
+ ) -> DagRun | None:
"""
Return an existing run for the DAG with a specific run_id or execution_date.
@@ -404,14 +425,15 @@ def find_duplicate(
@staticmethod
def generate_run_id(run_type: DagRunType, execution_date: datetime) -> str:
"""Generate Run ID based on Run Type and Execution Date"""
- return f"{run_type}__{execution_date.isoformat()}"
+ # _Ensure_ run_type is a DagRunType, not just a string from user code
+ return DagRunType(run_type).generate_run_id(execution_date)
@provide_session
def get_task_instances(
self,
- state: Optional[Iterable[Optional[TaskInstanceState]]] = None,
+ state: Iterable[TaskInstanceState | None] | None = None,
session: Session = NEW_SESSION,
- ) -> List[TI]:
+ ) -> list[TI]:
"""Returns the task instances for this dag run"""
tis = (
session.query(TI)
@@ -447,7 +469,7 @@ def get_task_instance(
session: Session = NEW_SESSION,
*,
map_index: int = -1,
- ) -> Optional[TI]:
+ ) -> TI | None:
"""
Returns the task instance specified by task_id for this dag run
@@ -460,7 +482,7 @@ def get_task_instance(
.one_or_none()
)
- def get_dag(self) -> "DAG":
+ def get_dag(self) -> DAG:
"""
Returns the Dag associated with this DagRun.
@@ -473,8 +495,8 @@ def get_dag(self) -> "DAG":
@provide_session
def get_previous_dagrun(
- self, state: Optional[DagRunState] = None, session: Session = NEW_SESSION
- ) -> Optional['DagRun']:
+ self, state: DagRunState | None = None, session: Session = NEW_SESSION
+ ) -> DagRun | None:
"""The previous DagRun, if there is one"""
filters = [
DagRun.dag_id == self.dag_id,
@@ -485,7 +507,7 @@ def get_previous_dagrun(
return session.query(DagRun).filter(*filters).order_by(DagRun.execution_date.desc()).first()
@provide_session
- def get_previous_scheduled_dagrun(self, session: Session = NEW_SESSION) -> Optional['DagRun']:
+ def get_previous_scheduled_dagrun(self, session: Session = NEW_SESSION) -> DagRun | None:
"""The previous, SCHEDULED DagRun, if there is one"""
return (
session.query(DagRun)
@@ -501,19 +523,38 @@ def get_previous_scheduled_dagrun(self, session: Session = NEW_SESSION) -> Optio
@provide_session
def update_state(
self, session: Session = NEW_SESSION, execute_callbacks: bool = True
- ) -> Tuple[List[TI], Optional[DagCallbackRequest]]:
+ ) -> tuple[list[TI], DagCallbackRequest | None]:
"""
Determines the overall state of the DagRun based on the state
of its TaskInstances.
:param session: Sqlalchemy ORM Session
- :param execute_callbacks: Should dag callbacks (success/failure, SLA etc) be invoked
- directly (default: true) or recorded as a pending request in the ``callback`` property
- :return: Tuple containing tis that can be scheduled in the current loop & `callback` that
+ :param execute_callbacks: Should dag callbacks (success/failure, SLA etc.) be invoked
+ directly (default: true) or recorded as a pending request in the ``returned_callback`` property
+ :return: Tuple containing tis that can be scheduled in the current loop & `returned_callback` that
needs to be executed
"""
# Callback to execute in case of Task Failures
- callback: Optional[DagCallbackRequest] = None
+ callback: DagCallbackRequest | None = None
+
+ class _UnfinishedStates(NamedTuple):
+ tis: Sequence[TI]
+
+ @classmethod
+ def calculate(cls, unfinished_tis: Sequence[TI]) -> _UnfinishedStates:
+ return cls(tis=unfinished_tis)
+
+ @property
+ def should_schedule(self) -> bool:
+ return (
+ bool(self.tis)
+ and all(not t.task.depends_on_past for t in self.tis)
+ and all(t.task.max_active_tis_per_dag is None for t in self.tis)
+ and all(t.state != TaskInstanceState.DEFERRED for t in self.tis)
+ )
+
+ def recalculate(self) -> _UnfinishedStates:
+ return self._replace(tis=[t for t in self.tis if t.state in State.unfinished])
start_dttm = timezone.utcnow()
self.last_scheduling_decision = start_dttm
@@ -525,72 +566,82 @@ def update_state(
schedulable_tis = info.schedulable_tis
changed_tis = info.changed_tis
finished_tis = info.finished_tis
- unfinished_tis = info.unfinished_tis
+ unfinished = _UnfinishedStates.calculate(info.unfinished_tis)
- none_depends_on_past = all(not t.task.depends_on_past for t in unfinished_tis)
- none_task_concurrency = all(t.task.max_active_tis_per_dag is None for t in unfinished_tis)
- none_deferred = all(t.state != State.DEFERRED for t in unfinished_tis)
-
- if unfinished_tis and none_depends_on_past and none_task_concurrency and none_deferred:
+ if unfinished.should_schedule:
+ are_runnable_tasks = schedulable_tis or changed_tis
# small speed up
- are_runnable_tasks = (
- schedulable_tis
- or self._are_premature_tis(unfinished_tis, finished_tis, session)
- or changed_tis
- )
+ if not are_runnable_tasks:
+ are_runnable_tasks, changed_by_upstream = self._are_premature_tis(
+ unfinished.tis, finished_tis, session
+ )
+ if changed_by_upstream: # Something changed, we need to recalculate!
+ unfinished = unfinished.recalculate()
leaf_task_ids = {t.task_id for t in dag.leaves}
leaf_tis = [ti for ti in tis if ti.task_id in leaf_task_ids if ti.state != TaskInstanceState.REMOVED]
# if all roots finished and at least one failed, the run failed
- if not unfinished_tis and any(leaf_ti.state in State.failed_states for leaf_ti in leaf_tis):
- self.log.error('Marking run %s failed', self)
+ if not unfinished.tis and any(leaf_ti.state in State.failed_states for leaf_ti in leaf_tis):
+ self.log.error("Marking run %s failed", self)
self.set_state(DagRunState.FAILED)
+ self.notify_dagrun_state_changed(msg="task_failure")
+
if execute_callbacks:
- dag.handle_callback(self, success=False, reason='task_failure', session=session)
+ dag.handle_callback(self, success=False, reason="task_failure", session=session)
elif dag.has_on_failure_callback:
+ from airflow.models.dag import DagModel
+
+ dag_model = DagModel.get_dagmodel(dag.dag_id, session)
callback = DagCallbackRequest(
full_filepath=dag.fileloc,
dag_id=self.dag_id,
run_id=self.run_id,
is_failure_callback=True,
- msg='task_failure',
+ processor_subdir=dag_model.processor_subdir,
+ msg="task_failure",
)
# if all leaves succeeded and no unfinished tasks, the run succeeded
- elif not unfinished_tis and all(leaf_ti.state in State.success_states for leaf_ti in leaf_tis):
- self.log.info('Marking run %s successful', self)
+ elif not unfinished.tis and all(leaf_ti.state in State.success_states for leaf_ti in leaf_tis):
+ self.log.info("Marking run %s successful", self)
self.set_state(DagRunState.SUCCESS)
+ self.notify_dagrun_state_changed(msg="success")
+
if execute_callbacks:
- dag.handle_callback(self, success=True, reason='success', session=session)
+ dag.handle_callback(self, success=True, reason="success", session=session)
elif dag.has_on_success_callback:
+ from airflow.models.dag import DagModel
+
+ dag_model = DagModel.get_dagmodel(dag.dag_id, session)
callback = DagCallbackRequest(
full_filepath=dag.fileloc,
dag_id=self.dag_id,
run_id=self.run_id,
is_failure_callback=False,
- msg='success',
+ processor_subdir=dag_model.processor_subdir,
+ msg="success",
)
# if *all tasks* are deadlocked, the run failed
- elif (
- unfinished_tis
- and none_depends_on_past
- and none_task_concurrency
- and none_deferred
- and not are_runnable_tasks
- ):
- self.log.error('Deadlock; marking run %s failed', self)
+ elif unfinished.should_schedule and not are_runnable_tasks:
+ self.log.error("Task deadlock (no runnable tasks); marking run %s failed", self)
self.set_state(DagRunState.FAILED)
+ self.notify_dagrun_state_changed(msg="all_tasks_deadlocked")
+
if execute_callbacks:
- dag.handle_callback(self, success=False, reason='all_tasks_deadlocked', session=session)
+ dag.handle_callback(self, success=False, reason="all_tasks_deadlocked", session=session)
elif dag.has_on_failure_callback:
+ from airflow.models.dag import DagModel
+
+ dag_model = DagModel.get_dagmodel(dag.dag_id, session)
callback = DagCallbackRequest(
full_filepath=dag.fileloc,
dag_id=self.dag_id,
run_id=self.run_id,
is_failure_callback=True,
- msg='all_tasks_deadlocked',
+ processor_subdir=dag_model.processor_subdir,
+ msg="all_tasks_deadlocked",
)
# finally, if the roots aren't done, the dag is still running
@@ -633,22 +684,23 @@ def update_state(
@provide_session
def task_instance_scheduling_decisions(self, session: Session = NEW_SESSION) -> TISchedulingDecision:
+ tis = self.get_task_instances(session=session, state=State.task_states)
+ self.log.debug("number of tis tasks for %s: %s task(s)", self, len(tis))
- schedulable_tis: List[TI] = []
- changed_tis = False
+ def _filter_tis_and_exclude_removed(dag: DAG, tis: list[TI]) -> Iterable[TI]:
+ """Populate ``ti.task`` while excluding those missing one, marking them as REMOVED."""
+ for ti in tis:
+ try:
+ ti.task = dag.get_task(ti.task_id)
+ except TaskNotFound:
+ if ti.state != State.REMOVED:
+ self.log.error("Failed to get task for ti %s. Marking it as removed.", ti)
+ ti.state = State.REMOVED
+ session.flush()
+ else:
+ yield ti
- tis = list(self.get_task_instances(session=session, state=State.task_states))
- self.log.debug("number of tis tasks for %s: %s task(s)", self, len(tis))
- dag = self.get_dag()
- for ti in tis:
- try:
- ti.task = dag.get_task(ti.task_id)
- except TaskNotFound:
- self.log.warning(
- "Failed to get task '%s' for dag '%s'. Marking it as removed.", ti, ti.dag_id
- )
- ti.state = State.REMOVED
- session.flush()
+ tis = list(_filter_tis_and_exclude_removed(self.get_dag(), tis))
unfinished_tis = [t for t in tis if t.state in State.unfinished]
finished_tis = [t for t in tis if t.state in State.finished]
@@ -661,12 +713,16 @@ def task_instance_scheduling_decisions(self, session: Session = NEW_SESSION) ->
session=session,
)
- # During expansion we may change some tis into non-schedulable
+ # During expansion, we may change some tis into non-schedulable
# states, so we need to re-compute.
if expansion_happened:
+ changed_tis = True
new_unfinished_tis = [t for t in unfinished_tis if t.state in State.unfinished]
finished_tis.extend(t for t in unfinished_tis if t.state in State.finished)
unfinished_tis = new_unfinished_tis
+ else:
+ schedulable_tis = []
+ changed_tis = False
return TISchedulingDecision(
tis=tis,
@@ -676,14 +732,25 @@ def task_instance_scheduling_decisions(self, session: Session = NEW_SESSION) ->
finished_tis=finished_tis,
)
+ def notify_dagrun_state_changed(self, msg: str = ""):
+ if self.state == DagRunState.RUNNING:
+ get_listener_manager().hook.on_dag_run_running(dag_run=self, msg=msg)
+ elif self.state == DagRunState.SUCCESS:
+ get_listener_manager().hook.on_dag_run_success(dag_run=self, msg=msg)
+ elif self.state == DagRunState.FAILED:
+ get_listener_manager().hook.on_dag_run_failed(dag_run=self, msg=msg)
+ # deliberately not notifying on QUEUED
+ # we can't get all the state changes on SchedulerJob, BackfillJob
+ # or LocalTaskJob, so we don't want to "falsely advertise" we notify about that
+
def _get_ready_tis(
self,
- schedulable_tis: List[TI],
- finished_tis: List[TI],
+ schedulable_tis: list[TI],
+ finished_tis: list[TI],
session: Session,
- ) -> Tuple[List[TI], bool, bool]:
+ ) -> tuple[list[TI], bool, bool]:
old_states = {}
- ready_tis: List[TI] = []
+ ready_tis: list[TI] = []
changed_tis = False
if not schedulable_tis:
@@ -691,13 +758,34 @@ def _get_ready_tis(
# If we expand TIs, we need a new list so that we iterate over them too. (We can't alter
# `schedulable_tis` in place and have the `for` loop pick them up
- additional_tis: List[TI] = []
+ additional_tis: list[TI] = []
dep_context = DepContext(
flag_upstream_failed=True,
ignore_unmapped_tasks=True, # Ignore this Dep, as we will expand it if we can.
finished_tis=finished_tis,
)
+ def _expand_mapped_task_if_needed(ti: TI) -> Iterable[TI] | None:
+ """Try to expand the ti, if needed.
+
+ If the ti needs expansion, newly created task instances are
+ returned as well as the original ti.
+ The original ti is also modified in-place and assigned the
+ ``map_index`` of 0.
+
+ If the ti does not need expansion, either because the task is not
+ mapped, or has already been expanded, *None* is returned.
+ """
+ if ti.map_index >= 0: # Already expanded, we're good.
+ return None
+ try:
+ expanded_tis, _ = ti.task.expand_mapped_task(self.run_id, session=session)
+ except NotMapped: # Not a mapped task, nothing needed.
+ return None
+ if expanded_tis:
+ return expanded_tis
+ return ()
+
# Check dependencies.
expansion_happened = False
for schedulable in itertools.chain(schedulable_tis, additional_tis):
@@ -705,18 +793,20 @@ def _get_ready_tis(
if not schedulable.are_dependencies_met(session=session, dep_context=dep_context):
old_states[schedulable.key] = old_state
continue
- # If schedulable is from a mapped task, but not yet expanded, do it
- # now. This is called in two places: First and ideally in the mini
- # scheduler at the end of LocalTaskJob, and then as an "expansion of
- # last resort" in the scheduler to ensure that the mapped task is
- # correctly expanded before executed.
- if schedulable.map_index < 0 and isinstance(schedulable.task, MappedOperator):
- expanded_tis, _ = schedulable.task.expand_mapped_task(self.run_id, session=session)
- if expanded_tis:
- assert expanded_tis[0] is schedulable
- additional_tis.extend(expanded_tis[1:])
- expansion_happened = True
- if schedulable.state in SCHEDULEABLE_STATES:
+ # If schedulable is not yet expanded, try doing it now. This is
+ # called in two places: First and ideally in the mini scheduler at
+ # the end of LocalTaskJob, and then as an "expansion of last resort"
+ # in the scheduler to ensure that the mapped task is correctly
+ # expanded before executed. Also see _revise_map_indexes_if_mapped
+ # docstring for additional information.
+ new_tis = None
+ if schedulable.map_index < 0:
+ new_tis = _expand_mapped_task_if_needed(schedulable)
+ if new_tis is not None:
+ additional_tis.extend(new_tis)
+ expansion_happened = True
+ if new_tis is None and schedulable.state in SCHEDULEABLE_STATES:
+ ready_tis.extend(self._revise_map_indexes_if_mapped(schedulable.task, session=session))
ready_tis.append(schedulable)
# Check if any ti changed state
@@ -729,26 +819,24 @@ def _get_ready_tis(
def _are_premature_tis(
self,
- unfinished_tis: List[TI],
- finished_tis: List[TI],
+ unfinished_tis: Sequence[TI],
+ finished_tis: list[TI],
session: Session,
- ) -> bool:
- # there might be runnable tasks that are up for retry and for some reason(retry delay, etc) are
- # not ready yet so we set the flags to count them in
- for ut in unfinished_tis:
- if ut.are_dependencies_met(
- dep_context=DepContext(
- flag_upstream_failed=True,
- ignore_in_retry_period=True,
- ignore_in_reschedule_period=True,
- finished_tis=finished_tis,
- ),
- session=session,
- ):
- return True
- return False
+ ) -> tuple[bool, bool]:
+ dep_context = DepContext(
+ flag_upstream_failed=True,
+ ignore_in_retry_period=True,
+ ignore_in_reschedule_period=True,
+ finished_tis=finished_tis,
+ )
+ # there might be runnable tasks that are up for retry and for some reason(retry delay, etc.) are
+ # not ready yet, so we set the flags to count them in
+ return (
+ any(ut.are_dependencies_met(dep_context=dep_context, session=session) for ut in unfinished_tis),
+ dep_context.have_changed_ti_states,
+ )
- def _emit_true_scheduling_delay_stats_for_finished_state(self, finished_tis: List[TI]) -> None:
+ def _emit_true_scheduling_delay_stats_for_finished_state(self, finished_tis: list[TI]) -> None:
"""
This is a helper method to emit the true scheduling delay stats, which is defined as
the time when the first task in DAG starts minus the expected DAG run datetime.
@@ -756,7 +844,7 @@ def _emit_true_scheduling_delay_stats_for_finished_state(self, finished_tis: Lis
is updated to a completed status (either success or failure). The method will find the first
started task within the DAG and calculate the expected DagRun start time (based on
dag.execution_date & dag.timetable), and minus these two values to get the delay.
- The emitted data may contains outlier (e.g. when the first task was cleared, so
+ The emitted data may contain outlier (e.g. when the first task was cleared, so
the second task's start_date will be used), but we can get rid of the outliers
on the stats side through the dashboards tooling built.
Note, the stat will only be emitted if the DagRun is a scheduler triggered one
@@ -778,7 +866,7 @@ def _emit_true_scheduling_delay_stats_for_finished_state(self, finished_tis: Lis
ordered_tis_by_start_date = [ti for ti in finished_tis if ti.start_date]
ordered_tis_by_start_date.sort(key=lambda ti: ti.start_date, reverse=False)
- first_start_date = ordered_tis_by_start_date[0].start_date
+ first_start_date = ordered_tis_by_start_date[0].start_date if ordered_tis_by_start_date else None
if first_start_date:
# TODO: Logically, this should be DagRunInfo.run_after, but the
# information is not stored on a DagRun, only before the actual
@@ -788,45 +876,81 @@ def _emit_true_scheduling_delay_stats_for_finished_state(self, finished_tis: Lis
data_interval_end = dag.get_run_data_interval(self).end
true_delay = first_start_date - data_interval_end
if true_delay.total_seconds() > 0:
- Stats.timing(f'dagrun.{dag.dag_id}.first_task_scheduling_delay', true_delay)
+ Stats.timing(f"dagrun.{dag.dag_id}.first_task_scheduling_delay", true_delay)
except Exception:
- self.log.warning('Failed to record first_task_scheduling_delay metric:', exc_info=True)
+ self.log.warning("Failed to record first_task_scheduling_delay metric:", exc_info=True)
def _emit_duration_stats_for_finished_state(self):
if self.state == State.RUNNING:
return
if self.start_date is None:
- self.log.warning('Failed to record duration of %s: start_date is not set.', self)
+ self.log.warning("Failed to record duration of %s: start_date is not set.", self)
return
if self.end_date is None:
- self.log.warning('Failed to record duration of %s: end_date is not set.', self)
+ self.log.warning("Failed to record duration of %s: end_date is not set.", self)
return
duration = self.end_date - self.start_date
if self.state == State.SUCCESS:
- Stats.timing(f'dagrun.duration.success.{self.dag_id}', duration)
+ Stats.timing(f"dagrun.duration.success.{self.dag_id}", duration)
elif self.state == State.FAILED:
- Stats.timing(f'dagrun.duration.failed.{self.dag_id}', duration)
+ Stats.timing(f"dagrun.duration.failed.{self.dag_id}", duration)
@provide_session
- def verify_integrity(self, session: Session = NEW_SESSION):
+ def verify_integrity(self, *, session: Session = NEW_SESSION) -> None:
"""
Verifies the DagRun by checking for removed tasks or tasks that are not in the
database yet. It will set state to removed or add the task if required.
+ :missing_indexes: A dictionary of task vs indexes that are missing.
:param session: Sqlalchemy ORM Session
"""
from airflow.settings import task_instance_mutation_hook
+ # Set for the empty default in airflow.settings -- if it's not set this means it has been changed
+ # Note: Literal[True, False] instead of bool because otherwise it doesn't correctly find the overload.
+ hook_is_noop: Literal[True, False] = getattr(task_instance_mutation_hook, "is_noop", False)
+
dag = self.get_dag()
+ task_ids = self._check_for_removed_or_restored_tasks(
+ dag, task_instance_mutation_hook, session=session
+ )
+
+ def task_filter(task: Operator) -> bool:
+ return task.task_id not in task_ids and (
+ self.is_backfill
+ or task.start_date <= self.execution_date
+ and (task.end_date is None or self.execution_date <= task.end_date)
+ )
+
+ created_counts: dict[str, int] = defaultdict(int)
+ task_creator = self._get_task_creator(created_counts, task_instance_mutation_hook, hook_is_noop)
+
+ # Create the missing tasks, including mapped tasks
+ tasks_to_create = (task for task in dag.task_dict.values() if task_filter(task))
+ tis_to_create = self._create_tasks(tasks_to_create, task_creator, session=session)
+ self._create_task_instances(self.dag_id, tis_to_create, created_counts, hook_is_noop, session=session)
+
+ def _check_for_removed_or_restored_tasks(
+ self, dag: DAG, ti_mutation_hook, *, session: Session
+ ) -> set[str]:
+ """
+ Check for removed tasks/restored/missing tasks.
+
+ :param dag: DAG object corresponding to the dagrun
+ :param ti_mutation_hook: task_instance_mutation_hook function
+ :param session: Sqlalchemy ORM Session
+
+ :return: Task IDs in the DAG run
+
+ """
tis = self.get_task_instances(session=session)
# check for removed or restored tasks
task_ids = set()
for ti in tis:
- task_instance_mutation_hook(ti)
+ ti_mutation_hook(ti)
task_ids.add(ti.task_id)
- task = None
try:
task = dag.get_task(ti.task_id)
@@ -844,31 +968,15 @@ def verify_integrity(self, session: Session = NEW_SESSION):
ti.state = State.REMOVED
continue
- if not task.is_mapped:
+ try:
+ num_mapped_tis = task.get_parse_time_mapped_ti_count()
+ except NotMapped:
continue
- task = cast("MappedOperator", task)
- num_mapped_tis = task.parse_time_mapped_ti_count
- # Check if the number of mapped literals has changed and we need to mark this TI as removed
- if num_mapped_tis is not None:
- if ti.map_index >= num_mapped_tis:
- self.log.debug(
- "Removing task '%s' as the map_index is longer than the literal mapping list (%s)",
- ti,
- num_mapped_tis,
- )
- ti.state = State.REMOVED
- elif ti.map_index < 0:
- self.log.debug("Removing the unmapped TI '%s' as the mapping can now be performed", ti)
- ti.state = State.REMOVED
- else:
- self.log.info("Restoring mapped task '%s'", ti)
- Stats.incr(f"task_restored_to_dag.{dag.dag_id}", 1, 1)
- ti.state = State.NONE
- else:
- # What if it is _now_ dynamically mapped, but wasn't before?
- total_length = task.run_time_mapped_ti_count(self.run_id, session=session)
-
- if total_length is None:
+ except NotFullyPopulated:
+ # What if it is _now_ dynamically mapped, but wasn't before?
+ try:
+ total_length = task.get_mapped_ti_count(self.run_id, session=session)
+ except NotFullyPopulated:
# Not all upstreams finished, so we can't tell what should be here. Remove everything.
if ti.map_index >= 0:
self.log.debug(
@@ -884,23 +992,58 @@ def verify_integrity(self, session: Session = NEW_SESSION):
total_length,
)
ti.state = State.REMOVED
- ...
+ else:
+ # Check if the number of mapped literals has changed, and we need to mark this TI as removed.
+ if ti.map_index >= num_mapped_tis:
+ self.log.debug(
+ "Removing task '%s' as the map_index is longer than the literal mapping list (%s)",
+ ti,
+ num_mapped_tis,
+ )
+ ti.state = State.REMOVED
+ elif ti.map_index < 0:
+ self.log.debug("Removing the unmapped TI '%s' as the mapping can now be performed", ti)
+ ti.state = State.REMOVED
- def task_filter(task: "Operator") -> bool:
- return task.task_id not in task_ids and (
- self.is_backfill
- or task.start_date <= self.execution_date
- and (task.end_date is None or self.execution_date <= task.end_date)
- )
+ return task_ids
- created_counts: Dict[str, int] = defaultdict(int)
+ @overload
+ def _get_task_creator(
+ self,
+ created_counts: dict[str, int],
+ ti_mutation_hook: Callable,
+ hook_is_noop: Literal[True],
+ ) -> Callable[[Operator, Iterable[int]], Iterator[dict[str, Any]]]:
+ ...
+
+ @overload
+ def _get_task_creator(
+ self,
+ created_counts: dict[str, int],
+ ti_mutation_hook: Callable,
+ hook_is_noop: Literal[False],
+ ) -> Callable[[Operator, Iterable[int]], Iterator[TI]]:
+ ...
- # Set for the empty default in airflow.settings -- if it's not set this means it has been changed
- hook_is_noop = getattr(task_instance_mutation_hook, 'is_noop', False)
+ def _get_task_creator(
+ self,
+ created_counts: dict[str, int],
+ ti_mutation_hook: Callable,
+ hook_is_noop: Literal[True, False],
+ ) -> Callable[[Operator, Iterable[int]], Iterator[dict[str, Any]] | Iterator[TI]]:
+ """
+ Get the task creator function.
+
+ This function also updates the created_counts dictionary with the number of tasks created.
+
+ :param created_counts: Dictionary of task_type -> count of created TIs
+ :param ti_mutation_hook: task_instance_mutation_hook function
+ :param hook_is_noop: Whether the task_instance_mutation_hook is a noop
+ """
if hook_is_noop:
- def create_ti_mapping(task: "Operator", indexes: Tuple[int, ...]) -> Generator:
+ def create_ti_mapping(task: Operator, indexes: Iterable[int]) -> Iterator[dict[str, Any]]:
created_counts[task.task_type] += 1
for map_index in indexes:
yield TI.insert_mapping(self.run_id, task, map_index=map_index)
@@ -909,30 +1052,67 @@ def create_ti_mapping(task: "Operator", indexes: Tuple[int, ...]) -> Generator:
else:
- def create_ti(task: "Operator", indexes: Tuple[int, ...]) -> Generator:
+ def create_ti(task: Operator, indexes: Iterable[int]) -> Iterator[TI]:
for map_index in indexes:
ti = TI(task, run_id=self.run_id, map_index=map_index)
- task_instance_mutation_hook(ti)
+ ti_mutation_hook(ti)
created_counts[ti.operator] += 1
yield ti
creator = create_ti
+ return creator
- # Create missing tasks -- and expand any MappedOperator that _only_ have literals as input
- def expand_mapped_literals(task: "Operator") -> Tuple["Operator", Sequence[int]]:
- if not task.is_mapped:
- return (task, (-1,))
- task = cast("MappedOperator", task)
- count = task.parse_time_mapped_ti_count or task.run_time_mapped_ti_count(
- self.run_id, session=session
- )
- if not count:
- return (task, (-1,))
- return (task, range(count))
+ def _create_tasks(
+ self,
+ tasks: Iterable[Operator],
+ task_creator: TaskCreator,
+ *,
+ session: Session,
+ ) -> CreatedTasks:
+ """
+ Create missing tasks -- and expand any MappedOperator that _only_ have literals as input
- tasks_and_map_idxs = map(expand_mapped_literals, filter(task_filter, dag.task_dict.values()))
- tasks = itertools.chain.from_iterable(itertools.starmap(creator, tasks_and_map_idxs))
+ :param tasks: Tasks to create jobs for in the DAG run
+ :param task_creator: Function to create task instances
+ """
+ map_indexes: Iterable[int]
+ for task in tasks:
+ try:
+ count = task.get_mapped_ti_count(self.run_id, session=session)
+ except (NotMapped, NotFullyPopulated):
+ map_indexes = (-1,)
+ else:
+ if count:
+ map_indexes = range(count)
+ else:
+ # Make sure to always create at least one ti; this will be
+ # marked as REMOVED later at runtime.
+ map_indexes = (-1,)
+ yield from task_creator(task, map_indexes)
+
+ def _create_task_instances(
+ self,
+ dag_id: str,
+ tasks: Iterator[dict[str, Any]] | Iterator[TI],
+ created_counts: dict[str, int],
+ hook_is_noop: bool,
+ *,
+ session: Session,
+ ) -> None:
+ """
+ Create the necessary task instances from the given tasks.
+ :param dag_id: DAG ID associated with the dagrun
+ :param tasks: the tasks to create the task instances from
+ :param created_counts: a dictionary of number of tasks -> total ti created by the task creator
+ :param hook_is_noop: whether the task_instance_mutation_hook is noop
+ :param session: the session to use
+
+ """
+ # Fetch the information we need before handling the exception to avoid
+ # PendingRollbackError due to the session being invalidated on exception
+ # see https://github.com/apache/superset/pull/530
+ run_id = self.run_id
try:
if hook_is_noop:
session.bulk_insert_mappings(TI, tasks)
@@ -944,17 +1124,62 @@ def expand_mapped_literals(task: "Operator") -> Tuple["Operator", Sequence[int]]
session.flush()
except IntegrityError:
self.log.info(
- 'Hit IntegrityError while creating the TIs for %s- %s',
- dag.dag_id,
- self.run_id,
+ "Hit IntegrityError while creating the TIs for %s- %s",
+ dag_id,
+ run_id,
exc_info=True,
)
- self.log.info('Doing session rollback.')
+ self.log.info("Doing session rollback.")
# TODO[HA]: We probably need to savepoint this so we can keep the transaction alive.
session.rollback()
+ def _revise_map_indexes_if_mapped(self, task: Operator, *, session: Session) -> Iterator[TI]:
+ """Check if task increased or reduced in length and handle appropriately.
+
+ Task instances that do not already exist are created and returned if
+ possible. Expansion only happens if all upstreams are ready; otherwise
+ we delay expansion to the "last resort". See comments at the call site
+ for more details.
+ """
+ from airflow.settings import task_instance_mutation_hook
+
+ try:
+ total_length = task.get_mapped_ti_count(self.run_id, session=session)
+ except NotMapped:
+ return # Not a mapped task, don't need to do anything.
+ except NotFullyPopulated:
+ return # Upstreams not ready, don't need to revise this yet.
+
+ query = session.query(TI.map_index).filter(
+ TI.dag_id == self.dag_id,
+ TI.task_id == task.task_id,
+ TI.run_id == self.run_id,
+ )
+ existing_indexes = {i for (i,) in query}
+
+ removed_indexes = existing_indexes.difference(range(total_length))
+ if removed_indexes:
+ session.query(TI).filter(
+ TI.dag_id == self.dag_id,
+ TI.task_id == task.task_id,
+ TI.run_id == self.run_id,
+ TI.map_index.in_(removed_indexes),
+ ).update({TI.state: TaskInstanceState.REMOVED})
+ session.flush()
+
+ for index in range(total_length):
+ if index in existing_indexes:
+ continue
+ ti = TI(task, run_id=self.run_id, map_index=index, state=None)
+ self.log.debug("Expanding TIs upserted %s", ti)
+ task_instance_mutation_hook(ti)
+ ti = session.merge(ti)
+ ti.refresh_from_task(task)
+ session.flush()
+ yield ti
+
@staticmethod
- def get_run(session: Session, dag_id: str, execution_date: datetime) -> Optional['DagRun']:
+ def get_run(session: Session, dag_id: str, execution_date: datetime) -> DagRun | None:
"""
Get a single DAG Run
@@ -964,11 +1189,10 @@ def get_run(session: Session, dag_id: str, execution_date: datetime) -> Optional
:param execution_date: execution date
:return: DagRun corresponding to the given dag_id and execution date
if one exists. None otherwise.
- :rtype: airflow.models.DagRun
"""
warnings.warn(
"This method is deprecated. Please use SQLAlchemy directly",
- DeprecationWarning,
+ RemovedInAirflow3Warning,
stacklevel=2,
)
return (
@@ -987,10 +1211,10 @@ def is_backfill(self) -> bool:
@classmethod
@provide_session
- def get_latest_runs(cls, session=None) -> List['DagRun']:
+ def get_latest_runs(cls, session=None) -> list[DagRun]:
"""Returns the latest DagRun for each DAG"""
subquery = (
- session.query(cls.dag_id, func.max(cls.execution_date).label('execution_date'))
+ session.query(cls.dag_id, func.max(cls.execution_date).label("execution_date"))
.group_by(cls.dag_id)
.subquery()
)
@@ -1010,7 +1234,7 @@ def schedule_tis(self, schedulable_tis: Iterable[TI], session: Session = NEW_SES
Each element of ``schedulable_tis`` should have it's ``task`` attribute already set.
- Any EmptyOperator without callbacks is instead set straight to the success state.
+ Any EmptyOperator without callbacks or outlets is instead set straight to the success state.
All the TIs should belong to this DagRun, but this code is in the hot-path, this is not checked -- it
is the caller's responsibility to call this function only with TIs from a single dag run.
@@ -1024,6 +1248,7 @@ def schedule_tis(self, schedulable_tis: Iterable[TI], session: Session = NEW_SES
ti.task.inherits_from_empty_operator
and not ti.task.on_execute_callback
and not ti.task.on_success_callback
+ and not ti.task.outlets
):
dummy_ti_ids.append(ti.task_id)
else:
@@ -1065,14 +1290,62 @@ def schedule_tis(self, schedulable_tis: Iterable[TI], session: Session = NEW_SES
return count
@provide_session
- def get_log_filename_template(self, *, session: Session = NEW_SESSION) -> str:
+ def get_log_template(self, *, session: Session = NEW_SESSION) -> LogTemplate:
if self.log_template_id is None: # DagRun created before LogTemplate introduction.
- template = session.query(LogTemplate.filename).order_by(LogTemplate.id).limit(1).scalar()
+ template = session.query(LogTemplate).order_by(LogTemplate.id).first()
else:
- template = session.query(LogTemplate.filename).filter_by(id=self.log_template_id).scalar()
+ template = session.query(LogTemplate).get(self.log_template_id)
if template is None:
raise AirflowException(
f"No log_template entry found for ID {self.log_template_id!r}. "
f"Please make sure you set up the metadatabase correctly."
)
return template
+
+ @provide_session
+ def get_log_filename_template(self, *, session: Session = NEW_SESSION) -> str:
+ warnings.warn(
+ "This method is deprecated. Please use get_log_template instead.",
+ RemovedInAirflow3Warning,
+ stacklevel=2,
+ )
+ return self.get_log_template(session=session).filename
+
+
+class DagRunNote(Base):
+ """For storage of arbitrary notes concerning the dagrun instance."""
+
+ __tablename__ = "dag_run_note"
+
+ user_id = Column(Integer, nullable=True)
+ dag_run_id = Column(Integer, primary_key=True, nullable=False)
+ content = Column(String(1000).with_variant(Text(1000), "mysql"))
+ created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
+ updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False)
+
+ dag_run = relationship("DagRun", back_populates="dag_run_note")
+
+ __table_args__ = (
+ PrimaryKeyConstraint("dag_run_id", name="dag_run_note_pkey"),
+ ForeignKeyConstraint(
+ (dag_run_id,),
+ ["dag_run.id"],
+ name="dag_run_note_dr_fkey",
+ ondelete="CASCADE",
+ ),
+ ForeignKeyConstraint(
+ (user_id,),
+ ["ab_user.id"],
+ name="dag_run_note_user_fkey",
+ ),
+ )
+
+ def __init__(self, content, user_id=None):
+ self.content = content
+ self.user_id = user_id
+
+ def __repr__(self):
+ prefix = f"<{self.__class__.__name__}: {self.dag_id}.{self.dagrun_id} {self.run_id}"
+ if self.map_index != -1:
+ prefix += f" map_index={self.map_index}"
+ return prefix + ">"
diff --git a/airflow/models/dagwarning.py b/airflow/models/dagwarning.py
new file mode 100644
index 0000000000000..db93aafb0fba5
--- /dev/null
+++ b/airflow/models/dagwarning.py
@@ -0,0 +1,100 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from enum import Enum
+
+from sqlalchemy import Column, ForeignKeyConstraint, String, Text, false
+from sqlalchemy.orm import Session
+
+from airflow.models.base import Base, StringID
+from airflow.utils import timezone
+from airflow.utils.retries import retry_db_transaction
+from airflow.utils.session import NEW_SESSION, provide_session
+from airflow.utils.sqlalchemy import UtcDateTime
+
+
+class DagWarning(Base):
+ """
+ A table to store DAG warnings.
+
+ DAG warnings are problems that don't rise to the level of failing the DAG parse
+ but which users should nonetheless be warned about. These warnings are recorded
+ when parsing DAG and displayed on the Webserver in a flash message.
+ """
+
+ dag_id = Column(StringID(), primary_key=True)
+ warning_type = Column(String(50), primary_key=True)
+ message = Column(Text, nullable=False)
+ timestamp = Column(UtcDateTime, nullable=False, default=timezone.utcnow)
+
+ __tablename__ = "dag_warning"
+ __table_args__ = (
+ ForeignKeyConstraint(
+ ("dag_id",),
+ ["dag.dag_id"],
+ name="dcw_dag_id_fkey",
+ ondelete="CASCADE",
+ ),
+ )
+
+ def __init__(self, dag_id: str, error_type: str, message: str, **kwargs):
+ super().__init__(**kwargs)
+ self.dag_id = dag_id
+ self.warning_type = DagWarningType(error_type).value # make sure valid type
+ self.message = message
+
+ def __eq__(self, other) -> bool:
+ return self.dag_id == other.dag_id and self.warning_type == other.warning_type
+
+ def __hash__(self) -> int:
+ return hash((self.dag_id, self.warning_type))
+
+ @classmethod
+ @provide_session
+ def purge_inactive_dag_warnings(cls, session: Session = NEW_SESSION) -> None:
+ """
+ Deactivate DagWarning records for inactive dags.
+
+ :return: None
+ """
+ cls._purge_inactive_dag_warnings_with_retry(session)
+
+ @classmethod
+ @retry_db_transaction
+ def _purge_inactive_dag_warnings_with_retry(cls, session: Session) -> None:
+ from airflow.models.dag import DagModel
+
+ if session.get_bind().dialect.name == "sqlite":
+ dag_ids = session.query(DagModel.dag_id).filter(DagModel.is_active == false())
+ query = session.query(cls).filter(cls.dag_id.in_(dag_ids))
+ else:
+ query = session.query(cls).filter(cls.dag_id == DagModel.dag_id, DagModel.is_active == false())
+ query.delete(synchronize_session=False)
+ session.commit()
+
+
+class DagWarningType(str, Enum):
+ """
+ Enum for DAG warning types.
+
+ This is the set of allowable values for the ``warning_type`` field
+ in the DagWarning model.
+ """
+
+ NONEXISTENT_POOL = "non-existent pool"
diff --git a/airflow/models/dataset.py b/airflow/models/dataset.py
new file mode 100644
index 0000000000000..4cd370386c4c0
--- /dev/null
+++ b/airflow/models/dataset.py
@@ -0,0 +1,338 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from urllib.parse import urlsplit
+
+import sqlalchemy_jsonfield
+from sqlalchemy import (
+ Boolean,
+ Column,
+ ForeignKey,
+ ForeignKeyConstraint,
+ Index,
+ Integer,
+ PrimaryKeyConstraint,
+ String,
+ Table,
+ text,
+)
+from sqlalchemy.orm import relationship
+
+from airflow.datasets import Dataset
+from airflow.models.base import Base, StringID
+from airflow.settings import json
+from airflow.utils import timezone
+from airflow.utils.sqlalchemy import UtcDateTime
+
+
+class DatasetModel(Base):
+ """
+ A table to store datasets.
+
+ :param uri: a string that uniquely identifies the dataset
+ :param extra: JSON field for arbitrary extra info
+ """
+
+ id = Column(Integer, primary_key=True, autoincrement=True)
+ uri = Column(
+ String(length=3000).with_variant(
+ String(
+ length=3000,
+ # latin1 allows for more indexed length in mysql
+ # and this field should only be ascii chars
+ collation="latin1_general_cs",
+ ),
+ "mysql",
+ ),
+ nullable=False,
+ )
+ extra = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=False, default={})
+ created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
+ updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False)
+ is_orphaned = Column(Boolean, default=False, nullable=False, server_default="0")
+
+ consuming_dags = relationship("DagScheduleDatasetReference", back_populates="dataset")
+ producing_tasks = relationship("TaskOutletDatasetReference", back_populates="dataset")
+
+ __tablename__ = "dataset"
+ __table_args__ = (
+ Index("idx_uri_unique", uri, unique=True),
+ {"sqlite_autoincrement": True}, # ensures PK values not reused
+ )
+
+ @classmethod
+ def from_public(cls, obj: Dataset) -> DatasetModel:
+ return cls(uri=obj.uri, extra=obj.extra)
+
+ def __init__(self, uri: str, **kwargs):
+ try:
+ uri.encode("ascii")
+ except UnicodeEncodeError:
+ raise ValueError("URI must be ascii")
+ parsed = urlsplit(uri)
+ if parsed.scheme and parsed.scheme.lower() == "airflow":
+ raise ValueError("Scheme `airflow` is reserved.")
+ super().__init__(uri=uri, **kwargs)
+
+ def __eq__(self, other):
+ if isinstance(other, (self.__class__, Dataset)):
+ return self.uri == other.uri
+ else:
+ return NotImplemented
+
+ def __hash__(self):
+ return hash(self.uri)
+
+ def __repr__(self):
+ return f"{self.__class__.__name__}(uri={self.uri!r}, extra={self.extra!r})"
+
+
+class DagScheduleDatasetReference(Base):
+ """References from a DAG to a dataset of which it is a consumer."""
+
+ dataset_id = Column(Integer, primary_key=True, nullable=False)
+ dag_id = Column(StringID(), primary_key=True, nullable=False)
+ created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
+ updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False)
+
+ dataset = relationship("DatasetModel", back_populates="consuming_dags")
+ queue_records = relationship(
+ "DatasetDagRunQueue",
+ primaryjoin="""and_(
+ DagScheduleDatasetReference.dataset_id == foreign(DatasetDagRunQueue.dataset_id),
+ DagScheduleDatasetReference.dag_id == foreign(DatasetDagRunQueue.target_dag_id),
+ )""",
+ cascade="all, delete, delete-orphan",
+ )
+
+ __tablename__ = "dag_schedule_dataset_reference"
+ __table_args__ = (
+ PrimaryKeyConstraint(dataset_id, dag_id, name="dsdr_pkey", mssql_clustered=True),
+ ForeignKeyConstraint(
+ (dataset_id,),
+ ["dataset.id"],
+ name="dsdr_dataset_fkey",
+ ondelete="CASCADE",
+ ),
+ ForeignKeyConstraint(
+ columns=(dag_id,),
+ refcolumns=["dag.dag_id"],
+ name="dsdr_dag_id_fkey",
+ ondelete="CASCADE",
+ ),
+ )
+
+ def __eq__(self, other):
+ if isinstance(other, self.__class__):
+ return self.dataset_id == other.dataset_id and self.dag_id == other.dag_id
+ else:
+ return NotImplemented
+
+ def __hash__(self):
+ return hash(self.__mapper__.primary_key)
+
+ def __repr__(self):
+ args = []
+ for attr in [x.name for x in self.__mapper__.primary_key]:
+ args.append(f"{attr}={getattr(self, attr)!r}")
+ return f"{self.__class__.__name__}({', '.join(args)})"
+
+
+class TaskOutletDatasetReference(Base):
+ """References from a task to a dataset that it updates / produces."""
+
+ dataset_id = Column(Integer, primary_key=True, nullable=False)
+ dag_id = Column(StringID(), primary_key=True, nullable=False)
+ task_id = Column(StringID(), primary_key=True, nullable=False)
+ created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
+ updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False)
+
+ dataset = relationship("DatasetModel", back_populates="producing_tasks")
+
+ __tablename__ = "task_outlet_dataset_reference"
+ __table_args__ = (
+ ForeignKeyConstraint(
+ (dataset_id,),
+ ["dataset.id"],
+ name="todr_dataset_fkey",
+ ondelete="CASCADE",
+ ),
+ PrimaryKeyConstraint(dataset_id, dag_id, task_id, name="todr_pkey", mssql_clustered=True),
+ ForeignKeyConstraint(
+ columns=(dag_id,),
+ refcolumns=["dag.dag_id"],
+ name="todr_dag_id_fkey",
+ ondelete="CASCADE",
+ ),
+ )
+
+ def __eq__(self, other):
+ if isinstance(other, self.__class__):
+ return (
+ self.dataset_id == other.dataset_id
+ and self.dag_id == other.dag_id
+ and self.task_id == other.task_id
+ )
+ else:
+ return NotImplemented
+
+ def __hash__(self):
+ return hash(self.__mapper__.primary_key)
+
+ def __repr__(self):
+ args = []
+ for attr in [x.name for x in self.__mapper__.primary_key]:
+ args.append(f"{attr}={getattr(self, attr)!r}")
+ return f"{self.__class__.__name__}({', '.join(args)})"
+
+
+class DatasetDagRunQueue(Base):
+ """Model for storing dataset events that need processing."""
+
+ dataset_id = Column(Integer, primary_key=True, nullable=False)
+ target_dag_id = Column(StringID(), primary_key=True, nullable=False)
+ created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
+
+ __tablename__ = "dataset_dag_run_queue"
+ __table_args__ = (
+ PrimaryKeyConstraint(dataset_id, target_dag_id, name="datasetdagrunqueue_pkey", mssql_clustered=True),
+ ForeignKeyConstraint(
+ (dataset_id,),
+ ["dataset.id"],
+ name="ddrq_dataset_fkey",
+ ondelete="CASCADE",
+ ),
+ ForeignKeyConstraint(
+ (target_dag_id,),
+ ["dag.dag_id"],
+ name="ddrq_dag_fkey",
+ ondelete="CASCADE",
+ ),
+ )
+
+ def __eq__(self, other):
+ if isinstance(other, self.__class__):
+ return self.dataset_id == other.dataset_id and self.target_dag_id == other.target_dag_id
+ else:
+ return NotImplemented
+
+ def __hash__(self):
+ return hash(self.__mapper__.primary_key)
+
+ def __repr__(self):
+ args = []
+ for attr in [x.name for x in self.__mapper__.primary_key]:
+ args.append(f"{attr}={getattr(self, attr)!r}")
+ return f"{self.__class__.__name__}({', '.join(args)})"
+
+
+association_table = Table(
+ "dagrun_dataset_event",
+ Base.metadata,
+ Column("dag_run_id", ForeignKey("dag_run.id", ondelete="CASCADE"), primary_key=True),
+ Column("event_id", ForeignKey("dataset_event.id", ondelete="CASCADE"), primary_key=True),
+ Index("idx_dagrun_dataset_events_dag_run_id", "dag_run_id"),
+ Index("idx_dagrun_dataset_events_event_id", "event_id"),
+)
+
+
+class DatasetEvent(Base):
+ """
+ A table to store datasets events.
+
+ :param dataset_id: reference to DatasetModel record
+ :param extra: JSON field for arbitrary extra info
+ :param source_task_id: the task_id of the TI which updated the dataset
+ :param source_dag_id: the dag_id of the TI which updated the dataset
+ :param source_run_id: the run_id of the TI which updated the dataset
+ :param source_map_index: the map_index of the TI which updated the dataset
+ :param timestamp: the time the event was logged
+
+ We use relationships instead of foreign keys so that dataset events are not deleted even
+ if the foreign key object is.
+ """
+
+ id = Column(Integer, primary_key=True, autoincrement=True)
+ dataset_id = Column(Integer, nullable=False)
+ extra = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=False, default={})
+ source_task_id = Column(StringID(), nullable=True)
+ source_dag_id = Column(StringID(), nullable=True)
+ source_run_id = Column(StringID(), nullable=True)
+ source_map_index = Column(Integer, nullable=True, server_default=text("-1"))
+ timestamp = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
+
+ __tablename__ = "dataset_event"
+ __table_args__ = (
+ Index("idx_dataset_id_timestamp", dataset_id, timestamp),
+ {"sqlite_autoincrement": True}, # ensures PK values not reused
+ )
+
+ created_dagruns = relationship(
+ "DagRun",
+ secondary=association_table,
+ backref="consumed_dataset_events",
+ )
+
+ source_task_instance = relationship(
+ "TaskInstance",
+ primaryjoin="""and_(
+ DatasetEvent.source_dag_id == foreign(TaskInstance.dag_id),
+ DatasetEvent.source_run_id == foreign(TaskInstance.run_id),
+ DatasetEvent.source_task_id == foreign(TaskInstance.task_id),
+ DatasetEvent.source_map_index == foreign(TaskInstance.map_index),
+ )""",
+ viewonly=True,
+ lazy="select",
+ uselist=False,
+ )
+ source_dag_run = relationship(
+ "DagRun",
+ primaryjoin="""and_(
+ DatasetEvent.source_dag_id == foreign(DagRun.dag_id),
+ DatasetEvent.source_run_id == foreign(DagRun.run_id),
+ )""",
+ viewonly=True,
+ lazy="select",
+ uselist=False,
+ )
+ dataset = relationship(
+ DatasetModel,
+ primaryjoin="DatasetEvent.dataset_id == foreign(DatasetModel.id)",
+ viewonly=True,
+ lazy="select",
+ uselist=False,
+ )
+
+ @property
+ def uri(self):
+ return self.dataset.uri
+
+ def __repr__(self) -> str:
+ args = []
+ for attr in [
+ "id",
+ "dataset_id",
+ "extra",
+ "source_task_id",
+ "source_dag_id",
+ "source_run_id",
+ "source_map_index",
+ ]:
+ args.append(f"{attr}={getattr(self, attr)!r}")
+ return f"{self.__class__.__name__}({', '.join(args)})"
diff --git a/airflow/models/db_callback_request.py b/airflow/models/db_callback_request.py
index 4fdd36a71be4b..5a264dee9abc0 100644
--- a/airflow/models/db_callback_request.py
+++ b/airflow/models/db_callback_request.py
@@ -15,6 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
from importlib import import_module
@@ -36,11 +37,12 @@ class DbCallbackRequest(Base):
priority_weight = Column(Integer(), nullable=False)
callback_data = Column(ExtendedJSON, nullable=False)
callback_type = Column(String(20), nullable=False)
- dag_directory = Column(String(1000), nullable=True)
+ processor_subdir = Column(String(2000), nullable=True)
def __init__(self, priority_weight: int, callback: CallbackRequest):
self.created_at = timezone.utcnow()
self.priority_weight = priority_weight
+ self.processor_subdir = callback.processor_subdir
self.callback_data = callback.to_json()
self.callback_type = callback.__class__.__name__
@@ -48,5 +50,5 @@ def get_callback_request(self) -> CallbackRequest:
module = import_module("airflow.callbacks.callback_requests")
callback_class = getattr(module, self.callback_type)
# Get the function (from the instance) that we need to call
- from_json = getattr(callback_class, 'from_json')
+ from_json = getattr(callback_class, "from_json")
return from_json(self.callback_data)
diff --git a/airflow/models/errors.py b/airflow/models/errors.py
index 9718c063a28f0..974d9b8eebbe5 100644
--- a/airflow/models/errors.py
+++ b/airflow/models/errors.py
@@ -15,6 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
from sqlalchemy import Column, Integer, String, Text
diff --git a/airflow/models/expandinput.py b/airflow/models/expandinput.py
new file mode 100644
index 0000000000000..8a9a3d874012d
--- /dev/null
+++ b/airflow/models/expandinput.py
@@ -0,0 +1,284 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import collections.abc
+import functools
+import operator
+from typing import TYPE_CHECKING, Any, Dict, Iterable, Mapping, NamedTuple, Sequence, Sized, Union
+
+import attr
+
+from airflow.typing_compat import TypeGuard
+from airflow.utils.context import Context
+from airflow.utils.mixins import ResolveMixin
+from airflow.utils.session import NEW_SESSION, provide_session
+
+if TYPE_CHECKING:
+ from sqlalchemy.orm import Session
+
+ from airflow.models.operator import Operator
+ from airflow.models.xcom_arg import XComArg
+
+ExpandInput = Union["DictOfListsExpandInput", "ListOfDictsExpandInput"]
+
+# Each keyword argument to expand() can be an XComArg, sequence, or dict (not
+# any mapping since we need the value to be ordered).
+OperatorExpandArgument = Union["MappedArgument", "XComArg", Sequence, Dict[str, Any]]
+
+# The single argument of expand_kwargs() can be an XComArg, or a list with each
+# element being either an XComArg or a dict.
+OperatorExpandKwargsArgument = Union["XComArg", Sequence[Union["XComArg", Mapping[str, Any]]]]
+
+
+@attr.define(kw_only=True)
+class MappedArgument(ResolveMixin):
+ """Stand-in stub for task-group-mapping arguments.
+
+ This is very similar to an XComArg, but resolved differently. Declared here
+ (instead of in the task group module) to avoid import cycles.
+ """
+
+ _input: ExpandInput
+ _key: str
+
+ def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
+ # TODO (AIP-42): Implement run-time task map length inspection. This is
+ # needed when we implement task mapping inside a mapped task group.
+ raise NotImplementedError()
+
+ def iter_references(self) -> Iterable[tuple[Operator, str]]:
+ yield from self._input.iter_references()
+
+ @provide_session
+ def resolve(self, context: Context, *, session: Session = NEW_SESSION) -> Any:
+ data, _ = self._input.resolve(context, session=session)
+ return data[self._key]
+
+
+# To replace tedious isinstance() checks.
+def is_mappable(v: Any) -> TypeGuard[OperatorExpandArgument]:
+ from airflow.models.xcom_arg import XComArg
+
+ return isinstance(v, (MappedArgument, XComArg, Mapping, Sequence)) and not isinstance(v, str)
+
+
+# To replace tedious isinstance() checks.
+def _is_parse_time_mappable(v: OperatorExpandArgument) -> TypeGuard[Mapping | Sequence]:
+ from airflow.models.xcom_arg import XComArg
+
+ return not isinstance(v, (MappedArgument, XComArg))
+
+
+# To replace tedious isinstance() checks.
+def _needs_run_time_resolution(v: OperatorExpandArgument) -> TypeGuard[MappedArgument | XComArg]:
+ from airflow.models.xcom_arg import XComArg
+
+ return isinstance(v, (MappedArgument, XComArg))
+
+
+class NotFullyPopulated(RuntimeError):
+ """Raise when ``get_map_lengths`` cannot populate all mapping metadata.
+
+ This is generally due to not all upstream tasks have finished when the
+ function is called.
+ """
+
+ def __init__(self, missing: set[str]) -> None:
+ self.missing = missing
+
+ def __str__(self) -> str:
+ keys = ", ".join(repr(k) for k in sorted(self.missing))
+ return f"Failed to populate all mapping metadata; missing: {keys}"
+
+
+class DictOfListsExpandInput(NamedTuple):
+ """Storage type of a mapped operator's mapped kwargs.
+
+ This is created from ``expand(**kwargs)``.
+ """
+
+ value: dict[str, OperatorExpandArgument]
+
+ def _iter_parse_time_resolved_kwargs(self) -> Iterable[tuple[str, Sized]]:
+ """Generate kwargs with values available on parse-time."""
+ return ((k, v) for k, v in self.value.items() if _is_parse_time_mappable(v))
+
+ def get_parse_time_mapped_ti_count(self) -> int:
+ if not self.value:
+ return 0
+ literal_values = [len(v) for _, v in self._iter_parse_time_resolved_kwargs()]
+ if len(literal_values) != len(self.value):
+ literal_keys = (k for k, _ in self._iter_parse_time_resolved_kwargs())
+ raise NotFullyPopulated(set(self.value).difference(literal_keys))
+ return functools.reduce(operator.mul, literal_values, 1)
+
+ def _get_map_lengths(self, run_id: str, *, session: Session) -> dict[str, int]:
+ """Return dict of argument name to map length.
+
+ If any arguments are not known right now (upstream task not finished),
+ they will not be present in the dict.
+ """
+ # TODO: This initiates one database call for each XComArg. Would it be
+ # more efficient to do one single db call and unpack the value here?
+ def _get_length(v: OperatorExpandArgument) -> int | None:
+ if _needs_run_time_resolution(v):
+ return v.get_task_map_length(run_id, session=session)
+ # Unfortunately a user-defined TypeGuard cannot apply negative type
+ # narrowing. https://github.com/python/typing/discussions/1013
+ if TYPE_CHECKING:
+ assert isinstance(v, Sized)
+ return len(v)
+
+ map_lengths_iterator = ((k, _get_length(v)) for k, v in self.value.items())
+
+ map_lengths = {k: v for k, v in map_lengths_iterator if v is not None}
+ if len(map_lengths) < len(self.value):
+ raise NotFullyPopulated(set(self.value).difference(map_lengths))
+ return map_lengths
+
+ def get_total_map_length(self, run_id: str, *, session: Session) -> int:
+ if not self.value:
+ return 0
+ lengths = self._get_map_lengths(run_id, session=session)
+ return functools.reduce(operator.mul, (lengths[name] for name in self.value), 1)
+
+ def _expand_mapped_field(self, key: str, value: Any, context: Context, *, session: Session) -> Any:
+ if _needs_run_time_resolution(value):
+ value = value.resolve(context, session=session)
+ map_index = context["ti"].map_index
+ if map_index < 0:
+ raise RuntimeError("can't resolve task-mapping argument without expanding")
+ all_lengths = self._get_map_lengths(context["run_id"], session=session)
+
+ def _find_index_for_this_field(index: int) -> int:
+ # Need to use the original user input to retain argument order.
+ for mapped_key in reversed(list(self.value)):
+ mapped_length = all_lengths[mapped_key]
+ if mapped_length < 1:
+ raise RuntimeError(f"cannot expand field mapped to length {mapped_length!r}")
+ if mapped_key == key:
+ return index % mapped_length
+ index //= mapped_length
+ return -1
+
+ found_index = _find_index_for_this_field(map_index)
+ if found_index < 0:
+ return value
+ if isinstance(value, collections.abc.Sequence):
+ return value[found_index]
+ if not isinstance(value, dict):
+ raise TypeError(f"can't map over value of type {type(value)}")
+ for i, (k, v) in enumerate(value.items()):
+ if i == found_index:
+ return k, v
+ raise IndexError(f"index {map_index} is over mapped length")
+
+ def iter_references(self) -> Iterable[tuple[Operator, str]]:
+ from airflow.models.xcom_arg import XComArg
+
+ for x in self.value.values():
+ if isinstance(x, XComArg):
+ yield from x.iter_references()
+
+ def resolve(self, context: Context, session: Session) -> tuple[Mapping[str, Any], set[int]]:
+ data = {k: self._expand_mapped_field(k, v, context, session=session) for k, v in self.value.items()}
+ literal_keys = {k for k, _ in self._iter_parse_time_resolved_kwargs()}
+ resolved_oids = {id(v) for k, v in data.items() if k not in literal_keys}
+ return data, resolved_oids
+
+
+def _describe_type(value: Any) -> str:
+ if value is None:
+ return "None"
+ return type(value).__name__
+
+
+class ListOfDictsExpandInput(NamedTuple):
+ """Storage type of a mapped operator's mapped kwargs.
+
+ This is created from ``expand_kwargs(xcom_arg)``.
+ """
+
+ value: OperatorExpandKwargsArgument
+
+ def get_parse_time_mapped_ti_count(self) -> int:
+ if isinstance(self.value, collections.abc.Sized):
+ return len(self.value)
+ raise NotFullyPopulated({"expand_kwargs() argument"})
+
+ def get_total_map_length(self, run_id: str, *, session: Session) -> int:
+ if isinstance(self.value, collections.abc.Sized):
+ return len(self.value)
+ length = self.value.get_task_map_length(run_id, session=session)
+ if length is None:
+ raise NotFullyPopulated({"expand_kwargs() argument"})
+ return length
+
+ def iter_references(self) -> Iterable[tuple[Operator, str]]:
+ from airflow.models.xcom_arg import XComArg
+
+ if isinstance(self.value, XComArg):
+ yield from self.value.iter_references()
+ else:
+ for x in self.value:
+ if isinstance(x, XComArg):
+ yield from x.iter_references()
+
+ def resolve(self, context: Context, session: Session) -> tuple[Mapping[str, Any], set[int]]:
+ map_index = context["ti"].map_index
+ if map_index < 0:
+ raise RuntimeError("can't resolve task-mapping argument without expanding")
+
+ mapping: Any
+ if isinstance(self.value, collections.abc.Sized):
+ mapping = self.value[map_index]
+ if not isinstance(mapping, collections.abc.Mapping):
+ mapping = mapping.resolve(context, session)
+ else:
+ mappings = self.value.resolve(context, session)
+ if not isinstance(mappings, collections.abc.Sequence):
+ raise ValueError(f"expand_kwargs() expects a list[dict], not {_describe_type(mappings)}")
+ mapping = mappings[map_index]
+
+ if not isinstance(mapping, collections.abc.Mapping):
+ raise ValueError(f"expand_kwargs() expects a list[dict], not list[{_describe_type(mapping)}]")
+
+ for key in mapping:
+ if not isinstance(key, str):
+ raise ValueError(
+ f"expand_kwargs() input dict keys must all be str, "
+ f"but {key!r} is of type {_describe_type(key)}"
+ )
+ return mapping, {id(v) for v in mapping.values()}
+
+
+EXPAND_INPUT_EMPTY = DictOfListsExpandInput({}) # Sentinel value.
+
+_EXPAND_INPUT_TYPES = {
+ "dict-of-lists": DictOfListsExpandInput,
+ "list-of-dicts": ListOfDictsExpandInput,
+}
+
+
+def get_map_type_key(expand_input: ExpandInput) -> str:
+ return next(k for k, v in _EXPAND_INPUT_TYPES.items() if v == type(expand_input))
+
+
+def create_expand_input(kind: str, value: Any) -> ExpandInput:
+ return _EXPAND_INPUT_TYPES[kind](value)
diff --git a/airflow/models/log.py b/airflow/models/log.py
index b2a5639dcdd57..7994abc4638ef 100644
--- a/airflow/models/log.py
+++ b/airflow/models/log.py
@@ -15,8 +15,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
-from sqlalchemy import Column, Index, Integer, String, Text, text
+from sqlalchemy import Column, Index, Integer, String, Text
from airflow.models.base import Base, StringID
from airflow.utils import timezone
@@ -32,15 +33,15 @@ class Log(Base):
dttm = Column(UtcDateTime)
dag_id = Column(StringID())
task_id = Column(StringID())
- map_index = Column(Integer, server_default=text('NULL'))
+ map_index = Column(Integer)
event = Column(String(30))
execution_date = Column(UtcDateTime)
owner = Column(String(500))
extra = Column(Text)
__table_args__ = (
- Index('idx_log_dag', dag_id),
- Index('idx_log_event', event),
+ Index("idx_log_dag", dag_id),
+ Index("idx_log_event", event),
)
def __init__(self, event, task_instance=None, owner=None, extra=None, **kwargs):
@@ -55,15 +56,19 @@ def __init__(self, event, task_instance=None, owner=None, extra=None, **kwargs):
self.task_id = task_instance.task_id
self.execution_date = task_instance.execution_date
self.map_index = task_instance.map_index
- task_owner = task_instance.task.owner
+ if getattr(task_instance, "task", None):
+ task_owner = task_instance.task.owner
- if 'task_id' in kwargs:
- self.task_id = kwargs['task_id']
- if 'dag_id' in kwargs:
- self.dag_id = kwargs['dag_id']
- if kwargs.get('execution_date'):
- self.execution_date = kwargs['execution_date']
- if 'map_index' in kwargs:
- self.map_index = kwargs['map_index']
+ if "task_id" in kwargs:
+ self.task_id = kwargs["task_id"]
+ if "dag_id" in kwargs:
+ self.dag_id = kwargs["dag_id"]
+ if kwargs.get("execution_date"):
+ self.execution_date = kwargs["execution_date"]
+ if "map_index" in kwargs:
+ self.map_index = kwargs["map_index"]
self.owner = owner or task_owner
+
+ def __str__(self) -> str:
+ return f"Log({self.event}, {self.task_id}, {self.owner}, {self.extra})"
diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py
index fe18a97cc1414..99e2b67f50018 100644
--- a/airflow/models/mappedoperator.py
+++ b/airflow/models/mappedoperator.py
@@ -15,38 +15,22 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
import collections
import collections.abc
+import contextlib
+import copy
import datetime
-import functools
-import operator
import warnings
-from typing import (
- TYPE_CHECKING,
- Any,
- ClassVar,
- Collection,
- Dict,
- FrozenSet,
- Iterable,
- Iterator,
- List,
- Optional,
- Sequence,
- Set,
- Tuple,
- Type,
- Union,
-)
+from typing import TYPE_CHECKING, Any, ClassVar, Collection, Iterable, Iterator, Mapping, Sequence, Union
import attr
import pendulum
-from sqlalchemy import func, or_
from sqlalchemy.orm.session import Session
from airflow import settings
-from airflow.compat.functools import cache, cached_property
+from airflow.compat.functools import cache
from airflow.exceptions import AirflowException, UnmappableOperator
from airflow.models.abstractoperator import (
DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST,
@@ -59,17 +43,26 @@
DEFAULT_TRIGGER_RULE,
DEFAULT_WEIGHT_RULE,
AbstractOperator,
+ NotMapped,
TaskStateChangeCallback,
)
+from airflow.models.expandinput import (
+ DictOfListsExpandInput,
+ ExpandInput,
+ ListOfDictsExpandInput,
+ OperatorExpandArgument,
+ OperatorExpandKwargsArgument,
+ is_mappable,
+)
+from airflow.models.param import ParamsDict
from airflow.models.pool import Pool
from airflow.serialization.enums import DagAttributeTypes
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
from airflow.ti_deps.deps.mapped_task_expanded import MappedTaskIsExpanded
from airflow.typing_compat import Literal
-from airflow.utils.context import Context
-from airflow.utils.helpers import is_container
+from airflow.utils.context import Context, context_update_for_unmapped
+from airflow.utils.helpers import is_container, prevent_duplicates
from airflow.utils.operator_resources import Resources
-from airflow.utils.state import State, TaskInstanceState
from airflow.utils.trigger_rule import TriggerRule
from airflow.utils.types import NOTSET
@@ -79,29 +72,13 @@
from airflow.models.baseoperator import BaseOperator, BaseOperatorLink
from airflow.models.dag import DAG
from airflow.models.operator import Operator
- from airflow.models.taskinstance import TaskInstance
from airflow.models.xcom_arg import XComArg
from airflow.utils.task_group import TaskGroup
- # BaseOperator.expand() can be called on an XComArg, sequence, or dict (not
- # any mapping since we need the value to be ordered).
- Mappable = Union[XComArg, Sequence, dict]
-
ValidationSource = Union[Literal["expand"], Literal["partial"]]
-MAPPABLE_LITERAL_TYPES = (dict, list)
-
-
-# For isinstance() check.
-@cache
-def get_mappable_types() -> Tuple[type, ...]:
- from airflow.models.xcom_arg import XComArg
-
- return (XComArg,) + MAPPABLE_LITERAL_TYPES
-
-
-def validate_mapping_kwargs(op: Type["BaseOperator"], func: ValidationSource, value: Dict[str, Any]) -> None:
+def validate_mapping_kwargs(op: type[BaseOperator], func: ValidationSource, value: dict[str, Any]) -> None:
# use a dict so order of args is same as code order
unknown_args = value.copy()
for klass in op.mro():
@@ -116,7 +93,7 @@ def validate_mapping_kwargs(op: Type["BaseOperator"], func: ValidationSource, va
continue
if value is NOTSET:
continue
- if isinstance(value, get_mappable_types()):
+ if is_mappable(value):
continue
type_name = type(value).__name__
error = f"{op.__name__}.expand() got an unexpected type {type_name!r} for keyword argument {name}"
@@ -132,22 +109,13 @@ def validate_mapping_kwargs(op: Type["BaseOperator"], func: ValidationSource, va
raise TypeError(f"{op.__name__}.{func}() got {error}")
-def prevent_duplicates(kwargs1: Dict[str, Any], kwargs2: Dict[str, Any], *, fail_reason: str) -> None:
- duplicated_keys = set(kwargs1).intersection(kwargs2)
- if not duplicated_keys:
- return
- if len(duplicated_keys) == 1:
- raise TypeError(f"{fail_reason} argument: {duplicated_keys.pop()}")
- duplicated_keys_display = ", ".join(sorted(duplicated_keys))
- raise TypeError(f"{fail_reason} arguments: {duplicated_keys_display}")
-
-
def ensure_xcomarg_return_value(arg: Any) -> None:
from airflow.models.xcom_arg import XCOM_RETURN_KEY, XComArg
if isinstance(arg, XComArg):
- if arg.key != XCOM_RETURN_KEY:
- raise ValueError(f"cannot map over XCom with custom key {arg.key!r} from {arg.operator}")
+ for operator, key in arg.iter_references():
+ if key != XCOM_RETURN_KEY:
+ raise ValueError(f"cannot map over XCom with custom key {key!r} from {operator}")
elif not is_container(arg):
return
elif isinstance(arg, collections.abc.Mapping):
@@ -167,8 +135,9 @@ class OperatorPartial:
create a ``MappedOperator`` to add into the DAG.
"""
- operator_class: Type["BaseOperator"]
- kwargs: Dict[str, Any]
+ operator_class: type[BaseOperator]
+ kwargs: dict[str, Any]
+ params: ParamsDict | dict
_expand_called: bool = False # Set when expand() is called to ease user debugging.
@@ -191,34 +160,50 @@ def __del__(self):
task_id = f"at {hex(id(self))}"
warnings.warn(f"Task {task_id} was never mapped!")
- def expand(self, **mapped_kwargs: "Mappable") -> "MappedOperator":
+ def expand(self, **mapped_kwargs: OperatorExpandArgument) -> MappedOperator:
if not mapped_kwargs:
raise TypeError("no arguments to expand against")
- return self._expand(**mapped_kwargs)
+ validate_mapping_kwargs(self.operator_class, "expand", mapped_kwargs)
+ prevent_duplicates(self.kwargs, mapped_kwargs, fail_reason="unmappable or already specified")
+ # Since the input is already checked at parse time, we can set strict
+ # to False to skip the checks on execution.
+ return self._expand(DictOfListsExpandInput(mapped_kwargs), strict=False)
- def _expand(self, **mapped_kwargs: "Mappable") -> "MappedOperator":
- self._expand_called = True
+ def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> MappedOperator:
+ from airflow.models.xcom_arg import XComArg
+ if isinstance(kwargs, collections.abc.Sequence):
+ for item in kwargs:
+ if not isinstance(item, (XComArg, collections.abc.Mapping)):
+ raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")
+ elif not isinstance(kwargs, XComArg):
+ raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")
+ return self._expand(ListOfDictsExpandInput(kwargs), strict=strict)
+
+ def _expand(self, expand_input: ExpandInput, *, strict: bool) -> MappedOperator:
from airflow.operators.empty import EmptyOperator
- validate_mapping_kwargs(self.operator_class, "expand", mapped_kwargs)
- prevent_duplicates(self.kwargs, mapped_kwargs, fail_reason="unmappable or already specified")
- ensure_xcomarg_return_value(mapped_kwargs)
+ self._expand_called = True
+ ensure_xcomarg_return_value(expand_input.value)
partial_kwargs = self.kwargs.copy()
task_id = partial_kwargs.pop("task_id")
- params = partial_kwargs.pop("params")
dag = partial_kwargs.pop("dag")
task_group = partial_kwargs.pop("task_group")
start_date = partial_kwargs.pop("start_date")
end_date = partial_kwargs.pop("end_date")
+ try:
+ operator_name = self.operator_class.custom_operator_name # type: ignore
+ except AttributeError:
+ operator_name = self.operator_class.__name__
+
op = MappedOperator(
operator_class=self.operator_class,
- mapped_kwargs=mapped_kwargs,
+ expand_input=expand_input,
partial_kwargs=partial_kwargs,
task_id=task_id,
- params=params,
+ params=self.params,
deps=MappedOperator.deps_for(self.operator_class),
operator_extra_links=self.operator_class.operator_extra_links,
template_ext=self.operator_class.template_ext,
@@ -229,18 +214,30 @@ def _expand(self, **mapped_kwargs: "Mappable") -> "MappedOperator":
is_empty=issubclass(self.operator_class, EmptyOperator),
task_module=self.operator_class.__module__,
task_type=self.operator_class.__name__,
+ operator_name=operator_name,
dag=dag,
task_group=task_group,
start_date=start_date,
end_date=end_date,
- # For classic operators, this points to mapped_kwargs because kwargs
+ disallow_kwargs_override=strict,
+ # For classic operators, this points to expand_input because kwargs
# to BaseOperator.expand() contribute to operator arguments.
- expansion_kwargs_attr="mapped_kwargs",
+ expand_input_attr="expand_input",
)
return op
-@attr.define(kw_only=True)
+@attr.define(
+ kw_only=True,
+ # Disable custom __getstate__ and __setstate__ generation since it interacts
+ # badly with Airflow's DAG serialization and pickling. When a mapped task is
+ # deserialized, subclasses are coerced into MappedOperator, but when it goes
+ # through DAG pickling, all attributes defined in the subclasses are dropped
+ # by attrs's custom state management. Since attrs does not do anything too
+ # special here (the logic is only important for slots=True), we use Python's
+ # built-in implementation, which works (as proven by good old BaseOperator).
+ getstate_setstate=False,
+)
class MappedOperator(AbstractOperator):
"""Object representing a mapped operator in a DAG."""
@@ -249,45 +246,52 @@ class MappedOperator(AbstractOperator):
# can be used to create an unmapped operator for execution. For an operator
# recreated from a serialized DAG, however, this holds the serialized data
# that can be used to unmap this into a SerializedBaseOperator.
- operator_class: Union[Type["BaseOperator"], Dict[str, Any]]
+ operator_class: type[BaseOperator] | dict[str, Any]
- mapped_kwargs: Dict[str, "Mappable"]
- partial_kwargs: Dict[str, Any]
+ expand_input: ExpandInput
+ partial_kwargs: dict[str, Any]
# Needed for serialization.
task_id: str
- params: Optional[dict]
- deps: FrozenSet[BaseTIDep]
- operator_extra_links: Collection["BaseOperatorLink"]
+ params: ParamsDict | dict
+ deps: frozenset[BaseTIDep]
+ operator_extra_links: Collection[BaseOperatorLink]
template_ext: Sequence[str]
template_fields: Collection[str]
- template_fields_renderers: Dict[str, str]
+ template_fields_renderers: dict[str, str]
ui_color: str
ui_fgcolor: str
_is_empty: bool
_task_module: str
_task_type: str
+ _operator_name: str
+
+ dag: DAG | None
+ task_group: TaskGroup | None
+ start_date: pendulum.DateTime | None
+ end_date: pendulum.DateTime | None
+ upstream_task_ids: set[str] = attr.ib(factory=set, init=False)
+ downstream_task_ids: set[str] = attr.ib(factory=set, init=False)
- dag: Optional["DAG"]
- task_group: Optional["TaskGroup"]
- start_date: Optional[pendulum.DateTime]
- end_date: Optional[pendulum.DateTime]
- upstream_task_ids: Set[str] = attr.ib(factory=set, init=False)
- downstream_task_ids: Set[str] = attr.ib(factory=set, init=False)
+ _disallow_kwargs_override: bool
+ """Whether execution fails if ``expand_input`` has duplicates to ``partial_kwargs``.
- _expansion_kwargs_attr: str
+ If *False*, values from ``expand_input`` under duplicate keys override those
+ under corresponding keys in ``partial_kwargs``.
+ """
+
+ _expand_input_attr: str
"""Where to get kwargs to calculate expansion length against.
This should be a name to call ``getattr()`` on.
"""
- is_mapped: ClassVar[bool] = True
subdag: None = None # Since we don't support SubDagOperator, this is always None.
- HIDE_ATTRS_FROM_UI: ClassVar[FrozenSet[str]] = AbstractOperator.HIDE_ATTRS_FROM_UI | frozenset(
+ HIDE_ATTRS_FROM_UI: ClassVar[frozenset[str]] = AbstractOperator.HIDE_ATTRS_FROM_UI | frozenset(
(
- 'parse_time_mapped_ti_count',
- 'operator_class',
+ "parse_time_mapped_ti_count",
+ "operator_class",
)
)
@@ -300,17 +304,18 @@ def __repr__(self):
def __attrs_post_init__(self):
from airflow.models.xcom_arg import XComArg
- self._validate_argument_count()
+ if self.get_closest_mapped_task_group() is not None:
+ raise NotImplementedError("operator expansion in an expanded task group is not yet supported")
+
if self.task_group:
self.task_group.add(self)
if self.dag:
self.dag.add_task(self)
- for k, v in self.mapped_kwargs.items():
- XComArg.apply_upstream_relationship(self, v)
+ XComArg.apply_upstream_relationship(self, self.expand_input.value)
for k, v in self.partial_kwargs.items():
if k in self.template_fields:
XComArg.apply_upstream_relationship(self, v)
- if self.partial_kwargs.get('sla') is not None:
+ if self.partial_kwargs.get("sla") is not None:
raise AirflowException(
f"SLAs are unsupported with mapped tasks. Please set `sla=None` for task "
f"{self.task_id!r}."
@@ -323,7 +328,7 @@ def get_serialized_fields(cls):
return frozenset(attr.fields_dict(MappedOperator)) - {
"dag",
"deps",
- "is_mapped",
+ "expand_input", # This is needed to be able to accept XComArg.
"subdag",
"task_group",
"upstream_task_ids",
@@ -331,7 +336,7 @@ def get_serialized_fields(cls):
@staticmethod
@cache
- def deps_for(operator_class: Type["BaseOperator"]) -> FrozenSet[BaseTIDep]:
+ def deps_for(operator_class: type[BaseOperator]) -> frozenset[BaseTIDep]:
operator_deps = operator_class.deps
if not isinstance(operator_deps, collections.abc.Set):
raise UnmappableOperator(
@@ -340,23 +345,15 @@ def deps_for(operator_class: Type["BaseOperator"]) -> FrozenSet[BaseTIDep]:
)
return operator_deps | {MappedTaskIsExpanded()}
- def _validate_argument_count(self) -> None:
- """Validate mapping arguments by unmapping with mocked values.
-
- This ensures the user passed enough arguments in the DAG definition for
- the operator to work in the task runner. This does not guarantee the
- arguments are *valid* (that depends on the actual mapping values), but
- makes sure there are *enough* of them.
- """
- if not isinstance(self.operator_class, type):
- return # No need to validate deserialized operator.
- self.operator_class.validate_mapped_arguments(**self._get_unmap_kwargs())
-
@property
def task_type(self) -> str:
"""Implementing Operator."""
return self._task_type
+ @property
+ def operator_name(self) -> str:
+ return self._operator_name
+
@property
def inherits_from_empty_operator(self) -> bool:
"""Implementing Operator."""
@@ -377,7 +374,7 @@ def owner(self) -> str: # type: ignore[override]
return self.partial_kwargs.get("owner", DEFAULT_OWNER)
@property
- def email(self) -> Union[None, str, Iterable[str]]:
+ def email(self) -> None | str | Iterable[str]:
return self.partial_kwargs.get("email")
@property
@@ -398,7 +395,7 @@ def wait_for_downstream(self) -> bool:
return bool(self.partial_kwargs.get("wait_for_downstream"))
@property
- def retries(self) -> Optional[int]:
+ def retries(self) -> int | None:
return self.partial_kwargs.get("retries", DEFAULT_RETRIES)
@property
@@ -410,15 +407,15 @@ def pool(self) -> str:
return self.partial_kwargs.get("pool", Pool.DEFAULT_POOL_NAME)
@property
- def pool_slots(self) -> Optional[str]:
+ def pool_slots(self) -> str | None:
return self.partial_kwargs.get("pool_slots", DEFAULT_POOL_SLOTS)
@property
- def execution_timeout(self) -> Optional[datetime.timedelta]:
+ def execution_timeout(self) -> datetime.timedelta | None:
return self.partial_kwargs.get("execution_timeout")
@property
- def max_retry_delay(self) -> Optional[datetime.timedelta]:
+ def max_retry_delay(self) -> datetime.timedelta | None:
return self.partial_kwargs.get("max_retry_delay")
@property
@@ -438,109 +435,175 @@ def weight_rule(self) -> int: # type: ignore[override]
return self.partial_kwargs.get("weight_rule", DEFAULT_WEIGHT_RULE)
@property
- def sla(self) -> Optional[datetime.timedelta]:
+ def sla(self) -> datetime.timedelta | None:
return self.partial_kwargs.get("sla")
@property
- def max_active_tis_per_dag(self) -> Optional[int]:
+ def max_active_tis_per_dag(self) -> int | None:
return self.partial_kwargs.get("max_active_tis_per_dag")
@property
- def resources(self) -> Optional[Resources]:
+ def resources(self) -> Resources | None:
return self.partial_kwargs.get("resources")
@property
- def on_execute_callback(self) -> Optional[TaskStateChangeCallback]:
+ def on_execute_callback(self) -> TaskStateChangeCallback | None:
return self.partial_kwargs.get("on_execute_callback")
+ @on_execute_callback.setter
+ def on_execute_callback(self, value: TaskStateChangeCallback | None) -> None:
+ self.partial_kwargs["on_execute_callback"] = value
+
@property
- def on_failure_callback(self) -> Optional[TaskStateChangeCallback]:
+ def on_failure_callback(self) -> TaskStateChangeCallback | None:
return self.partial_kwargs.get("on_failure_callback")
+ @on_failure_callback.setter
+ def on_failure_callback(self, value: TaskStateChangeCallback | None) -> None:
+ self.partial_kwargs["on_failure_callback"] = value
+
@property
- def on_retry_callback(self) -> Optional[TaskStateChangeCallback]:
+ def on_retry_callback(self) -> TaskStateChangeCallback | None:
return self.partial_kwargs.get("on_retry_callback")
+ @on_retry_callback.setter
+ def on_retry_callback(self, value: TaskStateChangeCallback | None) -> None:
+ self.partial_kwargs["on_retry_callback"] = value
+
@property
- def on_success_callback(self) -> Optional[TaskStateChangeCallback]:
+ def on_success_callback(self) -> TaskStateChangeCallback | None:
return self.partial_kwargs.get("on_success_callback")
+ @on_success_callback.setter
+ def on_success_callback(self, value: TaskStateChangeCallback | None) -> None:
+ self.partial_kwargs["on_success_callback"] = value
+
@property
- def run_as_user(self) -> Optional[str]:
+ def run_as_user(self) -> str | None:
return self.partial_kwargs.get("run_as_user")
@property
def executor_config(self) -> dict:
return self.partial_kwargs.get("executor_config", {})
- @property
- def inlets(self) -> Optional[Any]:
- return self.partial_kwargs.get("inlets", None)
+ @property # type: ignore[override]
+ def inlets(self) -> list[Any]: # type: ignore[override]
+ return self.partial_kwargs.get("inlets", [])
- @property
- def outlets(self) -> Optional[Any]:
- return self.partial_kwargs.get("outlets", None)
+ @inlets.setter
+ def inlets(self, value: list[Any]) -> None: # type: ignore[override]
+ self.partial_kwargs["inlets"] = value
+
+ @property # type: ignore[override]
+ def outlets(self) -> list[Any]: # type: ignore[override]
+ return self.partial_kwargs.get("outlets", [])
+
+ @outlets.setter
+ def outlets(self, value: list[Any]) -> None: # type: ignore[override]
+ self.partial_kwargs["outlets"] = value
@property
- def doc(self) -> Optional[str]:
+ def doc(self) -> str | None:
return self.partial_kwargs.get("doc")
@property
- def doc_md(self) -> Optional[str]:
+ def doc_md(self) -> str | None:
return self.partial_kwargs.get("doc_md")
@property
- def doc_json(self) -> Optional[str]:
+ def doc_json(self) -> str | None:
return self.partial_kwargs.get("doc_json")
@property
- def doc_yaml(self) -> Optional[str]:
+ def doc_yaml(self) -> str | None:
return self.partial_kwargs.get("doc_yaml")
@property
- def doc_rst(self) -> Optional[str]:
+ def doc_rst(self) -> str | None:
return self.partial_kwargs.get("doc_rst")
- def get_dag(self) -> Optional["DAG"]:
+ def get_dag(self) -> DAG | None:
"""Implementing Operator."""
return self.dag
- def serialize_for_task_group(self) -> Tuple[DagAttributeTypes, Any]:
+ @property
+ def output(self) -> XComArg:
+ """Returns reference to XCom pushed by current operator"""
+ from airflow.models.xcom_arg import XComArg
+
+ return XComArg(operator=self)
+
+ def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
"""Implementing DAGNode."""
return DagAttributeTypes.OP, self.task_id
- def _get_unmap_kwargs(self) -> Dict[str, Any]:
+ def _expand_mapped_kwargs(self, context: Context, session: Session) -> tuple[Mapping[str, Any], set[int]]:
+ """Get the kwargs to create the unmapped operator.
+
+ This exists because taskflow operators expand against op_kwargs, not the
+ entire operator kwargs dict.
+ """
+ return self._get_specified_expand_input().resolve(context, session)
+
+ def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict: bool) -> dict[str, Any]:
+ """Get init kwargs to unmap the underlying operator class.
+
+ :param mapped_kwargs: The dict returned by ``_expand_mapped_kwargs``.
+ """
+ if strict:
+ prevent_duplicates(
+ self.partial_kwargs,
+ mapped_kwargs,
+ fail_reason="unmappable or already specified",
+ )
+
+ # If params appears in the mapped kwargs, we need to merge it into the
+ # partial params, overriding existing keys.
+ params = copy.copy(self.params)
+ with contextlib.suppress(KeyError):
+ params.update(mapped_kwargs["params"])
+
+ # Ordering is significant; mapped kwargs should override partial ones,
+ # and the specially handled params should be respected.
return {
"task_id": self.task_id,
"dag": self.dag,
"task_group": self.task_group,
- "params": self.params,
"start_date": self.start_date,
"end_date": self.end_date,
**self.partial_kwargs,
- **self.mapped_kwargs,
+ **mapped_kwargs,
+ "params": params,
}
- def unmap(self, unmap_kwargs: Optional[Dict[str, Any]] = None) -> "BaseOperator":
- """
- Get the "normal" Operator after applying the current mapping.
+ def unmap(self, resolve: None | Mapping[str, Any] | tuple[Context, Session]) -> BaseOperator:
+ """Get the "normal" Operator after applying the current mapping.
- If ``operator_class`` is not a class (i.e. this DAG has been deserialized) then this will return a
- SerializedBaseOperator that aims to "look like" the real operator.
+ The *resolve* argument is only used if ``operator_class`` is a real
+ class, i.e. if this operator is not serialized. If ``operator_class`` is
+ not a class (i.e. this DAG has been deserialized), this returns a
+ SerializedBaseOperator that "looks like" the actual unmapping result.
- :param unmap_kwargs: Override the args to pass to the Operator constructor. Only used when
- ``operator_class`` is still an actual class.
+ If *resolve* is a two-tuple (context, session), the information is used
+ to resolve the mapped arguments into init arguments. If it is a mapping,
+ no resolving happens, the mapping directly provides those init arguments
+ resolved from mapped kwargs.
:meta private:
"""
if isinstance(self.operator_class, type):
- # We can't simply specify task_id here because BaseOperator further
+ if isinstance(resolve, collections.abc.Mapping):
+ kwargs = resolve
+ elif resolve is not None:
+ kwargs, _ = self._expand_mapped_kwargs(*resolve)
+ else:
+ raise RuntimeError("cannot unmap a non-serialized operator without context")
+ kwargs = self._get_unmap_kwargs(kwargs, strict=self._disallow_kwargs_override)
+ op = self.operator_class(**kwargs, _airflow_from_mapped=True)
+ # We need to overwrite task_id here because BaseOperator further
# mangles the task_id based on the task hierarchy (namely, group_id
- # is prepended, and '__N' appended to deduplicate). Instead of
- # recreating the whole logic here, we just overwrite task_id later.
- if unmap_kwargs is None:
- unmap_kwargs = self._get_unmap_kwargs()
- op = self.operator_class(**unmap_kwargs, _airflow_from_mapped=True)
+ # is prepended, and '__N' appended to deduplicate). This is hacky,
+ # but better than duplicating the whole mangling logic.
op.task_id = self.task_id
return op
@@ -550,311 +613,77 @@ def unmap(self, unmap_kwargs: Optional[Dict[str, Any]] = None) -> "BaseOperator"
# mapped operator to a new SerializedBaseOperator instance.
from airflow.serialization.serialized_objects import SerializedBaseOperator
- op = SerializedBaseOperator(task_id=self.task_id, _airflow_from_mapped=True)
+ op = SerializedBaseOperator(task_id=self.task_id, params=self.params, _airflow_from_mapped=True)
SerializedBaseOperator.populate_operator(op, self.operator_class)
return op
- def _get_expansion_kwargs(self) -> Dict[str, "Mappable"]:
- """The kwargs to calculate expansion length against."""
- return getattr(self, self._expansion_kwargs_attr)
+ def _get_specified_expand_input(self) -> ExpandInput:
+ """Input received from the expand call on the operator."""
+ return getattr(self, self._expand_input_attr)
- def _get_map_lengths(self, run_id: str, *, session: Session) -> Dict[str, int]:
- """Return dict of argument name to map length.
+ def prepare_for_execution(self) -> MappedOperator:
+ # Since a mapped operator cannot be used for execution, and an unmapped
+ # BaseOperator needs to be created later (see render_template_fields),
+ # we don't need to create a copy of the MappedOperator here.
+ return self
- If any arguments are not known right now (upstream task not finished) they will not be present in the
- dict.
- """
- # TODO: Find a way to cache this.
- from airflow.models.taskmap import TaskMap
- from airflow.models.xcom import XCom
+ def iter_mapped_dependencies(self) -> Iterator[Operator]:
+ """Upstream dependencies that provide XComs used by this task for task mapping."""
from airflow.models.xcom_arg import XComArg
- expansion_kwargs = self._get_expansion_kwargs()
-
- # Populate literal mapped arguments first.
- map_lengths: Dict[str, int] = collections.defaultdict(int)
- map_lengths.update((k, len(v)) for k, v in expansion_kwargs.items() if not isinstance(v, XComArg))
-
- # Build a reverse mapping of what arguments each task contributes to.
- mapped_dep_keys: Dict[str, Set[str]] = collections.defaultdict(set)
- non_mapped_dep_keys: Dict[str, Set[str]] = collections.defaultdict(set)
- for k, v in expansion_kwargs.items():
- if not isinstance(v, XComArg):
- continue
- if v.operator.is_mapped:
- mapped_dep_keys[v.operator.task_id].add(k)
- else:
- non_mapped_dep_keys[v.operator.task_id].add(k)
- # TODO: It's not possible now, but in the future (AIP-42 Phase 2)
- # we will add support for depending on one single mapped task
- # instance. When that happens, we need to further analyze the mapped
- # case to contain only tasks we depend on "as a whole", and put
- # those we only depend on individually to the non-mapped lookup.
-
- # Collect lengths from unmapped upstreams.
- taskmap_query = session.query(TaskMap.task_id, TaskMap.length).filter(
- TaskMap.dag_id == self.dag_id,
- TaskMap.run_id == run_id,
- TaskMap.task_id.in_(non_mapped_dep_keys),
- TaskMap.map_index < 0,
- )
- for task_id, length in taskmap_query:
- for mapped_arg_name in non_mapped_dep_keys[task_id]:
- map_lengths[mapped_arg_name] += length
-
- # Collect lengths from mapped upstreams.
- xcom_query = (
- session.query(XCom.task_id, func.count(XCom.map_index))
- .group_by(XCom.task_id)
- .filter(
- XCom.dag_id == self.dag_id,
- XCom.run_id == run_id,
- XCom.task_id.in_(mapped_dep_keys),
- XCom.map_index >= 0,
- )
- )
- for task_id, length in xcom_query:
- for mapped_arg_name in mapped_dep_keys[task_id]:
- map_lengths[mapped_arg_name] += length
- return map_lengths
+ for operator, _ in XComArg.iter_xcom_references(self._get_specified_expand_input()):
+ yield operator
@cache
- def _resolve_map_lengths(self, run_id: str, *, session: Session) -> Dict[str, int]:
- """Return dict of argument name to map length, or throw if some are not resolvable"""
- expansion_kwargs = self._get_expansion_kwargs()
- map_lengths = self._get_map_lengths(run_id, session=session)
- if len(map_lengths) < len(expansion_kwargs):
- keys = ", ".join(repr(k) for k in sorted(set(expansion_kwargs).difference(map_lengths)))
- raise RuntimeError(f"Failed to populate all mapping metadata; missing: {keys}")
-
- return map_lengths
-
- def expand_mapped_task(self, run_id: str, *, session: Session) -> Tuple[Sequence["TaskInstance"], int]:
- """Create the mapped task instances for mapped task.
-
- :return: The newly created mapped TaskInstances (if any) in ascending order by map index, and the
- maximum map_index.
- """
- from airflow.models.taskinstance import TaskInstance
- from airflow.settings import task_instance_mutation_hook
-
- total_length = functools.reduce(
- operator.mul, self._resolve_map_lengths(run_id, session=session).values()
- )
-
- state: Optional[TaskInstanceState] = None
- unmapped_ti: Optional[TaskInstance] = (
- session.query(TaskInstance)
- .filter(
- TaskInstance.dag_id == self.dag_id,
- TaskInstance.task_id == self.task_id,
- TaskInstance.run_id == run_id,
- TaskInstance.map_index == -1,
- or_(TaskInstance.state.in_(State.unfinished), TaskInstance.state.is_(None)),
- )
- .one_or_none()
- )
+ def get_parse_time_mapped_ti_count(self) -> int:
+ current_count = self._get_specified_expand_input().get_parse_time_mapped_ti_count()
+ try:
+ parent_count = super().get_parse_time_mapped_ti_count()
+ except NotMapped:
+ return current_count
+ return parent_count * current_count
- all_expanded_tis: List[TaskInstance] = []
-
- if unmapped_ti:
- # The unmapped task instance still exists and is unfinished, i.e. we
- # haven't tried to run it before.
- if total_length < 1:
- # If the upstream maps this to a zero-length value, simply marked the
- # unmapped task instance as SKIPPED (if needed).
- self.log.info(
- "Marking %s as SKIPPED since the map has %d values to expand",
- unmapped_ti,
- total_length,
- )
- unmapped_ti.state = TaskInstanceState.SKIPPED
- else:
- # Otherwise convert this into the first mapped index, and create
- # TaskInstance for other indexes.
- unmapped_ti.map_index = 0
- self.log.debug("Updated in place to become %s", unmapped_ti)
- all_expanded_tis.append(unmapped_ti)
- state = unmapped_ti.state
- indexes_to_map = range(1, total_length)
- else:
- # Only create "missing" ones.
- current_max_mapping = (
- session.query(func.max(TaskInstance.map_index))
- .filter(
- TaskInstance.dag_id == self.dag_id,
- TaskInstance.task_id == self.task_id,
- TaskInstance.run_id == run_id,
- )
- .scalar()
- )
- indexes_to_map = range(current_max_mapping + 1, total_length)
-
- for index in indexes_to_map:
- # TODO: Make more efficient with bulk_insert_mappings/bulk_save_mappings.
- ti = TaskInstance(self, run_id=run_id, map_index=index, state=state)
- self.log.debug("Expanding TIs upserted %s", ti)
- task_instance_mutation_hook(ti)
- ti = session.merge(ti)
- ti.refresh_from_task(self) # session.merge() loses task information.
- all_expanded_tis.append(ti)
-
- # Set to "REMOVED" any (old) TaskInstances with map indices greater
- # than the current map value
- session.query(TaskInstance).filter(
- TaskInstance.dag_id == self.dag_id,
- TaskInstance.task_id == self.task_id,
- TaskInstance.run_id == run_id,
- TaskInstance.map_index >= total_length,
- ).update({TaskInstance.state: TaskInstanceState.REMOVED})
-
- session.flush()
- return all_expanded_tis, total_length
-
- def prepare_for_execution(self) -> "MappedOperator":
- # Since a mapped operator cannot be used for execution, and an unmapped
- # BaseOperator needs to be created later (see render_template_fields),
- # we don't need to create a copy of the MappedOperator here.
- return self
+ def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int:
+ current_count = self._get_specified_expand_input().get_total_map_length(run_id, session=session)
+ try:
+ parent_count = super().get_mapped_ti_count(run_id, session=session)
+ except NotMapped:
+ return current_count
+ return parent_count * current_count
def render_template_fields(
self,
context: Context,
- jinja_env: Optional["jinja2.Environment"] = None,
- ) -> Optional["BaseOperator"]:
- """Template all attributes listed in template_fields.
+ jinja_env: jinja2.Environment | None = None,
+ ) -> None:
+ """Template all attributes listed in *self.template_fields*.
- Different from the BaseOperator implementation, this renders the
- template fields on the *unmapped* BaseOperator.
+ This updates *context* to reference the map-expanded task and relevant
+ information, without modifying the mapped operator. The expanded task
+ in *context* is then rendered in-place.
- :param context: Dict with values to apply on content
- :param jinja_env: Jinja environment
- :return: The unmapped, populated BaseOperator
+ :param context: Context dict with values to apply on content.
+ :param jinja_env: Jinja environment to use for rendering.
"""
if not jinja_env:
jinja_env = self.get_template_env()
- # Before we unmap we have to resolve the mapped arguments, otherwise the real operator constructor
- # could be called with an XComArg, rather than the value it resolves to.
- #
- # We also need to resolve _all_ mapped arguments, even if they aren't marked as templated
- kwargs = self._get_unmap_kwargs()
-
- template_fields = set(self.template_fields)
-
- # Ideally we'd like to pass in session as an argument to this function, but since operators _could_
- # override this we can't easily change this function signature.
- # We can't use @provide_session, as that closes and expunges everything, which we don't want to do
- # when we are so "deep" in the weeds here.
- #
- # Nor do we want to close the session -- that would expunge all the things from the internal cache
- # which we don't want to do either
+
+ # Ideally we'd like to pass in session as an argument to this function,
+ # but we can't easily change this function signature since operators
+ # could override this. We can't use @provide_session since it closes and
+ # expunges everything, which we don't want to do when we are so "deep"
+ # in the weeds here. We don't close this session for the same reason.
session = settings.Session()
- self._resolve_expansion_kwargs(kwargs, template_fields, context, session)
- unmapped_task = self.unmap(unmap_kwargs=kwargs)
+ mapped_kwargs, seen_oids = self._expand_mapped_kwargs(context, session)
+ unmapped_task = self.unmap(mapped_kwargs)
+ context_update_for_unmapped(context, unmapped_task)
+
self._do_render_template_fields(
parent=unmapped_task,
- template_fields=template_fields,
+ template_fields=self.template_fields,
context=context,
jinja_env=jinja_env,
- seen_oids=set(),
+ seen_oids=seen_oids,
session=session,
)
- return unmapped_task
-
- def _resolve_expansion_kwargs(
- self, kwargs: Dict[str, Any], template_fields: Set[str], context: Context, session: Session
- ) -> None:
- """Update mapped fields in place in kwargs dict"""
- from airflow.models.xcom_arg import XComArg
-
- expansion_kwargs = self._get_expansion_kwargs()
-
- for k, v in expansion_kwargs.items():
- if isinstance(v, XComArg):
- v = v.resolve(context, session=session)
- v = self._expand_mapped_field(k, v, context, session=session)
- template_fields.discard(k)
- kwargs[k] = v
-
- def _expand_mapped_field(self, key: str, value: Any, context: Context, *, session: Session) -> Any:
- map_index = context["ti"].map_index
- if map_index < 0:
- return value
- expansion_kwargs = self._get_expansion_kwargs()
- all_lengths = self._resolve_map_lengths(context["run_id"], session=session)
-
- def _find_index_for_this_field(index: int) -> int:
- # Need to use self.mapped_kwargs for the original argument order.
- for mapped_key in reversed(list(expansion_kwargs)):
- mapped_length = all_lengths[mapped_key]
- if mapped_length < 1:
- raise RuntimeError(f"cannot expand field mapped to length {mapped_length!r}")
- if mapped_key == key:
- return index % mapped_length
- index //= mapped_length
- return -1
-
- found_index = _find_index_for_this_field(map_index)
- if found_index < 0:
- return value
- if isinstance(value, collections.abc.Sequence):
- return value[found_index]
- if not isinstance(value, dict):
- raise TypeError(f"can't map over value of type {type(value)}")
- for i, (k, v) in enumerate(value.items()):
- if i == found_index:
- return k, v
- raise IndexError(f"index {map_index} is over mapped length")
-
- def iter_mapped_dependencies(self) -> Iterator["Operator"]:
- """Upstream dependencies that provide XComs used by this task for task mapping."""
- from airflow.models.xcom_arg import XComArg
-
- for ref in XComArg.iter_xcom_args(self._get_expansion_kwargs()):
- yield ref.operator
-
- @cached_property
- def parse_time_mapped_ti_count(self) -> Optional[int]:
- """
- Number of mapped TaskInstances that can be created at DagRun create time.
-
- :return: None if non-literal mapped arg encountered, or else total number of mapped TIs this task
- should have
- """
- total = 0
-
- for value in self._get_expansion_kwargs().values():
- if not isinstance(value, MAPPABLE_LITERAL_TYPES):
- # None literal type encountered, so give up
- return None
- if total == 0:
- total = len(value)
- else:
- total *= len(value)
- return total
-
- @cache
- def run_time_mapped_ti_count(self, run_id: str, *, session: Session) -> Optional[int]:
- """
- Number of mapped TaskInstances that can be created at run time, or None if upstream tasks are not
- complete yet.
-
- :return: None if upstream tasks are not complete yet, or else total number of mapped TIs this task
- should have
- """
-
- lengths = self._get_map_lengths(run_id, session=session)
- expansion_kwargs = self._get_expansion_kwargs()
-
- if not lengths or not expansion_kwargs:
- return None
-
- total = 1
- for name in expansion_kwargs:
- val = lengths.get(name)
- if val is None:
- return None
- total *= val
-
- return total
diff --git a/airflow/models/operator.py b/airflow/models/operator.py
index b95e8dc0036b1..7352ecbdcb021 100644
--- a/airflow/models/operator.py
+++ b/airflow/models/operator.py
@@ -15,12 +15,34 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
from typing import Union
+from airflow.models.abstractoperator import AbstractOperator
from airflow.models.baseoperator import BaseOperator
from airflow.models.mappedoperator import MappedOperator
+from airflow.typing_compat import TypeGuard
Operator = Union[BaseOperator, MappedOperator]
-__all__ = ["Operator"]
+
+def needs_expansion(task: AbstractOperator) -> TypeGuard[Operator]:
+ """Whether a task needs expansion at runtime.
+
+ A task needs expansion if it either
+
+ * Is a mapped operator, or
+ * Is in a mapped task group.
+
+ This is implemented as a free function (instead of a property) so we can
+ make it a type guard.
+ """
+ if isinstance(task, MappedOperator):
+ return True
+ if task.get_closest_mapped_task_group() is not None:
+ return True
+ return False
+
+
+__all__ = ["Operator", "needs_expansion"]
diff --git a/airflow/models/param.py b/airflow/models/param.py
index fcbe7a0f931c6..d944ef20311d3 100644
--- a/airflow/models/param.py
+++ b/airflow/models/param.py
@@ -14,18 +14,26 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
import contextlib
import copy
import json
+import logging
import warnings
-from typing import TYPE_CHECKING, Any, Dict, ItemsView, MutableMapping, Optional, ValuesView
+from typing import TYPE_CHECKING, Any, ItemsView, Iterable, MutableMapping, ValuesView
-from airflow.exceptions import AirflowException, ParamValidationError
+from airflow.exceptions import AirflowException, ParamValidationError, RemovedInAirflow3Warning
from airflow.utils.context import Context
+from airflow.utils.mixins import ResolveMixin
from airflow.utils.types import NOTSET, ArgNotSet
if TYPE_CHECKING:
from airflow.models.dag import DAG
+ from airflow.models.dagrun import DagRun
+ from airflow.models.operator import Operator
+
+logger = logging.getLogger(__name__)
class Param:
@@ -39,16 +47,16 @@ class Param:
default & description will form the schema
"""
- CLASS_IDENTIFIER = '__class'
+ CLASS_IDENTIFIER = "__class"
- def __init__(self, default: Any = NOTSET, description: Optional[str] = None, **kwargs):
+ def __init__(self, default: Any = NOTSET, description: str | None = None, **kwargs):
if default is not NOTSET:
self._warn_if_not_json(default)
self.value = default
self.description = description
- self.schema = kwargs.pop('schema') if 'schema' in kwargs else kwargs
+ self.schema = kwargs.pop("schema") if "schema" in kwargs else kwargs
- def __copy__(self) -> "Param":
+ def __copy__(self) -> Param:
return Param(self.value, self.description, schema=self.schema)
@staticmethod
@@ -59,7 +67,7 @@ def _warn_if_not_json(value):
warnings.warn(
"The use of non-json-serializable params is deprecated and will be removed in "
"a future release",
- DeprecationWarning,
+ RemovedInAirflow3Warning,
)
def resolve(self, value: Any = NOTSET, suppress_exception: bool = False) -> Any:
@@ -96,7 +104,7 @@ def resolve(self, value: Any = NOTSET, suppress_exception: bool = False) -> Any:
def dump(self) -> dict:
"""Dump the Param as a dictionary"""
- out_dict = {self.CLASS_IDENTIFIER: f'{self.__module__}.{self.__class__.__name__}'}
+ out_dict = {self.CLASS_IDENTIFIER: f"{self.__module__}.{self.__class__.__name__}"}
out_dict.update(self.__dict__)
return out_dict
@@ -112,14 +120,14 @@ class ParamsDict(MutableMapping[str, Any]):
dictionary implicitly and ideally not needed to be used directly.
"""
- __slots__ = ['__dict', 'suppress_exception']
+ __slots__ = ["__dict", "suppress_exception"]
- def __init__(self, dict_obj: Optional[Dict] = None, suppress_exception: bool = False):
+ def __init__(self, dict_obj: dict | None = None, suppress_exception: bool = False):
"""
:param dict_obj: A dict or dict like object to init ParamsDict
:param suppress_exception: Flag to suppress value exceptions while initializing the ParamsDict
"""
- params_dict: Dict[str, Param] = {}
+ params_dict: dict[str, Param] = {}
dict_obj = dict_obj or {}
for k, v in dict_obj.items():
if not isinstance(v, Param):
@@ -129,10 +137,20 @@ def __init__(self, dict_obj: Optional[Dict] = None, suppress_exception: bool = F
self.__dict = params_dict
self.suppress_exception = suppress_exception
- def __copy__(self) -> "ParamsDict":
+ def __bool__(self) -> bool:
+ return bool(self.__dict)
+
+ def __eq__(self, other: Any) -> bool:
+ if isinstance(other, ParamsDict):
+ return self.dump() == other.dump()
+ if isinstance(other, dict):
+ return self.dump() == other
+ return NotImplemented
+
+ def __copy__(self) -> ParamsDict:
return ParamsDict(self.__dict, self.suppress_exception)
- def __deepcopy__(self, memo: Optional[Dict[int, Any]]) -> "ParamsDict":
+ def __deepcopy__(self, memo: dict[int, Any] | None) -> ParamsDict:
return ParamsDict(copy.deepcopy(self.__dict, memo), self.suppress_exception)
def __contains__(self, o: object) -> bool:
@@ -147,6 +165,9 @@ def __delitem__(self, v: str) -> None:
def __iter__(self):
return iter(self.__dict)
+ def __repr__(self):
+ return repr(self.dump())
+
def __setitem__(self, key: str, value: Any) -> None:
"""
Override for dictionary's ``setitem`` method. This method make sure that all values are of
@@ -163,7 +184,7 @@ def __setitem__(self, key: str, value: Any) -> None:
try:
param.resolve(value=value, suppress_exception=self.suppress_exception)
except ParamValidationError as ve:
- raise ParamValidationError(f'Invalid input for param {key}: {ve}') from None
+ raise ParamValidationError(f"Invalid input for param {key}: {ve}") from None
else:
# if the key isn't there already and if the value isn't of Param type create a new Param object
param = Param(value)
@@ -195,32 +216,34 @@ def update(self, *args, **kwargs) -> None:
return super().update(args[0].__dict)
super().update(*args, **kwargs)
- def dump(self) -> Dict[str, Any]:
+ def dump(self) -> dict[str, Any]:
"""Dumps the ParamsDict object as a dictionary, while suppressing exceptions"""
return {k: v.resolve(suppress_exception=True) for k, v in self.items()}
- def validate(self) -> Dict[str, Any]:
+ def validate(self) -> dict[str, Any]:
"""Validates & returns all the Params object stored in the dictionary"""
resolved_dict = {}
try:
for k, v in self.items():
resolved_dict[k] = v.resolve(suppress_exception=self.suppress_exception)
except ParamValidationError as ve:
- raise ParamValidationError(f'Invalid input for param {k}: {ve}') from None
+ raise ParamValidationError(f"Invalid input for param {k}: {ve}") from None
return resolved_dict
-class DagParam:
- """
- Class that represents a DAG run parameter & binds a simple Param object to a name within a DAG instance,
- so that it can be resolved during the run time via ``{{ context }}`` dictionary. The ideal use case of
- this class is to implicitly convert args passed to a method which is being decorated by ``@dag`` keyword.
+class DagParam(ResolveMixin):
+ """DAG run parameter reference.
- It can be used to parameterize your dags. You can overwrite its value by setting it on conf
- when you trigger your DagRun.
+ This binds a simple Param object to a name within a DAG instance, so that it
+ can be resolved during the runtime via the ``{{ context }}`` dictionary. The
+ ideal use case of this class is to implicitly convert args passed to a
+ method decorated by ``@dag``.
- This can also be used in templates by accessing ``{{context.params}}`` dictionary.
+ It can be used to parameterize a DAG. You can overwrite its value by setting
+ it on conf when you trigger your DagRun.
+
+ This can also be used in templates by accessing ``{{ context.params }}``.
**Example**:
@@ -232,18 +255,42 @@ class DagParam:
:param default: Default value used if no parameter was set.
"""
- def __init__(self, current_dag: "DAG", name: str, default: Any = NOTSET):
+ def __init__(self, current_dag: DAG, name: str, default: Any = NOTSET):
if default is not NOTSET:
current_dag.params[name] = default
self._name = name
self._default = default
+ def iter_references(self) -> Iterable[tuple[Operator, str]]:
+ return ()
+
def resolve(self, context: Context) -> Any:
"""Pull DagParam value from DagRun context. This method is run during ``op.execute()``."""
with contextlib.suppress(KeyError):
- return context['dag_run'].conf[self._name]
+ return context["dag_run"].conf[self._name]
if self._default is not NOTSET:
return self._default
with contextlib.suppress(KeyError):
- return context['params'][self._name]
- raise AirflowException(f'No value could be resolved for parameter {self._name}')
+ return context["params"][self._name]
+ raise AirflowException(f"No value could be resolved for parameter {self._name}")
+
+
+def process_params(
+ dag: DAG,
+ task: Operator,
+ dag_run: DagRun | None,
+ *,
+ suppress_exception: bool,
+) -> dict[str, Any]:
+ """Merge, validate params, and convert them into a simple dict."""
+ from airflow.configuration import conf
+
+ params = ParamsDict(suppress_exception=suppress_exception)
+ with contextlib.suppress(AttributeError):
+ params.update(dag.params)
+ if task.params:
+ params.update(task.params)
+ if conf.getboolean("core", "dag_run_conf_overrides_params") and dag_run and dag_run.conf:
+ logger.debug("Updating task params (%s) with DagRun.conf (%s)", params, dag_run.conf)
+ params.update(dag_run.conf)
+ return params.validate()
diff --git a/airflow/models/pool.py b/airflow/models/pool.py
index f195b93d42623..d6b8ad38872dd 100644
--- a/airflow/models/pool.py
+++ b/airflow/models/pool.py
@@ -15,8 +15,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
-from typing import Dict, Iterable, Optional, Tuple
+from typing import Iterable
from sqlalchemy import Column, Integer, String, Text, func
from sqlalchemy.orm.session import Session
@@ -50,7 +51,7 @@ class Pool(Base):
slots = Column(Integer, default=0)
description = Column(Text)
- DEFAULT_POOL_NAME = 'default_pool'
+ DEFAULT_POOL_NAME = "default_pool"
def __repr__(self):
return str(self.pool)
@@ -141,7 +142,7 @@ def slots_stats(
*,
lock_rows: bool = False,
session: Session = NEW_SESSION,
- ) -> Dict[str, PoolStats]:
+ ) -> dict[str, PoolStats]:
"""
Get Pool stats (Number of Running, Queued, Open & Total tasks)
@@ -154,17 +155,17 @@ def slots_stats(
"""
from airflow.models.taskinstance import TaskInstance # Avoid circular import
- pools: Dict[str, PoolStats] = {}
+ pools: dict[str, PoolStats] = {}
query = session.query(Pool.pool, Pool.slots)
if lock_rows:
query = with_row_locks(query, session=session, **nowait(session))
- pool_rows: Iterable[Tuple[str, int]] = query.all()
+ pool_rows: Iterable[tuple[str, int]] = query.all()
for (pool_name, total_slots) in pool_rows:
if total_slots == -1:
- total_slots = float('inf') # type: ignore
+ total_slots = float("inf") # type: ignore
pools[pool_name] = PoolStats(total=total_slots, running=0, queued=0, open=0)
state_count_by_pool = (
@@ -178,7 +179,7 @@ def slots_stats(
# Some databases return decimal.Decimal here.
count = int(count)
- stats_dict: Optional[PoolStats] = pools.get(pool_name)
+ stats_dict: PoolStats | None = pools.get(pool_name)
if not stats_dict:
continue
# TypedDict key must be a string literal, so we use if-statements to set value
@@ -202,10 +203,10 @@ def to_json(self):
:return: the pool object in json format
"""
return {
- 'id': self.id,
- 'pool': self.pool,
- 'slots': self.slots,
- 'description': self.description,
+ "id": self.id,
+ "pool": self.pool,
+ "slots": self.slots,
+ "description": self.description,
}
@provide_session
@@ -262,6 +263,24 @@ def queued_slots(self, session: Session = NEW_SESSION):
or 0
)
+ @provide_session
+ def scheduled_slots(self, session: Session = NEW_SESSION):
+ """
+ Get the number of slots scheduled at the moment.
+
+ :param session: SQLAlchemy ORM Session
+ :return: the number of scheduled slots
+ """
+ from airflow.models.taskinstance import TaskInstance # Avoid circular import
+
+ return int(
+ session.query(func.sum(TaskInstance.pool_slots))
+ .filter(TaskInstance.pool == self.pool)
+ .filter(TaskInstance.state == State.SCHEDULED)
+ .scalar()
+ or 0
+ )
+
@provide_session
def open_slots(self, session: Session = NEW_SESSION) -> float:
"""
@@ -271,6 +290,6 @@ def open_slots(self, session: Session = NEW_SESSION) -> float:
:return: the number of slots
"""
if self.slots == -1:
- return float('inf')
+ return float("inf")
else:
return self.slots - self.occupied_slots(session)
diff --git a/airflow/models/renderedtifields.py b/airflow/models/renderedtifields.py
index c7bad78b5f3f4..7fd2c29edfa61 100644
--- a/airflow/models/renderedtifields.py
+++ b/airflow/models/renderedtifields.py
@@ -16,11 +16,13 @@
# specific language governing permissions and limitations
# under the License.
"""Save Rendered Template Fields"""
+from __future__ import annotations
+
import os
-from typing import Optional
+from typing import TYPE_CHECKING
import sqlalchemy_jsonfield
-from sqlalchemy import Column, ForeignKeyConstraint, Integer, and_, not_, tuple_
+from sqlalchemy import Column, ForeignKeyConstraint, Integer, PrimaryKeyConstraint, text
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.orm import Session, relationship
@@ -31,6 +33,10 @@
from airflow.settings import json
from airflow.utils.retries import retry_db_transaction
from airflow.utils.session import NEW_SESSION, provide_session
+from airflow.utils.sqlalchemy import tuple_not_in_condition
+
+if TYPE_CHECKING:
+ from sqlalchemy.sql import FromClause
class RenderedTaskInstanceFields(Base):
@@ -41,11 +47,19 @@ class RenderedTaskInstanceFields(Base):
dag_id = Column(StringID(), primary_key=True)
task_id = Column(StringID(), primary_key=True)
run_id = Column(StringID(), primary_key=True)
- map_index = Column(Integer, primary_key=True, server_default='-1')
+ map_index = Column(Integer, primary_key=True, server_default=text("-1"))
rendered_fields = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=False)
k8s_pod_yaml = Column(sqlalchemy_jsonfield.JSONField(json=json), nullable=True)
__table_args__ = (
+ PrimaryKeyConstraint(
+ "dag_id",
+ "task_id",
+ "run_id",
+ "map_index",
+ name="rendered_task_instance_fields_pkey",
+ mssql_clustered=True,
+ ),
ForeignKeyConstraint(
[dag_id, task_id, run_id, map_index],
[
@@ -54,13 +68,13 @@ class RenderedTaskInstanceFields(Base):
"task_instance.run_id",
"task_instance.map_index",
],
- name='rtif_ti_fkey',
+ name="rtif_ti_fkey",
ondelete="CASCADE",
),
)
task_instance = relationship(
"TaskInstance",
- lazy='joined',
+ lazy="joined",
back_populates="rendered_task_instance_fields",
)
@@ -98,7 +112,7 @@ def __repr__(self):
prefix = f"<{self.__class__.__name__}: {self.dag_id}.{self.task_id} {self.run_id}"
if self.map_index != -1:
prefix += f" map_index={self.map_index}"
- return prefix + '>'
+ return prefix + ">"
def _redact(self):
from airflow.utils.log.secrets_masker import redact
@@ -111,7 +125,7 @@ def _redact(self):
@classmethod
@provide_session
- def get_templated_fields(cls, ti: TaskInstance, session: Session = NEW_SESSION) -> Optional[dict]:
+ def get_templated_fields(cls, ti: TaskInstance, session: Session = NEW_SESSION) -> dict | None:
"""
Get templated field for a TaskInstance from the RenderedTaskInstanceFields
table.
@@ -139,7 +153,7 @@ def get_templated_fields(cls, ti: TaskInstance, session: Session = NEW_SESSION)
@classmethod
@provide_session
- def get_k8s_pod_yaml(cls, ti: TaskInstance, session: Session = NEW_SESSION) -> Optional[dict]:
+ def get_k8s_pod_yaml(cls, ti: TaskInstance, session: Session = NEW_SESSION) -> dict | None:
"""
Get rendered Kubernetes Pod Yaml for a TaskInstance from the RenderedTaskInstanceFields
table.
@@ -174,9 +188,9 @@ def delete_old_records(
cls,
task_id: str,
dag_id: str,
- num_to_keep=conf.getint("core", "max_num_rendered_ti_fields_per_task", fallback=0),
- session: Session = None,
- ):
+ num_to_keep: int = conf.getint("core", "max_num_rendered_ti_fields_per_task", fallback=0),
+ session: Session = NEW_SESSION,
+ ) -> None:
"""
Keep only Last X (num_to_keep) number of records for a task by deleting others.
@@ -202,49 +216,30 @@ def delete_old_records(
.limit(num_to_keep)
)
- if session.bind.dialect.name in ["postgresql", "sqlite"]:
- # Fetch Top X records given dag_id & task_id ordered by Execution Date
- subq1 = tis_to_keep_query.subquery()
- excluded = session.query(subq1.c.dag_id, subq1.c.task_id, subq1.c.run_id)
- session.query(cls).filter(
- cls.dag_id == dag_id,
- cls.task_id == task_id,
- tuple_(cls.dag_id, cls.task_id, cls.run_id).notin_(excluded),
- ).delete(synchronize_session=False)
- elif session.bind.dialect.name in ["mysql"]:
- cls._remove_old_rendered_ti_fields_mysql(dag_id, session, task_id, tis_to_keep_query)
- else:
- # Fetch Top X records given dag_id & task_id ordered by Execution Date
- tis_to_keep = tis_to_keep_query.all()
-
- filter_tis = [
- not_(
- and_(
- cls.dag_id == ti.dag_id,
- cls.task_id == ti.task_id,
- cls.run_id == ti.run_id,
- )
- )
- for ti in tis_to_keep
- ]
-
- session.query(cls).filter(and_(*filter_tis)).delete(synchronize_session=False)
-
+ cls._do_delete_old_records(
+ dag_id=dag_id,
+ task_id=task_id,
+ ti_clause=tis_to_keep_query.subquery(),
+ session=session,
+ )
session.flush()
@classmethod
@retry_db_transaction
- def _remove_old_rendered_ti_fields_mysql(cls, dag_id, session, task_id, tis_to_keep_query):
- # Fetch Top X records given dag_id & task_id ordered by Execution Date
- subq1 = tis_to_keep_query.subquery('subq1')
- # Second Subquery
- # Workaround for MySQL Limitation (https://stackoverflow.com/a/19344141/5691525)
- # Limitation: This version of MySQL does not yet support
- # LIMIT & IN/ALL/ANY/SOME subquery
- subq2 = session.query(subq1.c.dag_id, subq1.c.task_id, subq1.c.run_id).subquery('subq2')
+ def _do_delete_old_records(
+ cls,
+ *,
+ task_id: str,
+ dag_id: str,
+ ti_clause: FromClause,
+ session: Session,
+ ) -> None:
# This query might deadlock occasionally and it should be retried if fails (see decorator)
session.query(cls).filter(
cls.dag_id == dag_id,
cls.task_id == task_id,
- tuple_(cls.dag_id, cls.task_id, cls.run_id).notin_(subq2),
+ tuple_not_in_condition(
+ (cls.dag_id, cls.task_id, cls.run_id),
+ session.query(ti_clause.c.dag_id, ti_clause.c.task_id, ti_clause.c.run_id),
+ ),
).delete(synchronize_session=False)
diff --git a/airflow/models/sensorinstance.py b/airflow/models/sensorinstance.py
deleted file mode 100644
index b1681e52278f3..0000000000000
--- a/airflow/models/sensorinstance.py
+++ /dev/null
@@ -1,185 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-import json
-
-from sqlalchemy import BigInteger, Column, Index, Integer, String, Text
-
-from airflow.configuration import conf
-from airflow.exceptions import AirflowException
-from airflow.models.base import ID_LEN, Base
-from airflow.utils import timezone
-from airflow.utils.session import provide_session
-from airflow.utils.sqlalchemy import UtcDateTime
-from airflow.utils.state import State
-
-
-class SensorInstance(Base):
- """
- SensorInstance support the smart sensor service. It stores the sensor task states
- and context that required for poking include poke context and execution context.
- In sensor_instance table we also save the sensor operator classpath so that inside
- smart sensor there is no need to import the dagbag and create task object for each
- sensor task.
-
- SensorInstance include another set of columns to support the smart sensor shard on
- large number of sensor instance. The key idea is to generate the hash code from the
- poke context and use it to map to a shorter shard code which can be used as an index.
- Every smart sensor process takes care of tasks whose `shardcode` are in a certain range.
-
- """
-
- __tablename__ = "sensor_instance"
-
- id = Column(Integer, primary_key=True)
- task_id = Column(String(ID_LEN), nullable=False)
- dag_id = Column(String(ID_LEN), nullable=False)
- execution_date = Column(UtcDateTime, nullable=False)
- state = Column(String(20))
- _try_number = Column('try_number', Integer, default=0)
- start_date = Column(UtcDateTime)
- operator = Column(String(1000), nullable=False)
- op_classpath = Column(String(1000), nullable=False)
- hashcode = Column(BigInteger, nullable=False)
- shardcode = Column(Integer, nullable=False)
- poke_context = Column(Text, nullable=False)
- execution_context = Column(Text)
- created_at = Column(UtcDateTime, default=timezone.utcnow(), nullable=False)
- updated_at = Column(UtcDateTime, default=timezone.utcnow(), onupdate=timezone.utcnow(), nullable=False)
-
- # SmartSensor doesn't support mapped operators, but this is needed for compatibly with the
- # log_filename_template of TaskInstances
- map_index = -1
-
- __table_args__ = (
- Index('ti_primary_key', dag_id, task_id, execution_date, unique=True),
- Index('si_hashcode', hashcode),
- Index('si_shardcode', shardcode),
- Index('si_state_shard', state, shardcode),
- Index('si_updated_at', updated_at),
- )
-
- def __init__(self, ti):
- self.dag_id = ti.dag_id
- self.task_id = ti.task_id
- self.execution_date = ti.execution_date
-
- @staticmethod
- def get_classpath(obj):
- """
- Get the object dotted class path. Used for getting operator classpath.
-
- :param obj:
- :return: The class path of input object
- :rtype: str
- """
- module_name, class_name = obj.__module__, obj.__class__.__name__
-
- return module_name + "." + class_name
-
- @classmethod
- @provide_session
- def register(cls, ti, poke_context, execution_context, session=None):
- """
- Register task instance ti for a sensor in sensor_instance table. Persist the
- context used for a sensor and set the sensor_instance table state to sensing.
-
- :param ti: The task instance for the sensor to be registered.
- :param poke_context: Context used for sensor poke function.
- :param execution_context: Context used for execute sensor such as timeout
- setting and email configuration.
- :param session: SQLAlchemy ORM Session
- :return: True if the ti was registered successfully.
- :rtype: Boolean
- """
- if poke_context is None:
- raise AirflowException('poke_context should not be None')
-
- encoded_poke = json.dumps(poke_context)
- encoded_execution_context = json.dumps(execution_context)
-
- sensor = (
- session.query(SensorInstance)
- .filter(
- SensorInstance.dag_id == ti.dag_id,
- SensorInstance.task_id == ti.task_id,
- SensorInstance.execution_date == ti.execution_date,
- )
- .with_for_update()
- .first()
- )
-
- if sensor is None:
- sensor = SensorInstance(ti=ti)
-
- sensor.operator = ti.operator
- sensor.op_classpath = SensorInstance.get_classpath(ti.task)
- sensor.poke_context = encoded_poke
- sensor.execution_context = encoded_execution_context
-
- sensor.hashcode = hash(encoded_poke)
- sensor.shardcode = sensor.hashcode % conf.getint('smart_sensor', 'shard_code_upper_limit')
- sensor.try_number = ti.try_number
-
- sensor.state = State.SENSING
- sensor.start_date = timezone.utcnow()
- session.add(sensor)
- session.commit()
-
- return True
-
- @property
- def try_number(self):
- """
- Return the try number that this task number will be when it is actually
- run.
- If the TI is currently running, this will match the column in the
- database, in all other cases this will be incremented.
- """
- # This is designed so that task logs end up in the right file.
- if self.state in State.running:
- return self._try_number
- return self._try_number + 1
-
- @try_number.setter
- def try_number(self, value):
- self._try_number = value
-
- def __repr__(self):
- return (
- "<{self.__class__.__name__}: id: {self.id} poke_context: {self.poke_context} "
- "execution_context: {self.execution_context} state: {self.state}>".format(self=self)
- )
-
- @provide_session
- def get_dagrun(self, session):
- """
- Returns the DagRun for this SensorInstance
-
- :param session: SQLAlchemy ORM Session
- :return: DagRun
- """
- from airflow.models.dagrun import DagRun # Avoid circular import
-
- dr = (
- session.query(DagRun)
- .filter(DagRun.dag_id == self.dag_id, DagRun.execution_date == self.execution_date)
- .one()
- )
-
- return dr
diff --git a/airflow/models/serialized_dag.py b/airflow/models/serialized_dag.py
index 45424709a8421..3e1ec0c70cce3 100644
--- a/airflow/models/serialized_dag.py
+++ b/airflow/models/serialized_dag.py
@@ -15,17 +15,16 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Serialized DAG table in database."""
+from __future__ import annotations
import hashlib
import logging
import zlib
from datetime import datetime, timedelta
-from typing import Any, Dict, List, Optional
import sqlalchemy_jsonfield
-from sqlalchemy import BigInteger, Column, Index, LargeBinary, String, and_
+from sqlalchemy import BigInteger, Column, Index, LargeBinary, String, and_, or_
from sqlalchemy.orm import Session, backref, foreign, relationship
from sqlalchemy.sql.expression import func, literal
@@ -62,23 +61,24 @@ class SerializedDagModel(Base):
it solves the webserver scalability issue.
"""
- __tablename__ = 'serialized_dag'
+ __tablename__ = "serialized_dag"
dag_id = Column(String(ID_LEN), primary_key=True)
fileloc = Column(String(2000), nullable=False)
# The max length of fileloc exceeds the limit of indexing.
- fileloc_hash = Column(BigInteger, nullable=False)
- _data = Column('data', sqlalchemy_jsonfield.JSONField(json=json), nullable=True)
- _data_compressed = Column('data_compressed', LargeBinary, nullable=True)
+ fileloc_hash = Column(BigInteger(), nullable=False)
+ _data = Column("data", sqlalchemy_jsonfield.JSONField(json=json), nullable=True)
+ _data_compressed = Column("data_compressed", LargeBinary, nullable=True)
last_updated = Column(UtcDateTime, nullable=False)
dag_hash = Column(String(32), nullable=False)
+ processor_subdir = Column(String(2000), nullable=True)
- __table_args__ = (Index('idx_fileloc_hash', fileloc_hash, unique=False),)
+ __table_args__ = (Index("idx_fileloc_hash", fileloc_hash, unique=False),)
dag_runs = relationship(
DagRun,
primaryjoin=dag_id == foreign(DagRun.dag_id), # type: ignore
- backref=backref('serialized_dag', uselist=False, innerjoin=True),
+ backref=backref("serialized_dag", uselist=False, innerjoin=True),
)
dag_model = relationship(
@@ -87,16 +87,17 @@ class SerializedDagModel(Base):
foreign_keys=dag_id,
uselist=False,
innerjoin=True,
- backref=backref('serialized_dag', uselist=False, innerjoin=True),
+ backref=backref("serialized_dag", uselist=False, innerjoin=True),
)
load_op_links = True
- def __init__(self, dag: DAG):
+ def __init__(self, dag: DAG, processor_subdir: str | None = None):
self.dag_id = dag.dag_id
self.fileloc = dag.fileloc
self.fileloc_hash = DagCode.dag_fileloc_hash(self.fileloc)
self.last_updated = timezone.utcnow()
+ self.processor_subdir = processor_subdir
dag_data = SerializedDAG.to_dict(dag)
dag_data_json = json.dumps(dag_data, sort_keys=True).encode("utf-8")
@@ -119,7 +120,13 @@ def __repr__(self):
@classmethod
@provide_session
- def write_dag(cls, dag: DAG, min_update_interval: Optional[int] = None, session: Session = None) -> bool:
+ def write_dag(
+ cls,
+ dag: DAG,
+ min_update_interval: int | None = None,
+ processor_subdir: str | None = None,
+ session: Session = None,
+ ) -> bool:
"""Serializes a DAG and writes it into database.
If the record already exists, it checks if the Serialized DAG changed or not. If it is
changed, it updates the record, ignores otherwise.
@@ -142,19 +149,21 @@ def write_dag(cls, dag: DAG, min_update_interval: Optional[int] = None, session:
(timezone.utcnow() - timedelta(seconds=min_update_interval)) < cls.last_updated,
)
)
- .first()
- is not None
+ .scalar()
):
- # TODO: .first() is not None can be changed to .scalar() once we update to sqlalchemy 1.4+
- # as the associated sqlalchemy bug for MySQL was fixed
- # related issue : https://github.com/sqlalchemy/sqlalchemy/issues/5481
return False
log.debug("Checking if DAG (%s) changed", dag.dag_id)
- new_serialized_dag = cls(dag)
- serialized_dag_hash_from_db = session.query(cls.dag_hash).filter(cls.dag_id == dag.dag_id).scalar()
+ new_serialized_dag = cls(dag, processor_subdir)
+ serialized_dag_db = (
+ session.query(cls.dag_hash, cls.processor_subdir).filter(cls.dag_id == dag.dag_id).first()
+ )
- if serialized_dag_hash_from_db == new_serialized_dag.dag_hash:
+ if (
+ serialized_dag_db is not None
+ and serialized_dag_db.dag_hash == new_serialized_dag.dag_hash
+ and serialized_dag_db.processor_subdir == new_serialized_dag.processor_subdir
+ ):
log.debug("Serialized DAG (%s) is unchanged. Skipping writing to DB", dag.dag_id)
return False
@@ -165,7 +174,7 @@ def write_dag(cls, dag: DAG, min_update_interval: Optional[int] = None, session:
@classmethod
@provide_session
- def read_all_dags(cls, session: Session = None) -> Dict[str, 'SerializedDAG']:
+ def read_all_dags(cls, session: Session = None) -> dict[str, SerializedDAG]:
"""Reads all DAGs in serialized_dag table.
:param session: ORM Session
@@ -206,7 +215,7 @@ def dag(self):
SerializedDAG._load_operator_extra_links = self.load_op_links
if isinstance(self.data, dict):
- dag = SerializedDAG.from_dict(self.data) # type: Any
+ dag = SerializedDAG.from_dict(self.data)
else:
dag = SerializedDAG.from_json(self.data)
return dag
@@ -222,7 +231,9 @@ def remove_dag(cls, dag_id: str, session: Session = None):
@classmethod
@provide_session
- def remove_deleted_dags(cls, alive_dag_filelocs: List[str], session=None):
+ def remove_deleted_dags(
+ cls, alive_dag_filelocs: list[str], processor_subdir: str | None = None, session=None
+ ):
"""Deletes DAGs not included in alive_dag_filelocs.
:param alive_dag_filelocs: file paths of alive DAGs
@@ -236,7 +247,14 @@ def remove_deleted_dags(cls, alive_dag_filelocs: List[str], session=None):
session.execute(
cls.__table__.delete().where(
- and_(cls.fileloc_hash.notin_(alive_fileloc_hashes), cls.fileloc.notin_(alive_dag_filelocs))
+ and_(
+ cls.fileloc_hash.notin_(alive_fileloc_hashes),
+ cls.fileloc.notin_(alive_dag_filelocs),
+ or_(
+ cls.processor_subdir is None,
+ cls.processor_subdir == processor_subdir,
+ ),
+ )
)
)
@@ -252,7 +270,7 @@ def has_dag(cls, dag_id: str, session: Session = None) -> bool:
@classmethod
@provide_session
- def get_dag(cls, dag_id: str, session: Session = None) -> Optional['SerializedDAG']:
+ def get_dag(cls, dag_id: str, session: Session = None) -> SerializedDAG | None:
row = cls.get(dag_id, session=session)
if row:
return row.dag
@@ -260,7 +278,7 @@ def get_dag(cls, dag_id: str, session: Session = None) -> Optional['SerializedDA
@classmethod
@provide_session
- def get(cls, dag_id: str, session: Session = None) -> Optional['SerializedDagModel']:
+ def get(cls, dag_id: str, session: Session = None) -> SerializedDagModel | None:
"""
Get the SerializedDAG for the given dag ID.
It will cope with being passed the ID of a subdag by looking up the
@@ -281,7 +299,7 @@ def get(cls, dag_id: str, session: Session = None) -> Optional['SerializedDagMod
@staticmethod
@provide_session
- def bulk_sync_to_db(dags: List[DAG], session: Session = None):
+ def bulk_sync_to_db(dags: list[DAG], processor_subdir: str | None = None, session: Session = None):
"""
Saves DAGs as Serialized DAG objects in the database. Each
DAG is saved in a separate database query.
@@ -293,12 +311,15 @@ def bulk_sync_to_db(dags: List[DAG], session: Session = None):
for dag in dags:
if not dag.is_subdag:
SerializedDagModel.write_dag(
- dag, min_update_interval=MIN_SERIALIZED_DAG_UPDATE_INTERVAL, session=session
+ dag=dag,
+ min_update_interval=MIN_SERIALIZED_DAG_UPDATE_INTERVAL,
+ processor_subdir=processor_subdir,
+ session=session,
)
@classmethod
@provide_session
- def get_last_updated_datetime(cls, dag_id: str, session: Session = None) -> Optional[datetime]:
+ def get_last_updated_datetime(cls, dag_id: str, session: Session = None) -> datetime | None:
"""
Get the date when the Serialized DAG associated to DAG was last updated
in serialized_dag table
@@ -310,7 +331,7 @@ def get_last_updated_datetime(cls, dag_id: str, session: Session = None) -> Opti
@classmethod
@provide_session
- def get_max_last_updated_datetime(cls, session: Session = None) -> Optional[datetime]:
+ def get_max_last_updated_datetime(cls, session: Session = None) -> datetime | None:
"""
Get the maximum date when any DAG was last updated in serialized_dag table
@@ -320,20 +341,19 @@ def get_max_last_updated_datetime(cls, session: Session = None) -> Optional[date
@classmethod
@provide_session
- def get_latest_version_hash(cls, dag_id: str, session: Session = None) -> Optional[str]:
+ def get_latest_version_hash(cls, dag_id: str, session: Session = None) -> str | None:
"""
Get the latest DAG version for a given DAG ID.
:param dag_id: DAG ID
:param session: ORM Session
:return: DAG Hash, or None if the DAG is not found
- :rtype: str | None
"""
return session.query(cls.dag_hash).filter(cls.dag_id == dag_id).scalar()
@classmethod
@provide_session
- def get_dag_dependencies(cls, session: Session = None) -> Dict[str, List['DagDependency']]:
+ def get_dag_dependencies(cls, session: Session = None) -> dict[str, list[DagDependency]]:
"""
Get the dependencies between DAGs
diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py
index d5b1481cb19b2..c2ec087cf6042 100644
--- a/airflow/models/skipmixin.py
+++ b/airflow/models/skipmixin.py
@@ -15,10 +15,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
import warnings
-from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence, Union, cast
+from typing import TYPE_CHECKING, Iterable, Sequence
+from airflow.exceptions import AirflowException, RemovedInAirflow3Warning
from airflow.models.taskinstance import TaskInstance
from airflow.utils import timezone
from airflow.utils.log.logging_mixin import LoggingMixin
@@ -29,8 +31,9 @@
from pendulum import DateTime
from sqlalchemy import Session
- from airflow.models.baseoperator import BaseOperator
from airflow.models.dagrun import DagRun
+ from airflow.models.operator import Operator
+ from airflow.models.taskmixin import DAGNode
# The key used by SkipMixin to store XCom data.
XCOM_SKIPMIXIN_KEY = "skipmixin_key"
@@ -42,18 +45,29 @@
XCOM_SKIPMIXIN_FOLLOWED = "followed"
+def _ensure_tasks(nodes: Iterable[DAGNode]) -> Sequence[Operator]:
+ from airflow.models.baseoperator import BaseOperator
+ from airflow.models.mappedoperator import MappedOperator
+
+ return [n for n in nodes if isinstance(n, (BaseOperator, MappedOperator))]
+
+
class SkipMixin(LoggingMixin):
"""A Mixin to skip Tasks Instances"""
- def _set_state_to_skipped(self, dag_run: "DagRun", tasks: "Iterable[BaseOperator]", session: "Session"):
+ def _set_state_to_skipped(
+ self,
+ dag_run: DagRun,
+ tasks: Iterable[Operator],
+ session: Session,
+ ) -> None:
"""Used internally to set state of task instances to skipped from the same dag run."""
- task_ids = [d.task_id for d in tasks]
now = timezone.utcnow()
session.query(TaskInstance).filter(
TaskInstance.dag_id == dag_run.dag_id,
TaskInstance.run_id == dag_run.run_id,
- TaskInstance.task_id.in_(task_ids),
+ TaskInstance.task_id.in_(d.task_id for d in tasks),
).update(
{
TaskInstance.state: State.SKIPPED,
@@ -66,10 +80,10 @@ def _set_state_to_skipped(self, dag_run: "DagRun", tasks: "Iterable[BaseOperator
@provide_session
def skip(
self,
- dag_run: "DagRun",
- execution_date: "DateTime",
- tasks: Sequence["BaseOperator"],
- session: "Session" = NEW_SESSION,
+ dag_run: DagRun,
+ execution_date: DateTime,
+ tasks: Iterable[DAGNode],
+ session: Session = NEW_SESSION,
):
"""
Sets tasks instances to skipped from the same dag run.
@@ -83,7 +97,8 @@ def skip(
:param tasks: tasks to skip (not task_ids)
:param session: db session to use
"""
- if not tasks:
+ task_list = _ensure_tasks(tasks)
+ if not task_list:
return
if execution_date and not dag_run:
@@ -91,14 +106,14 @@ def skip(
warnings.warn(
"Passing an execution_date to `skip()` is deprecated in favour of passing a dag_run",
- DeprecationWarning,
+ RemovedInAirflow3Warning,
stacklevel=2,
)
dag_run = (
session.query(DagRun)
.filter(
- DagRun.dag_id == tasks[0].dag_id,
+ DagRun.dag_id == task_list[0].dag_id,
DagRun.execution_date == execution_date,
)
.one()
@@ -111,24 +126,24 @@ def skip(
if dag_run is None:
raise ValueError("dag_run is required")
- self._set_state_to_skipped(dag_run, tasks, session)
+ self._set_state_to_skipped(dag_run, task_list, session)
session.commit()
# SkipMixin may not necessarily have a task_id attribute. Only store to XCom if one is available.
- task_id: Optional[str] = getattr(self, "task_id", None)
+ task_id: str | None = getattr(self, "task_id", None)
if task_id is not None:
from airflow.models.xcom import XCom
XCom.set(
key=XCOM_SKIPMIXIN_KEY,
- value={XCOM_SKIPMIXIN_SKIPPED: [d.task_id for d in tasks]},
+ value={XCOM_SKIPMIXIN_SKIPPED: [d.task_id for d in task_list]},
task_id=task_id,
dag_id=dag_run.dag_id,
run_id=dag_run.run_id,
session=session,
)
- def skip_all_except(self, ti: TaskInstance, branch_task_ids: Union[None, str, Iterable[str]]):
+ def skip_all_except(self, ti: TaskInstance, branch_task_ids: None | str | Iterable[str]):
"""
This method implements the logic for a branching operator; given a single
task ID or list of task IDs to follow, this skips all other tasks
@@ -139,19 +154,40 @@ def skip_all_except(self, ti: TaskInstance, branch_task_ids: Union[None, str, It
"""
self.log.info("Following branch %s", branch_task_ids)
if isinstance(branch_task_ids, str):
- branch_task_ids = {branch_task_ids}
+ branch_task_id_set = {branch_task_ids}
+ elif isinstance(branch_task_ids, Iterable):
+ branch_task_id_set = set(branch_task_ids)
+ invalid_task_ids_type = {
+ (bti, type(bti).__name__) for bti in branch_task_ids if not isinstance(bti, str)
+ }
+ if invalid_task_ids_type:
+ raise AirflowException(
+ f"'branch_task_ids' expected all task IDs are strings. "
+ f"Invalid tasks found: {invalid_task_ids_type}."
+ )
elif branch_task_ids is None:
- branch_task_ids = ()
-
- branch_task_ids = set(branch_task_ids)
+ branch_task_id_set = set()
+ else:
+ raise AirflowException(
+ "'branch_task_ids' must be either None, a task ID, or an Iterable of IDs, "
+ f"but got {type(branch_task_ids).__name__!r}."
+ )
dag_run = ti.get_dagrun()
task = ti.task
dag = task.dag
- assert dag # For Mypy.
+ if TYPE_CHECKING:
+ assert dag
+
+ valid_task_ids = set(dag.task_ids)
+ invalid_task_ids = branch_task_id_set - valid_task_ids
+ if invalid_task_ids:
+ raise AirflowException(
+ "'branch_task_ids' must contain only valid task_ids. "
+ f"Invalid tasks found: {invalid_task_ids}."
+ )
- # At runtime, the downstream list will only be operators
- downstream_tasks = cast("List[BaseOperator]", task.downstream_list)
+ downstream_tasks = _ensure_tasks(task.downstream_list)
if downstream_tasks:
# For a branching workflow that looks like this, when "branch" does skip_all_except("task1"),
@@ -166,11 +202,11 @@ def skip_all_except(self, ti: TaskInstance, branch_task_ids: Union[None, str, It
# v /
# task1
#
- for branch_task_id in list(branch_task_ids):
- branch_task_ids.update(dag.get_task(branch_task_id).get_flat_relative_ids(upstream=False))
+ for branch_task_id in list(branch_task_id_set):
+ branch_task_id_set.update(dag.get_task(branch_task_id).get_flat_relative_ids(upstream=False))
- skip_tasks = [t for t in downstream_tasks if t.task_id not in branch_task_ids]
- follow_task_ids = [t.task_id for t in downstream_tasks if t.task_id in branch_task_ids]
+ skip_tasks = [t for t in downstream_tasks if t.task_id not in branch_task_id_set]
+ follow_task_ids = [t.task_id for t in downstream_tasks if t.task_id in branch_task_id_set]
self.log.info("Skipping tasks %s", [t.task_id for t in skip_tasks])
with create_session() as session:
diff --git a/airflow/models/slamiss.py b/airflow/models/slamiss.py
index 6c841e3ee6e05..cd2f0f1dd5271 100644
--- a/airflow/models/slamiss.py
+++ b/airflow/models/slamiss.py
@@ -15,6 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
from sqlalchemy import Boolean, Column, Index, String, Text
@@ -39,7 +40,7 @@ class SlaMiss(Base):
description = Column(Text)
notification_sent = Column(Boolean, default=False)
- __table_args__ = (Index('sm_dag', dag_id, unique=False),)
+ __table_args__ = (Index("sm_dag", dag_id, unique=False),)
def __repr__(self):
return str((self.dag_id, self.task_id, self.execution_date.isoformat()))
diff --git a/airflow/models/taskfail.py b/airflow/models/taskfail.py
index f7de99c308cac..ead4cee000f28 100644
--- a/airflow/models/taskfail.py
+++ b/airflow/models/taskfail.py
@@ -16,8 +16,9 @@
# specific language governing permissions and limitations
# under the License.
"""Taskfail tracks the failed run durations of each task instance"""
+from __future__ import annotations
-from sqlalchemy import Column, ForeignKeyConstraint, Integer
+from sqlalchemy import Column, ForeignKeyConstraint, Index, Integer, text
from sqlalchemy.orm import relationship
from airflow.models.base import Base, StringID
@@ -33,12 +34,13 @@ class TaskFail(Base):
task_id = Column(StringID(), nullable=False)
dag_id = Column(StringID(), nullable=False)
run_id = Column(StringID(), nullable=False)
- map_index = Column(Integer, nullable=False)
+ map_index = Column(Integer, nullable=False, server_default=text("-1"))
start_date = Column(UtcDateTime)
end_date = Column(UtcDateTime)
duration = Column(Integer)
__table_args__ = (
+ Index("idx_task_fail_task_instance", dag_id, task_id, run_id, map_index),
ForeignKeyConstraint(
[dag_id, task_id, run_id, map_index],
[
@@ -47,7 +49,7 @@ class TaskFail(Base):
"task_instance.run_id",
"task_instance.map_index",
],
- name='task_fail_ti_fkey',
+ name="task_fail_ti_fkey",
ondelete="CASCADE",
),
)
@@ -79,4 +81,4 @@ def __repr__(self):
prefix = f"<{self.__class__.__name__}: {self.dag_id}.{self.task_id} {self.run_id}"
if self.map_index != -1:
prefix += f" map_index={self.map_index}"
- return prefix + '>'
+ return prefix + ">"
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 0f5d49b819762..9708d73124b47 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -15,6 +15,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
import collections.abc
import contextlib
import hashlib
@@ -28,37 +30,24 @@
from datetime import datetime, timedelta
from functools import partial
from types import TracebackType
-from typing import (
- TYPE_CHECKING,
- Any,
- Callable,
- Collection,
- ContextManager,
- Dict,
- Generator,
- Iterable,
- List,
- NamedTuple,
- Optional,
- Set,
- Tuple,
- Union,
-)
+from typing import TYPE_CHECKING, Any, Callable, Collection, Generator, Iterable, NamedTuple, Tuple
from urllib.parse import quote
-import attr
import dill
import jinja2
+import lazy_object_proxy
import pendulum
from jinja2 import TemplateAssertionError, UndefinedError
from sqlalchemy import (
Column,
+ DateTime,
Float,
ForeignKeyConstraint,
Index,
Integer,
- PickleType,
+ PrimaryKeyConstraint,
String,
+ Text,
and_,
false,
func,
@@ -70,38 +59,38 @@
from sqlalchemy.ext.mutable import MutableDict
from sqlalchemy.orm import reconstructor, relationship
from sqlalchemy.orm.attributes import NO_VALUE, set_committed_value
-from sqlalchemy.orm.exc import NoResultFound
-from sqlalchemy.orm.query import Query
from sqlalchemy.orm.session import Session
from sqlalchemy.sql.elements import BooleanClauseList
-from sqlalchemy.sql.expression import ColumnOperators
-from sqlalchemy.sql.sqltypes import BigInteger
+from sqlalchemy.sql.expression import ColumnOperators, case
from airflow import settings
from airflow.compat.functools import cache
from airflow.configuration import conf
+from airflow.datasets import Dataset
+from airflow.datasets.manager import dataset_manager
from airflow.exceptions import (
AirflowException,
AirflowFailException,
AirflowRescheduleException,
AirflowSensorTimeout,
AirflowSkipException,
- AirflowSmartSensorException,
AirflowTaskTimeout,
DagRunNotFound,
+ RemovedInAirflow3Warning,
TaskDeferralError,
TaskDeferred,
UnmappableXComLengthPushed,
UnmappableXComTypePushed,
XComForMappingNotPushed,
)
-from airflow.models.base import COLLATION_ARGS, ID_LEN, Base
+from airflow.models.base import Base, StringID
from airflow.models.log import Log
-from airflow.models.param import ParamsDict
+from airflow.models.mappedoperator import MappedOperator
+from airflow.models.param import process_params
from airflow.models.taskfail import TaskFail
from airflow.models.taskmap import TaskMap
from airflow.models.taskreschedule import TaskReschedule
-from airflow.models.xcom import XCOM_RETURN_KEY, XCom
+from airflow.models.xcom import XCOM_RETURN_KEY, LazyXComAccess, XCom
from airflow.plugins_manager import integrate_macros_plugins
from airflow.sentry import Sentry
from airflow.stats import Stats
@@ -109,7 +98,7 @@
from airflow.ti_deps.dep_context import DepContext
from airflow.ti_deps.dependencies_deps import REQUEUEABLE_DEPS, RUNNING_DEPS
from airflow.timetables.base import DataInterval
-from airflow.typing_compat import Literal
+from airflow.typing_compat import Literal, TypeGuard
from airflow.utils import timezone
from airflow.utils.context import ConnectionAccessor, Context, VariableAccessor, context_merge
from airflow.utils.email import send_email
@@ -120,21 +109,30 @@
from airflow.utils.platform import getuser
from airflow.utils.retries import run_with_db_retries
from airflow.utils.session import NEW_SESSION, create_session, provide_session
-from airflow.utils.sqlalchemy import ExtendedJSON, UtcDateTime, tuple_in_condition, with_row_locks
+from airflow.utils.sqlalchemy import (
+ ExecutorConfigType,
+ ExtendedJSON,
+ UtcDateTime,
+ tuple_in_condition,
+ with_row_locks,
+)
from airflow.utils.state import DagRunState, State, TaskInstanceState
from airflow.utils.timeout import timeout
TR = TaskReschedule
-_CURRENT_CONTEXT: List[Context] = []
+_CURRENT_CONTEXT: list[Context] = []
log = logging.getLogger(__name__)
if TYPE_CHECKING:
+ from airflow.models.abstractoperator import TaskStateChangeCallback
from airflow.models.baseoperator import BaseOperator
from airflow.models.dag import DAG, DagModel
from airflow.models.dagrun import DagRun
+ from airflow.models.dataset import DatasetEvent
from airflow.models.operator import Operator
+ from airflow.utils.task_group import MappedTaskGroup, TaskGroup
@contextlib.contextmanager
@@ -157,12 +155,12 @@ def set_current_context(context: Context) -> Generator[Context, None, None]:
def clear_task_instances(
- tis,
- session,
- activate_dag_runs=None,
- dag=None,
- dag_run_state: Union[DagRunState, Literal[False]] = DagRunState.QUEUED,
-):
+ tis: list[TaskInstance],
+ session: Session,
+ activate_dag_runs: None = None,
+ dag: DAG | None = None,
+ dag_run_state: DagRunState | Literal[False] = DagRunState.QUEUED,
+) -> None:
"""
Clears a set of task instances, but makes sure the running ones
get killed.
@@ -176,7 +174,7 @@ def clear_task_instances(
"""
job_ids = []
# Keys: dag_id -> run_id -> map_indexes -> try_numbers -> task_id
- task_id_by_key: Dict[str, Dict[str, Dict[int, Dict[int, Set[str]]]]] = defaultdict(
+ task_id_by_key: dict[str, dict[str, dict[int, dict[int, set[str]]]]] = defaultdict(
lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(set)))
)
for ti in tis:
@@ -201,6 +199,7 @@ def clear_task_instances(
ti.max_tries = max(ti.max_tries, ti.prev_attempted_tries)
ti.state = None
ti.external_executor_id = None
+ ti.clear_next_method_args()
session.merge(ti)
task_id_by_key[ti.dag_id][ti.run_id][ti.map_index][ti.try_number].add(ti.task_id)
@@ -249,7 +248,7 @@ def clear_task_instances(
warnings.warn(
"`activate_dag_runs` parameter to clear_task_instances function is deprecated. "
"Please use `dag_run_state`",
- DeprecationWarning,
+ RemovedInAirflow3Warning,
stacklevel=2,
)
if not activate_dag_runs:
@@ -279,91 +278,20 @@ def clear_task_instances(
if dag_run_state == DagRunState.QUEUED:
dr.last_scheduling_decision = None
dr.start_date = None
+ session.flush()
-class _LazyXComAccessIterator(collections.abc.Iterator):
- __slots__ = ['_cm', '_it']
+def _is_mappable_value(value: Any) -> TypeGuard[Collection]:
+ """Whether a value can be used for task mapping.
- def __init__(self, cm: ContextManager[Query]):
- self._cm = cm
- self._it = None
-
- def __del__(self):
- if self._it:
- self._cm.__exit__(None, None, None)
-
- def __iter__(self):
- return self
-
- def __next__(self):
- if not self._it:
- self._it = iter(self._cm.__enter__())
- return XCom.deserialize_value(next(self._it))
-
-
-@attr.define
-class _LazyXComAccess(collections.abc.Sequence):
- """Wrapper to lazily pull XCom with a sequence-like interface.
-
- Note that since the session bound to the parent query may have died when we
- actually access the sequence's content, we must create a new session
- for every function call with ``with_session()``.
+ We only allow collections with guaranteed ordering, but exclude character
+ sequences since that's usually not what users would expect to be mappable.
"""
-
- dag_id: str
- run_id: str
- task_id: str
- _query: Query = attr.ib(repr=False)
- _len: Optional[int] = attr.ib(init=False, repr=False, default=None)
-
- @classmethod
- def build_from_single_xcom(cls, first: "XCom", query: Query) -> "_LazyXComAccess":
- return cls(
- dag_id=first.dag_id,
- run_id=first.run_id,
- task_id=first.task_id,
- query=query.with_entities(XCom.value)
- .filter(
- XCom.run_id == first.run_id,
- XCom.task_id == first.task_id,
- XCom.dag_id == first.dag_id,
- XCom.map_index >= 0,
- )
- .order_by(None)
- .order_by(XCom.map_index.asc()),
- )
-
- def __len__(self):
- if self._len is None:
- with self._get_bound_query() as query:
- self._len = query.count()
- return self._len
-
- def __iter__(self):
- return _LazyXComAccessIterator(self._get_bound_query())
-
- def __getitem__(self, key):
- if not isinstance(key, int):
- raise ValueError("only support index access for now")
- try:
- with self._get_bound_query() as query:
- r = query.offset(key).limit(1).one()
- except NoResultFound:
- raise IndexError(key) from None
- return XCom.deserialize_value(r)
-
- @contextlib.contextmanager
- def _get_bound_query(self) -> Generator[Query, None, None]:
- # Do we have a valid session already?
- if self._query.session and self._query.session.is_active:
- yield self._query
- return
-
- session = settings.Session()
- try:
- yield self._query.with_session(session)
- finally:
- session.close()
+ if not isinstance(value, (collections.abc.Sequence, dict)):
+ return False
+ if isinstance(value, (bytearray, bytes, str)):
+ return False
+ return True
class TaskInstanceKey(NamedTuple):
@@ -376,23 +304,23 @@ class TaskInstanceKey(NamedTuple):
map_index: int = -1
@property
- def primary(self) -> Tuple[str, str, str, int]:
+ def primary(self) -> tuple[str, str, str, int]:
"""Return task instance primary key part of the key"""
return self.dag_id, self.task_id, self.run_id, self.map_index
@property
- def reduced(self) -> 'TaskInstanceKey':
+ def reduced(self) -> TaskInstanceKey:
"""Remake the key by subtracting 1 from try number to match in memory information"""
return TaskInstanceKey(
self.dag_id, self.task_id, self.run_id, max(1, self.try_number - 1), self.map_index
)
- def with_try_number(self, try_number: int) -> 'TaskInstanceKey':
+ def with_try_number(self, try_number: int) -> TaskInstanceKey:
"""Returns TaskInstanceKey with provided ``try_number``"""
return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, try_number, self.map_index)
@property
- def key(self) -> "TaskInstanceKey":
+ def key(self) -> TaskInstanceKey:
"""For API-compatibly with TaskInstance.
Returns self
@@ -400,6 +328,16 @@ def key(self) -> "TaskInstanceKey":
return self
+def _creator_note(val):
+ """Custom creator for the ``note`` association proxy."""
+ if isinstance(val, str):
+ return TaskInstanceNote(content=val)
+ elif isinstance(val, dict):
+ return TaskInstanceNote(**val)
+ else:
+ return TaskInstanceNote(*val)
+
+
class TaskInstance(Base, LoggingMixin):
"""
Task instances store the state of a task instance. This table is the
@@ -419,38 +357,41 @@ class TaskInstance(Base, LoggingMixin):
"""
__tablename__ = "task_instance"
-
- task_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True, nullable=False)
- dag_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True, nullable=False)
- run_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True, nullable=False)
+ task_id = Column(StringID(), primary_key=True, nullable=False)
+ dag_id = Column(StringID(), primary_key=True, nullable=False)
+ run_id = Column(StringID(), primary_key=True, nullable=False)
map_index = Column(Integer, primary_key=True, nullable=False, server_default=text("-1"))
start_date = Column(UtcDateTime)
end_date = Column(UtcDateTime)
duration = Column(Float)
state = Column(String(20))
- _try_number = Column('try_number', Integer, default=0)
- max_tries = Column(Integer)
+ _try_number = Column("try_number", Integer, default=0)
+ max_tries = Column(Integer, server_default=text("-1"))
hostname = Column(String(1000))
unixname = Column(String(1000))
job_id = Column(Integer)
pool = Column(String(256), nullable=False)
- pool_slots = Column(Integer, default=1, nullable=False, server_default=text("1"))
+ pool_slots = Column(Integer, default=1, nullable=False)
queue = Column(String(256))
priority_weight = Column(Integer)
operator = Column(String(1000))
queued_dttm = Column(UtcDateTime)
queued_by_job_id = Column(Integer)
pid = Column(Integer)
- executor_config = Column(PickleType(pickler=dill))
+ executor_config = Column(ExecutorConfigType(pickler=dill))
+ updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow)
- external_executor_id = Column(String(ID_LEN, **COLLATION_ARGS))
+ external_executor_id = Column(StringID())
# The trigger to resume on if we are in state DEFERRED
- trigger_id = Column(BigInteger)
+ trigger_id = Column(Integer)
# Optional timeout datetime for the trigger (past this, we'll fail)
- trigger_timeout = Column(UtcDateTime)
+ trigger_timeout = Column(DateTime)
+ # The trigger_timeout should be TIMESTAMP(using UtcDateTime) but for ease of
+ # migration, we are keeping it as DateTime pending a change where expensive
+ # migration is inevitable.
# The method to call next, and any extra arguments to pass to it.
# Usually used when resuming from DEFERRED.
@@ -461,23 +402,26 @@ class TaskInstance(Base, LoggingMixin):
# refresh_from_db() or they won't display in the UI correctly
__table_args__ = (
- Index('ti_dag_state', dag_id, state),
- Index('ti_dag_run', dag_id, run_id),
- Index('ti_state', state),
- Index('ti_state_lkp', dag_id, task_id, run_id, state),
- Index('ti_pool', pool, state, priority_weight),
- Index('ti_job_id', job_id),
- Index('ti_trigger_id', trigger_id),
+ Index("ti_dag_state", dag_id, state),
+ Index("ti_dag_run", dag_id, run_id),
+ Index("ti_state", state),
+ Index("ti_state_lkp", dag_id, task_id, run_id, state),
+ Index("ti_pool", pool, state, priority_weight),
+ Index("ti_job_id", job_id),
+ Index("ti_trigger_id", trigger_id),
+ PrimaryKeyConstraint(
+ "dag_id", "task_id", "run_id", "map_index", name="task_instance_pkey", mssql_clustered=True
+ ),
ForeignKeyConstraint(
[trigger_id],
- ['trigger.id'],
- name='task_instance_trigger_id_fkey',
- ondelete='CASCADE',
+ ["trigger.id"],
+ name="task_instance_trigger_id_fkey",
+ ondelete="CASCADE",
),
ForeignKeyConstraint(
[dag_id, run_id],
["dag_run.dag_id", "dag_run.run_id"],
- name='task_instance_dag_run_fkey',
+ name="task_instance_dag_run_fkey",
ondelete="CASCADE",
),
)
@@ -491,27 +435,21 @@ class TaskInstance(Base, LoggingMixin):
viewonly=True,
)
- trigger = relationship(
- "Trigger",
- primaryjoin="TaskInstance.trigger_id == Trigger.id",
- foreign_keys=trigger_id,
- uselist=False,
- innerjoin=True,
- )
-
- dag_run = relationship("DagRun", back_populates="task_instances", lazy='joined', innerjoin=True)
- rendered_task_instance_fields = relationship("RenderedTaskInstanceFields", lazy='noload', uselist=False)
-
+ trigger = relationship("Trigger", uselist=False)
+ triggerer_job = association_proxy("trigger", "triggerer_job")
+ dag_run = relationship("DagRun", back_populates="task_instances", lazy="joined", innerjoin=True)
+ rendered_task_instance_fields = relationship("RenderedTaskInstanceFields", lazy="noload", uselist=False)
execution_date = association_proxy("dag_run", "execution_date")
-
- task: "Operator" # Not always set...
+ task_instance_note = relationship("TaskInstanceNote", back_populates="task_instance", uselist=False)
+ note = association_proxy("task_instance_note", "content", creator=_creator_note)
+ task: Operator # Not always set...
def __init__(
self,
- task: "Operator",
- execution_date: Optional[datetime] = None,
- run_id: Optional[str] = None,
- state: Optional[str] = None,
+ task: Operator,
+ execution_date: datetime | None = None,
+ run_id: str | None = None,
+ state: str | None = None,
map_index: int = -1,
):
super().__init__()
@@ -527,7 +465,7 @@ def __init__(
warnings.warn(
"Passing an execution_date to `TaskInstance()` is deprecated in favour of passing a run_id",
- DeprecationWarning,
+ RemovedInAirflow3Warning,
# Stack level is 4 because SQLA adds some wrappers around the constructor
stacklevel=4,
)
@@ -538,7 +476,8 @@ def __init__(
execution_date,
)
if self.task.has_dag():
- assert self.task.dag # For Mypy.
+ if TYPE_CHECKING:
+ assert self.task.dag
execution_date = timezone.make_aware(execution_date, self.task.dag.timezone)
else:
execution_date = timezone.make_aware(execution_date)
@@ -562,7 +501,7 @@ def __init__(
self.unixname = getuser()
if state:
self.state = state
- self.hostname = ''
+ self.hostname = ""
# Is this TaskInstance being currently running within `airflow tasks run --raw`.
# Not persisted to the database so only valid for the current process
self.raw = False
@@ -570,28 +509,28 @@ def __init__(
self.test_mode = False
@staticmethod
- def insert_mapping(run_id: str, task: "Operator", map_index: int) -> dict:
+ def insert_mapping(run_id: str, task: Operator, map_index: int) -> dict[str, Any]:
""":meta private:"""
return {
- 'dag_id': task.dag_id,
- 'task_id': task.task_id,
- 'run_id': run_id,
- '_try_number': 0,
- 'hostname': '',
- 'unixname': getuser(),
- 'queue': task.queue,
- 'pool': task.pool,
- 'pool_slots': task.pool_slots,
- 'priority_weight': task.priority_weight_total,
- 'run_as_user': task.run_as_user,
- 'max_tries': task.retries,
- 'executor_config': task.executor_config,
- 'operator': task.task_type,
- 'map_index': map_index,
+ "dag_id": task.dag_id,
+ "task_id": task.task_id,
+ "run_id": run_id,
+ "_try_number": 0,
+ "hostname": "",
+ "unixname": getuser(),
+ "queue": task.queue,
+ "pool": task.pool,
+ "pool_slots": task.pool_slots,
+ "priority_weight": task.priority_weight_total,
+ "run_as_user": task.run_as_user,
+ "max_tries": task.retries,
+ "executor_config": task.executor_config,
+ "operator": task.task_type,
+ "map_index": map_index,
}
@reconstructor
- def init_on_load(self):
+ def init_on_load(self) -> None:
"""Initialize the attributes that aren't stored in the DB"""
# correctly config the ti log
self._log = logging.getLogger("airflow.task")
@@ -607,17 +546,16 @@ def try_number(self):
database, in all other cases this will be incremented.
"""
# This is designed so that task logs end up in the right file.
- # TODO: whether we need sensing here or not (in sensor and task_instance state machine)
- if self.state in State.running:
+ if self.state == State.RUNNING:
return self._try_number
return self._try_number + 1
@try_number.setter
- def try_number(self, value):
+ def try_number(self, value: int) -> None:
self._try_number = value
@property
- def prev_attempted_tries(self):
+ def prev_attempted_tries(self) -> int:
"""
Based on this instance's try_number, this will calculate
the number of previously attempted tries, defaulting to 0.
@@ -631,8 +569,7 @@ def prev_attempted_tries(self):
return self._try_number
@property
- def next_try_number(self):
- """Setting Next Try Number"""
+ def next_try_number(self) -> int:
return self._try_number + 1
def command_as_list(
@@ -654,9 +591,9 @@ def command_as_list(
installed. This command is part of the message sent to executors by
the orchestrator.
"""
- dag: Union["DAG", "DagModel"]
+ dag: DAG | DagModel
# Use the dag if we have it, else fallback to the ORM dag_model, which might not be loaded
- if hasattr(self, 'task') and hasattr(self.task, 'dag'):
+ if hasattr(self, "task") and hasattr(self.task, "dag"):
dag = self.task.dag
else:
dag = self.dag_model
@@ -671,7 +608,7 @@ def command_as_list(
if path:
if not path.is_absolute():
- path = 'DAGS_FOLDER' / path
+ path = "DAGS_FOLDER" / path
path = str(path)
return TaskInstance.generate_command(
@@ -704,14 +641,14 @@ def generate_command(
ignore_task_deps: bool = False,
ignore_ti_state: bool = False,
local: bool = False,
- pickle_id: Optional[int] = None,
- file_path: Optional[str] = None,
+ pickle_id: int | None = None,
+ file_path: str | None = None,
raw: bool = False,
- job_id: Optional[str] = None,
- pool: Optional[str] = None,
- cfg_path: Optional[str] = None,
+ job_id: str | None = None,
+ pool: str | None = None,
+ cfg_path: str | None = None,
map_index: int = -1,
- ) -> List[str]:
+ ) -> list[str]:
"""
Generates the shell command required to execute this task instance.
@@ -735,7 +672,6 @@ def generate_command(
:param pool: the Airflow pool that the task should run in
:param cfg_path: the Path to the configuration file
:return: shell command that can be used to run the task instance
- :rtype: list[str]
"""
cmd = ["airflow", "tasks", "run", dag_id, task_id, run_id]
if mark_success:
@@ -763,22 +699,28 @@ def generate_command(
if cfg_path:
cmd.extend(["--cfg-path", cfg_path])
if map_index != -1:
- cmd.extend(['--map-index', str(map_index)])
+ cmd.extend(["--map-index", str(map_index)])
return cmd
@property
- def log_url(self):
+ def log_url(self) -> str:
"""Log URL for TaskInstance"""
iso = quote(self.execution_date.isoformat())
- base_url = conf.get('webserver', 'BASE_URL')
- return base_url + f"/log?execution_date={iso}&task_id={self.task_id}&dag_id={self.dag_id}"
+ base_url = conf.get_mandatory_value("webserver", "BASE_URL")
+ return (
+ f"{base_url}/log"
+ f"?execution_date={iso}"
+ f"&task_id={self.task_id}"
+ f"&dag_id={self.dag_id}"
+ f"&map_index={self.map_index}"
+ )
@property
- def mark_success_url(self):
+ def mark_success_url(self) -> str:
"""URL to mark TI success"""
- base_url = conf.get('webserver', 'BASE_URL')
- return base_url + (
- "/confirm"
+ base_url = conf.get_mandatory_value("webserver", "BASE_URL")
+ return (
+ f"{base_url}/confirm"
f"?task_id={self.task_id}"
f"&dag_id={self.dag_id}"
f"&dag_run_id={quote(self.run_id)}"
@@ -788,26 +730,22 @@ def mark_success_url(self):
)
@provide_session
- def current_state(self, session=NEW_SESSION) -> str:
+ def current_state(self, session: Session = NEW_SESSION) -> str:
"""
Get the very latest state from the database, if a session is passed,
we use and looking up the state becomes part of the session, otherwise
a new session is used.
+ sqlalchemy.inspect is used here to get the primary keys ensuring that if they change
+ it will not regress
+
:param session: SQLAlchemy ORM Session
"""
- return (
- session.query(TaskInstance.state)
- .filter(
- TaskInstance.dag_id == self.dag_id,
- TaskInstance.task_id == self.task_id,
- TaskInstance.run_id == self.run_id,
- )
- .scalar()
- )
+ filters = (col == getattr(self, col.name) for col in inspect(TaskInstance).primary_key)
+ return session.query(TaskInstance.state).filter(*filters).scalar()
@provide_session
- def error(self, session=NEW_SESSION):
+ def error(self, session: Session = NEW_SESSION) -> None:
"""
Forces the task instance's state to FAILED in the database.
@@ -819,7 +757,7 @@ def error(self, session=NEW_SESSION):
session.commit()
@provide_session
- def refresh_from_db(self, session=NEW_SESSION, lock_for_update=False) -> None:
+ def refresh_from_db(self, session: Session = NEW_SESSION, lock_for_update: bool = False) -> None:
"""
Refreshes the task instance from the database based on the primary key
@@ -830,28 +768,35 @@ def refresh_from_db(self, session=NEW_SESSION, lock_for_update=False) -> None:
"""
self.log.debug("Refreshing TaskInstance %s from DB", self)
- qry = session.query(TaskInstance).filter(
- TaskInstance.dag_id == self.dag_id,
- TaskInstance.task_id == self.task_id,
- TaskInstance.run_id == self.run_id,
- TaskInstance.map_index == self.map_index,
+ if self in session:
+ session.refresh(self, TaskInstance.__mapper__.column_attrs.keys())
+
+ qry = (
+ # To avoid joining any relationships, by default select all
+ # columns, not the object. This also means we get (effectively) a
+ # namedtuple back, not a TI object
+ session.query(*TaskInstance.__table__.columns).filter(
+ TaskInstance.dag_id == self.dag_id,
+ TaskInstance.task_id == self.task_id,
+ TaskInstance.run_id == self.run_id,
+ TaskInstance.map_index == self.map_index,
+ )
)
if lock_for_update:
for attempt in run_with_db_retries(logger=self.log):
with attempt:
- ti: Optional[TaskInstance] = qry.with_for_update().first()
+ ti: TaskInstance | None = qry.with_for_update().one_or_none()
else:
- ti = qry.first()
+ ti = qry.one_or_none()
if ti:
# Fields ordered per model definition
self.start_date = ti.start_date
self.end_date = ti.end_date
self.duration = ti.duration
self.state = ti.state
- # Get the raw value of try_number column, don't read through the
- # accessor here otherwise it will be incremented by one already.
- self.try_number = ti._try_number
+ # Since we selected columns, not the object, this is the raw value
+ self.try_number = ti.try_number
self.max_tries = ti.max_tries
self.hostname = ti.hostname
self.unixname = ti.unixname
@@ -872,7 +817,7 @@ def refresh_from_db(self, session=NEW_SESSION, lock_for_update=False) -> None:
else:
self.state = None
- def refresh_from_task(self, task: "Operator", pool_override=None):
+ def refresh_from_task(self, task: Operator, pool_override: str | None = None) -> None:
"""
Copy common attributes from the given task.
@@ -891,7 +836,7 @@ def refresh_from_task(self, task: "Operator", pool_override=None):
self.operator = task.task_type
@provide_session
- def clear_xcom_data(self, session: Session = NEW_SESSION):
+ def clear_xcom_data(self, session: Session = NEW_SESSION) -> None:
"""Clear all XCom data from the database for the task instance.
If the task is unmapped, all XComs matching this task ID in the same DAG
@@ -902,7 +847,7 @@ def clear_xcom_data(self, session: Session = NEW_SESSION):
"""
self.log.debug("Clearing XCom data")
if self.map_index < 0:
- map_index: Optional[int] = None
+ map_index: int | None = None
else:
map_index = self.map_index
XCom.clear(
@@ -919,13 +864,17 @@ def key(self) -> TaskInstanceKey:
return TaskInstanceKey(self.dag_id, self.task_id, self.run_id, self.try_number, self.map_index)
@provide_session
- def set_state(self, state: Optional[str], session=NEW_SESSION):
+ def set_state(self, state: str | None, session: Session = NEW_SESSION) -> bool:
"""
Set TaskInstance state.
:param state: State to set for the TI
:param session: SQLAlchemy ORM Session
+ :return: Was the state changed
"""
+ if self.state == state:
+ return False
+
current_time = timezone.utcnow()
self.log.debug("Setting task state for %s to %s", self, state)
self.state = state
@@ -934,9 +883,10 @@ def set_state(self, state: Optional[str], session=NEW_SESSION):
self.end_date = self.end_date or current_time
self.duration = (self.end_date - self.start_date).total_seconds()
session.merge(self)
+ return True
@property
- def is_premature(self):
+ def is_premature(self) -> bool:
"""
Returns whether a task is in UP_FOR_RETRY state and its retry interval
has elapsed.
@@ -945,7 +895,7 @@ def is_premature(self):
return self.state == State.UP_FOR_RETRY and not self.ready_for_retry()
@provide_session
- def are_dependents_done(self, session=NEW_SESSION):
+ def are_dependents_done(self, session: Session = NEW_SESSION) -> bool:
"""
Checks whether the immediate dependents of this task instance have succeeded or have been skipped.
This is meant to be used by wait_for_downstream.
@@ -973,9 +923,9 @@ def are_dependents_done(self, session=NEW_SESSION):
@provide_session
def get_previous_dagrun(
self,
- state: Optional[DagRunState] = None,
- session: Optional[Session] = None,
- ) -> Optional["DagRun"]:
+ state: DagRunState | None = None,
+ session: Session | None = None,
+ ) -> DagRun | None:
"""The DagRun that ran before this task instance's DagRun.
:param state: If passed, it only take into account instances of a specific state.
@@ -1006,9 +956,9 @@ def get_previous_dagrun(
@provide_session
def get_previous_ti(
self,
- state: Optional[DagRunState] = None,
+ state: DagRunState | None = None,
session: Session = NEW_SESSION,
- ) -> Optional['TaskInstance']:
+ ) -> TaskInstance | None:
"""
The task instance for the task that ran before this task instance.
@@ -1021,7 +971,7 @@ def get_previous_ti(
return dagrun.get_task_instance(self.task_id, session=session)
@property
- def previous_ti(self):
+ def previous_ti(self) -> TaskInstance | None:
"""
This attribute is deprecated.
Please use `airflow.models.taskinstance.TaskInstance.get_previous_ti` method.
@@ -1031,13 +981,13 @@ def previous_ti(self):
This attribute is deprecated.
Please use `airflow.models.taskinstance.TaskInstance.get_previous_ti` method.
""",
- DeprecationWarning,
+ RemovedInAirflow3Warning,
stacklevel=2,
)
return self.get_previous_ti()
@property
- def previous_ti_success(self) -> Optional['TaskInstance']:
+ def previous_ti_success(self) -> TaskInstance | None:
"""
This attribute is deprecated.
Please use `airflow.models.taskinstance.TaskInstance.get_previous_ti` method.
@@ -1047,7 +997,7 @@ def previous_ti_success(self) -> Optional['TaskInstance']:
This attribute is deprecated.
Please use `airflow.models.taskinstance.TaskInstance.get_previous_ti` method.
""",
- DeprecationWarning,
+ RemovedInAirflow3Warning,
stacklevel=2,
)
return self.get_previous_ti(state=DagRunState.SUCCESS)
@@ -1055,9 +1005,9 @@ def previous_ti_success(self) -> Optional['TaskInstance']:
@provide_session
def get_previous_execution_date(
self,
- state: Optional[DagRunState] = None,
+ state: DagRunState | None = None,
session: Session = NEW_SESSION,
- ) -> Optional[pendulum.DateTime]:
+ ) -> pendulum.DateTime | None:
"""
The execution date from property previous_ti_success.
@@ -1070,8 +1020,8 @@ def get_previous_execution_date(
@provide_session
def get_previous_start_date(
- self, state: Optional[DagRunState] = None, session: Session = NEW_SESSION
- ) -> Optional[pendulum.DateTime]:
+ self, state: DagRunState | None = None, session: Session = NEW_SESSION
+ ) -> pendulum.DateTime | None:
"""
The start date from property previous_ti_success.
@@ -1084,7 +1034,7 @@ def get_previous_start_date(
return prev_ti and prev_ti.start_date and pendulum.instance(prev_ti.start_date)
@property
- def previous_start_date_success(self) -> Optional[pendulum.DateTime]:
+ def previous_start_date_success(self) -> pendulum.DateTime | None:
"""
This attribute is deprecated.
Please use `airflow.models.taskinstance.TaskInstance.get_previous_start_date` method.
@@ -1094,13 +1044,15 @@ def previous_start_date_success(self) -> Optional[pendulum.DateTime]:
This attribute is deprecated.
Please use `airflow.models.taskinstance.TaskInstance.get_previous_start_date` method.
""",
- DeprecationWarning,
+ RemovedInAirflow3Warning,
stacklevel=2,
)
return self.get_previous_start_date(state=DagRunState.SUCCESS)
@provide_session
- def are_dependencies_met(self, dep_context=None, session=NEW_SESSION, verbose=False):
+ def are_dependencies_met(
+ self, dep_context: DepContext | None = None, session: Session = NEW_SESSION, verbose: bool = False
+ ) -> bool:
"""
Returns whether or not all the conditions are met for this task instance to be run
given the context for the dependencies (e.g. a task instance being force run from
@@ -1132,7 +1084,7 @@ def are_dependencies_met(self, dep_context=None, session=NEW_SESSION, verbose=Fa
return True
@provide_session
- def get_failed_dep_statuses(self, dep_context=None, session=NEW_SESSION):
+ def get_failed_dep_statuses(self, dep_context: DepContext | None = None, session: Session = NEW_SESSION):
"""Get failed Dependencies"""
dep_context = dep_context or DepContext()
for dep in dep_context.deps | self.task.deps:
@@ -1149,7 +1101,7 @@ def get_failed_dep_statuses(self, dep_context=None, session=NEW_SESSION):
if not dep_status.passed:
yield dep_status
- def __repr__(self):
+ def __repr__(self) -> str:
prefix = f" bool:
"""
Checks on whether the task instance is in the right state and timeframe
to be retried.
@@ -1202,7 +1154,7 @@ def ready_for_retry(self):
return self.state == State.UP_FOR_RETRY and self.next_retry_datetime() < timezone.utcnow()
@provide_session
- def get_dagrun(self, session: Session = NEW_SESSION) -> "DagRun":
+ def get_dagrun(self, session: Session = NEW_SESSION) -> DagRun:
"""
Returns the DagRun for this TaskInstance
@@ -1218,7 +1170,7 @@ def get_dagrun(self, session: Session = NEW_SESSION) -> "DagRun":
dr = session.query(DagRun).filter(DagRun.dag_id == self.dag_id, DagRun.run_id == self.run_id).one()
# Record it in the instance for next time. This means that `self.execution_date` will work correctly
- set_committed_value(self, 'dag_run', dr)
+ set_committed_value(self, "dag_run", dr)
return dr
@@ -1232,10 +1184,10 @@ def check_and_change_state_before_execution(
ignore_ti_state: bool = False,
mark_success: bool = False,
test_mode: bool = False,
- job_id: Optional[str] = None,
- pool: Optional[str] = None,
- external_executor_id: Optional[str] = None,
- session=NEW_SESSION,
+ job_id: str | None = None,
+ pool: str | None = None,
+ external_executor_id: str | None = None,
+ session: Session = NEW_SESSION,
) -> bool:
"""
Checks dependencies and then sets state to RUNNING if they are met. Returns
@@ -1254,7 +1206,6 @@ def check_and_change_state_before_execution(
:param external_executor_id: The identifier of the celery executor
:param session: SQLAlchemy ORM Session
:return: whether the state was changed to running or not
- :rtype: bool
"""
task = self.task
self.refresh_from_task(task, pool_override=pool)
@@ -1265,7 +1216,7 @@ def check_and_change_state_before_execution(
self.pid = None
if not ignore_all_deps and not ignore_ti_state and self.state == State.SUCCESS:
- Stats.incr('previously_succeeded', 1, 1)
+ Stats.incr("previously_succeeded", 1, 1)
# TODO: Logging needs cleanup, not clear what is being printed
hr_line_break = "\n" + ("-" * 80) # Line break
@@ -1349,32 +1300,32 @@ def check_and_change_state_before_execution(
self.log.info("Executing %s on %s", self.task, self.execution_date)
return True
- def _date_or_empty(self, attr: str):
- result: Optional[datetime] = getattr(self, attr, None)
- return result.strftime('%Y%m%dT%H%M%S') if result else ''
+ def _date_or_empty(self, attr: str) -> str:
+ result: datetime | None = getattr(self, attr, None)
+ return result.strftime("%Y%m%dT%H%M%S") if result else ""
- def _log_state(self, lead_msg: str = ''):
+ def _log_state(self, lead_msg: str = "") -> None:
params = [
lead_msg,
str(self.state).upper(),
self.dag_id,
self.task_id,
]
- message = '%sMarking task as %s. dag_id=%s, task_id=%s, '
+ message = "%sMarking task as %s. dag_id=%s, task_id=%s, "
if self.map_index >= 0:
params.append(self.map_index)
- message += 'map_index=%d, '
+ message += "map_index=%d, "
self.log.info(
- message + 'execution_date=%s, start_date=%s, end_date=%s',
+ message + "execution_date=%s, start_date=%s, end_date=%s",
*params,
- self._date_or_empty('execution_date'),
- self._date_or_empty('start_date'),
- self._date_or_empty('end_date'),
+ self._date_or_empty("execution_date"),
+ self._date_or_empty("start_date"),
+ self._date_or_empty("end_date"),
)
# Ensure we unset next_method and next_kwargs to ensure that any
# retries don't re-use them.
- def clear_next_method_args(self):
+ def clear_next_method_args(self) -> None:
self.log.debug("Clearing next_method and next_kwargs.")
self.next_method = None
@@ -1386,9 +1337,9 @@ def _run_raw_task(
self,
mark_success: bool = False,
test_mode: bool = False,
- job_id: Optional[str] = None,
- pool: Optional[str] = None,
- session=NEW_SESSION,
+ job_id: str | None = None,
+ pool: str | None = None,
+ session: Session = NEW_SESSION,
) -> None:
"""
Immediately runs the task (without checking or changing db state
@@ -1411,10 +1362,10 @@ def _run_raw_task(
session.merge(self)
session.commit()
actual_start_date = timezone.utcnow()
- Stats.incr(f'ti.start.{self.task.dag_id}.{self.task.task_id}')
+ Stats.incr(f"ti.start.{self.task.dag_id}.{self.task.task_id}")
# Initialize final state counters at zero
for state in State.task_states:
- Stats.incr(f'ti.finish.{self.task.dag_id}.{self.task.task_id}.{state}', count=0)
+ Stats.incr(f"ti.finish.{self.task.dag_id}.{self.task.task_id}.{state}", count=0)
self.task = self.task.prepare_for_execution()
context = self.get_template_context(ignore_param_exceptions=False)
@@ -1429,20 +1380,17 @@ def _run_raw_task(
# a trigger.
self._defer_task(defer=defer, session=session)
self.log.info(
- 'Pausing task as DEFERRED. dag_id=%s, task_id=%s, execution_date=%s, start_date=%s',
+ "Pausing task as DEFERRED. dag_id=%s, task_id=%s, execution_date=%s, start_date=%s",
self.dag_id,
self.task_id,
- self._date_or_empty('execution_date'),
- self._date_or_empty('start_date'),
+ self._date_or_empty("execution_date"),
+ self._date_or_empty("start_date"),
)
if not test_mode:
session.add(Log(self.state, self))
session.merge(self)
session.commit()
return
- except AirflowSmartSensorException as e:
- self.log.info(e)
- return
except AirflowSkipException as e:
# Recording SKIP
# log only if exception has any arguments to prevent log flooding
@@ -1481,7 +1429,7 @@ def _run_raw_task(
session.commit()
raise
finally:
- Stats.incr(f'ti.finish.{self.dag_id}.{self.task_id}.{self.state}')
+ Stats.incr(f"ti.finish.{self.dag_id}.{self.task_id}.{self.state}")
# Recording SKIPPED or SUCCESS
self.clear_next_method_args()
@@ -1492,14 +1440,26 @@ def _run_raw_task(
# run on_success_callback before db committing
# otherwise, the LocalTaskJob sees the state is changed to `success`,
# but the task_runner is still running, LocalTaskJob then treats the state is set externally!
- self._run_finished_callback(self.task.on_success_callback, context, 'on_success')
+ self._run_finished_callback(self.task.on_success_callback, context, "on_success")
if not test_mode:
session.add(Log(self.state, self))
- session.merge(self)
-
+ session.merge(self).task = self.task
+ if self.state == TaskInstanceState.SUCCESS:
+ self._register_dataset_changes(session=session)
session.commit()
+ def _register_dataset_changes(self, *, session: Session) -> None:
+ for obj in self.task.outlets or []:
+ self.log.debug("outlet obj %s", obj)
+ # Lineage can have other types of objects besides datasets
+ if isinstance(obj, Dataset):
+ dataset_manager.register_dataset_change(
+ task_instance=self,
+ dataset=obj,
+ session=session,
+ )
+
def _execute_task_with_callbacks(self, context, test_mode=False):
"""Prepare Task for Execution"""
from airflow.models.renderedtifields import RenderedTaskInstanceFields
@@ -1526,9 +1486,9 @@ def signal_handler(signum, frame):
if not self.next_method:
self.clear_xcom_data()
- with Stats.timer(f'dag.{self.task.dag_id}.{self.task.task_id}.duration'):
+ with Stats.timer(f"dag.{self.task.dag_id}.{self.task.task_id}.duration"):
# Set the validated/merged params on the task object.
- self.task.params = context['params']
+ self.task.params = context["params"]
task_orig = self.render_templates(context=context)
if not test_mode:
@@ -1546,7 +1506,7 @@ def signal_handler(signum, frame):
if not self.next_method:
self.log.info(
"Exporting the following env vars:\n%s",
- '\n'.join(f"{k}={v}" for k, v in airflow_context_vars.items()),
+ "\n".join(f"{k}={v}" for k, v in airflow_context_vars.items()),
)
# Run pre_execute callback
@@ -1555,22 +1515,6 @@ def signal_handler(signum, frame):
# Run on_execute callback
self._run_execute_callback(context, self.task)
- if self.task.is_smart_sensor_compatible():
- # Try to register it in the smart sensor service.
- registered = False
- try:
- registered = self.task.register_in_sensor_service(self, context)
- except Exception:
- self.log.warning(
- "Failed to register in sensor service."
- " Continue to run task in non smart sensor mode.",
- exc_info=True,
- )
-
- if registered:
- # Will raise AirflowSmartSensorException to avoid long running execution.
- self._update_ti_state_for_sensing()
-
# Execute the task
with set_current_context(context):
result = self._execute_task(context, task_orig)
@@ -1578,20 +1522,12 @@ def signal_handler(signum, frame):
# Run post_execute callback
self.task.post_execute(context=context, result=result)
- Stats.incr(f'operator_successes_{self.task.task_type}', 1, 1)
- Stats.incr('ti_successes')
+ Stats.incr(f"operator_successes_{self.task.task_type}", 1, 1)
+ Stats.incr("ti_successes")
- @provide_session
- def _update_ti_state_for_sensing(self, session=NEW_SESSION):
- self.log.info('Submitting %s to sensor service', self)
- self.state = State.SENSING
- self.start_date = timezone.utcnow()
- session.merge(self)
- session.commit()
- # Raise exception for sensing state
- raise AirflowSmartSensorException("Task successfully registered in smart sensor.")
-
- def _run_finished_callback(self, callback, context, callback_type):
+ def _run_finished_callback(
+ self, callback: TaskStateChangeCallback | None, context: Context, callback_type: str
+ ) -> None:
"""Run callback after task finishes"""
try:
if callback:
@@ -1654,7 +1590,7 @@ def _execute_task(self, context, task_orig):
return result
@provide_session
- def _defer_task(self, session, defer: TaskDeferred):
+ def _defer_task(self, session: Session, defer: TaskDeferred) -> None:
"""
Marks the task as deferred and sets up the trigger that is needed
to resume it.
@@ -1692,7 +1628,7 @@ def _defer_task(self, session, defer: TaskDeferred):
else:
self.trigger_timeout = self.start_date + execution_timeout
- def _run_execute_callback(self, context: Context, task):
+ def _run_execute_callback(self, context: Context, task: Operator) -> None:
"""Functions that need to be run before a Task is executed"""
try:
if task.on_execute_callback:
@@ -1710,9 +1646,9 @@ def run(
ignore_ti_state: bool = False,
mark_success: bool = False,
test_mode: bool = False,
- job_id: Optional[str] = None,
- pool: Optional[str] = None,
- session=NEW_SESSION,
+ job_id: str | None = None,
+ pool: str | None = None,
+ session: Session = NEW_SESSION,
) -> None:
"""Run TaskInstance"""
res = self.check_and_change_state_before_execution(
@@ -1734,13 +1670,14 @@ def run(
mark_success=mark_success, test_mode=test_mode, job_id=job_id, pool=pool, session=session
)
- def dry_run(self):
+ def dry_run(self) -> None:
"""Only Renders Templates for the TI"""
from airflow.models.baseoperator import BaseOperator
self.task = self.task.prepare_for_execution()
self.render_templates()
- assert isinstance(self.task, BaseOperator) # For Mypy.
+ if TYPE_CHECKING:
+ assert isinstance(self.task, BaseOperator)
self.task.dry_run()
@provide_session
@@ -1777,6 +1714,7 @@ def _handle_reschedule(
actual_start_date,
self.end_date,
reschedule_exception.reschedule_date,
+ self.map_index,
)
)
@@ -1791,10 +1729,10 @@ def _handle_reschedule(
session.merge(self)
session.commit()
- self.log.info('Rescheduling task, marking task as UP_FOR_RESCHEDULE')
+ self.log.info("Rescheduling task, marking task as UP_FOR_RESCHEDULE")
@staticmethod
- def get_truncated_error_traceback(error: BaseException, truncate_to: Callable) -> Optional[TracebackType]:
+ def get_truncated_error_traceback(error: BaseException, truncate_to: Callable) -> TracebackType | None:
"""
Truncates the traceback of an exception to the first frame called from within a given function
@@ -1812,14 +1750,18 @@ def get_truncated_error_traceback(error: BaseException, truncate_to: Callable) -
return tb or error.__traceback__
@provide_session
- def handle_failure(self, error, test_mode=None, context=None, force_fail=False, session=None) -> None:
+ def handle_failure(
+ self,
+ error: None | str | Exception | KeyboardInterrupt,
+ test_mode: bool | None = None,
+ context: Context | None = None,
+ force_fail: bool = False,
+ session: Session = NEW_SESSION,
+ ) -> None:
"""Handle Failure for the TaskInstance"""
if test_mode is None:
test_mode = self.test_mode
- if context is None:
- context = self.get_template_context()
-
if error:
if isinstance(error, BaseException):
tb = self.get_truncated_error_traceback(error, truncate_to=self._execute_task)
@@ -1831,8 +1773,8 @@ def handle_failure(self, error, test_mode=None, context=None, force_fail=False,
self.end_date = timezone.utcnow()
self.set_duration()
- Stats.incr(f'operator_failures_{self.task.task_type}')
- Stats.incr('ti_failures')
+ Stats.incr(f"operator_failures_{self.operator}")
+ Stats.incr("ti_failures")
if not test_mode:
session.add(Log(State.FAILED, self))
@@ -1841,8 +1783,12 @@ def handle_failure(self, error, test_mode=None, context=None, force_fail=False,
self.clear_next_method_args()
+ # In extreme cases (zombie in case of dag with parse error) we might _not_ have a Task.
+ if context is None and getattr(self, "task", None):
+ context = self.get_template_context(session)
+
if context is not None:
- context['exception'] = error
+ context["exception"] = error
# Set state correctly and figure out how to log it and decide whether
# to email
@@ -1856,34 +1802,35 @@ def handle_failure(self, error, test_mode=None, context=None, force_fail=False,
# only mark task instance as FAILED if the next task instance
# try_number exceeds the max_tries ... or if force_fail is truthy
- task = None
+ task: BaseOperator | None = None
try:
- task = self.task.unmap()
+ if getattr(self, "task", None) and context:
+ task = self.task.unmap((context, session))
except Exception:
- self.log.error("Unable to unmap task, can't determine if we need to send an alert email or not")
+ self.log.error("Unable to unmap task to determine if we need to send an alert email")
if force_fail or not self.is_eligible_to_retry():
self.state = State.FAILED
- email_for_state = operator.attrgetter('email_on_failure')
+ email_for_state = operator.attrgetter("email_on_failure")
callback = task.on_failure_callback if task else None
- callback_type = 'on_failure'
+ callback_type = "on_failure"
else:
if self.state == State.QUEUED:
# We increase the try_number so as to fail the task if it fails to start after sometime
self._try_number += 1
self.state = State.UP_FOR_RETRY
- email_for_state = operator.attrgetter('email_on_retry')
+ email_for_state = operator.attrgetter("email_on_retry")
callback = task.on_retry_callback if task else None
- callback_type = 'on_retry'
+ callback_type = "on_retry"
- self._log_state('Immediate failure requested. ' if force_fail else '')
+ self._log_state("Immediate failure requested. " if force_fail else "")
if task and email_for_state(task) and task.email:
try:
self.email_alert(error, task)
except Exception:
- self.log.exception('Failed to send email to: %s', task.email)
+ self.log.exception("Failed to send email to: %s", task.email)
- if callback:
+ if callback and context:
self._run_finished_callback(callback, context, callback_type)
if not test_mode:
@@ -1896,6 +1843,9 @@ def is_eligible_to_retry(self):
# If a task is cleared when running, it goes into RESTARTING state and is always
# eligible for retry
return True
+ if not getattr(self, "task", None):
+ # Couldn't load the task, don't know number of retries, guess:
+ return self.try_number <= self.max_tries
return self.task.retries and self.try_number <= self.max_tries
@@ -1908,56 +1858,50 @@ def get_template_context(
session = settings.Session()
from airflow import macros
+ from airflow.models.abstractoperator import NotMapped
integrate_macros_plugins()
task = self.task
- assert task.dag # For Mypy.
+ if TYPE_CHECKING:
+ assert task.dag
dag: DAG = task.dag
dag_run = self.get_dagrun(session)
data_interval = dag.get_run_data_interval(dag_run)
- # Validates Params and convert them into a simple dict.
- params = ParamsDict(suppress_exception=ignore_param_exceptions)
- with contextlib.suppress(AttributeError):
- params.update(dag.params)
- if task.params:
- params.update(task.params)
- if conf.getboolean('core', 'dag_run_conf_overrides_params'):
- self.overwrite_params_with_dag_run_conf(params=params, dag_run=dag_run)
- validated_params = params.validate()
+ validated_params = process_params(dag, task, dag_run, suppress_exception=ignore_param_exceptions)
logical_date = timezone.coerce_datetime(self.execution_date)
- ds = logical_date.strftime('%Y-%m-%d')
- ds_nodash = ds.replace('-', '')
+ ds = logical_date.strftime("%Y-%m-%d")
+ ds_nodash = ds.replace("-", "")
ts = logical_date.isoformat()
- ts_nodash = logical_date.strftime('%Y%m%dT%H%M%S')
- ts_nodash_with_tz = ts.replace('-', '').replace(':', '')
+ ts_nodash = logical_date.strftime("%Y%m%dT%H%M%S")
+ ts_nodash_with_tz = ts.replace("-", "").replace(":", "")
@cache # Prevent multiple database access.
- def _get_previous_dagrun_success() -> Optional["DagRun"]:
+ def _get_previous_dagrun_success() -> DagRun | None:
return self.get_previous_dagrun(state=DagRunState.SUCCESS, session=session)
- def _get_previous_dagrun_data_interval_success() -> Optional["DataInterval"]:
+ def _get_previous_dagrun_data_interval_success() -> DataInterval | None:
dagrun = _get_previous_dagrun_success()
if dagrun is None:
return None
return dag.get_run_data_interval(dagrun)
- def get_prev_data_interval_start_success() -> Optional[pendulum.DateTime]:
+ def get_prev_data_interval_start_success() -> pendulum.DateTime | None:
data_interval = _get_previous_dagrun_data_interval_success()
if data_interval is None:
return None
return data_interval.start
- def get_prev_data_interval_end_success() -> Optional[pendulum.DateTime]:
+ def get_prev_data_interval_end_success() -> pendulum.DateTime | None:
data_interval = _get_previous_dagrun_data_interval_success()
if data_interval is None:
return None
return data_interval.end
- def get_prev_start_date_success() -> Optional[pendulum.DateTime]:
+ def get_prev_start_date_success() -> pendulum.DateTime | None:
dagrun = _get_previous_dagrun_success()
if dagrun is None:
return None
@@ -1965,20 +1909,20 @@ def get_prev_start_date_success() -> Optional[pendulum.DateTime]:
@cache
def get_yesterday_ds() -> str:
- return (logical_date - timedelta(1)).strftime('%Y-%m-%d')
+ return (logical_date - timedelta(1)).strftime("%Y-%m-%d")
def get_yesterday_ds_nodash() -> str:
- return get_yesterday_ds().replace('-', '')
+ return get_yesterday_ds().replace("-", "")
@cache
def get_tomorrow_ds() -> str:
- return (logical_date + timedelta(1)).strftime('%Y-%m-%d')
+ return (logical_date + timedelta(1)).strftime("%Y-%m-%d")
def get_tomorrow_ds_nodash() -> str:
- return get_tomorrow_ds().replace('-', '')
+ return get_tomorrow_ds().replace("-", "")
@cache
- def get_next_execution_date() -> Optional[pendulum.DateTime]:
+ def get_next_execution_date() -> pendulum.DateTime | None:
# For manually triggered dagruns that aren't run on a schedule,
# the "next" execution date doesn't make sense, and should be set
# to execution date for consistency with how execution_date is set
@@ -1992,17 +1936,17 @@ def get_next_execution_date() -> Optional[pendulum.DateTime]:
return None
return timezone.coerce_datetime(next_info.logical_date)
- def get_next_ds() -> Optional[str]:
+ def get_next_ds() -> str | None:
execution_date = get_next_execution_date()
if execution_date is None:
return None
- return execution_date.strftime('%Y-%m-%d')
+ return execution_date.strftime("%Y-%m-%d")
- def get_next_ds_nodash() -> Optional[str]:
+ def get_next_ds_nodash() -> str | None:
ds = get_next_ds()
if ds is None:
return ds
- return ds.replace('-', '')
+ return ds.replace("-", "")
@cache
def get_prev_execution_date():
@@ -2013,70 +1957,91 @@ def get_prev_execution_date():
if dag_run.external_trigger:
return logical_date
with warnings.catch_warnings():
- warnings.simplefilter("ignore", DeprecationWarning)
+ warnings.simplefilter("ignore", RemovedInAirflow3Warning)
return dag.previous_schedule(logical_date)
@cache
- def get_prev_ds() -> Optional[str]:
+ def get_prev_ds() -> str | None:
execution_date = get_prev_execution_date()
if execution_date is None:
return None
- return execution_date.strftime(r'%Y-%m-%d')
+ return execution_date.strftime(r"%Y-%m-%d")
- def get_prev_ds_nodash() -> Optional[str]:
+ def get_prev_ds_nodash() -> str | None:
prev_ds = get_prev_ds()
if prev_ds is None:
return None
- return prev_ds.replace('-', '')
+ return prev_ds.replace("-", "")
+
+ def get_triggering_events() -> dict[str, list[DatasetEvent]]:
+ nonlocal dag_run
+ # The dag_run may not be attached to the session anymore (code base is over-zealous with use of
+ # `session.expunge_all()`) so re-attach it if we get called
+ if dag_run not in session:
+ dag_run = session.merge(dag_run, load=False)
+
+ dataset_events = dag_run.consumed_dataset_events
+ triggering_events: dict[str, list[DatasetEvent]] = defaultdict(list)
+ for event in dataset_events:
+ triggering_events[event.dataset.uri].append(event)
+
+ return triggering_events
+
+ try:
+ expanded_ti_count: int | None = task.get_mapped_ti_count(self.run_id, session=session)
+ except NotMapped:
+ expanded_ti_count = None
# NOTE: If you add anything to this dict, make sure to also update the
# definition in airflow/utils/context.pyi, and KNOWN_CONTEXT_KEYS in
# airflow/utils/context.py!
context = {
- 'conf': conf,
- 'dag': dag,
- 'dag_run': dag_run,
- 'data_interval_end': timezone.coerce_datetime(data_interval.end),
- 'data_interval_start': timezone.coerce_datetime(data_interval.start),
- 'ds': ds,
- 'ds_nodash': ds_nodash,
- 'execution_date': logical_date,
- 'inlets': task.inlets,
- 'logical_date': logical_date,
- 'macros': macros,
- 'next_ds': get_next_ds(),
- 'next_ds_nodash': get_next_ds_nodash(),
- 'next_execution_date': get_next_execution_date(),
- 'outlets': task.outlets,
- 'params': validated_params,
- 'prev_data_interval_start_success': get_prev_data_interval_start_success(),
- 'prev_data_interval_end_success': get_prev_data_interval_end_success(),
- 'prev_ds': get_prev_ds(),
- 'prev_ds_nodash': get_prev_ds_nodash(),
- 'prev_execution_date': get_prev_execution_date(),
- 'prev_execution_date_success': self.get_previous_execution_date(
+ "conf": conf,
+ "dag": dag,
+ "dag_run": dag_run,
+ "data_interval_end": timezone.coerce_datetime(data_interval.end),
+ "data_interval_start": timezone.coerce_datetime(data_interval.start),
+ "ds": ds,
+ "ds_nodash": ds_nodash,
+ "execution_date": logical_date,
+ "expanded_ti_count": expanded_ti_count,
+ "inlets": task.inlets,
+ "logical_date": logical_date,
+ "macros": macros,
+ "next_ds": get_next_ds(),
+ "next_ds_nodash": get_next_ds_nodash(),
+ "next_execution_date": get_next_execution_date(),
+ "outlets": task.outlets,
+ "params": validated_params,
+ "prev_data_interval_start_success": get_prev_data_interval_start_success(),
+ "prev_data_interval_end_success": get_prev_data_interval_end_success(),
+ "prev_ds": get_prev_ds(),
+ "prev_ds_nodash": get_prev_ds_nodash(),
+ "prev_execution_date": get_prev_execution_date(),
+ "prev_execution_date_success": self.get_previous_execution_date(
state=DagRunState.SUCCESS,
session=session,
),
- 'prev_start_date_success': get_prev_start_date_success(),
- 'run_id': self.run_id,
- 'task': task,
- 'task_instance': self,
- 'task_instance_key_str': f"{task.dag_id}__{task.task_id}__{ds_nodash}",
- 'test_mode': self.test_mode,
- 'ti': self,
- 'tomorrow_ds': get_tomorrow_ds(),
- 'tomorrow_ds_nodash': get_tomorrow_ds_nodash(),
- 'ts': ts,
- 'ts_nodash': ts_nodash,
- 'ts_nodash_with_tz': ts_nodash_with_tz,
- 'var': {
- 'json': VariableAccessor(deserialize_json=True),
- 'value': VariableAccessor(deserialize_json=False),
+ "prev_start_date_success": get_prev_start_date_success(),
+ "run_id": self.run_id,
+ "task": task,
+ "task_instance": self,
+ "task_instance_key_str": f"{task.dag_id}__{task.task_id}__{ds_nodash}",
+ "test_mode": self.test_mode,
+ "ti": self,
+ "tomorrow_ds": get_tomorrow_ds(),
+ "tomorrow_ds_nodash": get_tomorrow_ds_nodash(),
+ "triggering_dataset_events": lazy_object_proxy.Proxy(get_triggering_events),
+ "ts": ts,
+ "ts_nodash": ts_nodash,
+ "ts_nodash_with_tz": ts_nodash_with_tz,
+ "var": {
+ "json": VariableAccessor(deserialize_json=True),
+ "value": VariableAccessor(deserialize_json=False),
},
- 'conn': ConnectionAccessor(),
- 'yesterday_ds': get_yesterday_ds(),
- 'yesterday_ds_nodash': get_yesterday_ds_nodash(),
+ "conn": ConnectionAccessor(),
+ "yesterday_ds": get_yesterday_ds(),
+ "yesterday_ds_nodash": get_yesterday_ds_nodash(),
}
# Mypy doesn't like turning existing dicts in to a TypeDict -- and we "lie" in the type stub to say it
# is one, but in practice it isn't. See https://github.com/python/mypy/issues/8890
@@ -2092,7 +2057,7 @@ def get_rendered_template_fields(self, session: Session = NEW_SESSION) -> None:
rendered_task_instance_fields = RenderedTaskInstanceFields.get_templated_fields(self, session=session)
if rendered_task_instance_fields:
- self.task = self.task.unmap()
+ self.task = self.task.unmap(None)
for field_name, rendered_value in rendered_task_instance_fields.items():
setattr(self.task, field_name, rendered_value)
return
@@ -2114,7 +2079,7 @@ def get_rendered_template_fields(self, session: Session = NEW_SESSION) -> None:
) from e
@provide_session
- def get_rendered_k8s_spec(self, session=NEW_SESSION):
+ def get_rendered_k8s_spec(self, session: Session = NEW_SESSION):
"""Fetch rendered template fields from DB"""
from airflow.models.renderedtifields import RenderedTaskInstanceFields
@@ -2132,7 +2097,7 @@ def overwrite_params_with_dag_run_conf(self, params, dag_run):
self.log.debug("Updating task params (%s) with DagRun.conf (%s)", params, dag_run.conf)
params.update(dag_run.conf)
- def render_templates(self, context: Optional[Context] = None) -> "Operator":
+ def render_templates(self, context: Context | None = None) -> Operator:
"""Render templates in the operator fields.
If the task was originally mapped, this may replace ``self.task`` with
@@ -2141,13 +2106,17 @@ def render_templates(self, context: Optional[Context] = None) -> "Operator":
"""
if not context:
context = self.get_template_context()
- rendered_task = self.task.render_template_fields(context)
- if rendered_task is None: # Compatibility -- custom renderer, assume unmapped.
- return self.task
- original_task, self.task = self.task, rendered_task
+ original_task = self.task
+
+ # If self.task is mapped, this call replaces self.task to point to the
+ # unmapped BaseOperator created by this function! This is because the
+ # MappedOperator is useless for template rendering, and we need to be
+ # able to access the unmapped task instead.
+ original_task.render_template_fields(context)
+
return original_task
- def render_k8s_pod_yaml(self) -> Optional[dict]:
+ def render_k8s_pod_yaml(self) -> dict | None:
"""Render k8s pod yaml"""
from kubernetes.client.api_client import ApiClient
@@ -2176,40 +2145,39 @@ def render_k8s_pod_yaml(self) -> Optional[dict]:
return sanitized_pod
def get_email_subject_content(
- self, exception: BaseException, task: Optional["BaseOperator"] = None
- ) -> Tuple[str, str, str]:
+ self, exception: BaseException, task: BaseOperator | None = None
+ ) -> tuple[str, str, str]:
"""Get the email subject content for exceptions."""
# For a ti from DB (without ti.task), return the default value
- # Reuse it for smart sensor to send default email alert
if task is None:
- task = getattr(self, 'task')
+ task = getattr(self, "task")
use_default = task is None
- exception_html = str(exception).replace('\n', '
')
+ exception_html = str(exception).replace("\n", "
")
- default_subject = 'Airflow alert: {{ti}}'
+ default_subject = "Airflow alert: {{ti}}"
# For reporting purposes, we report based on 1-indexed,
# not 0-indexed lists (i.e. Try 1 instead of
# Try 0 for the first attempt).
default_html_content = (
- 'Try {{try_number}} out of {{max_tries + 1}}
'
- 'Exception:
{{exception_html}}
'
+ "Try {{try_number}} out of {{max_tries + 1}}
"
+ "Exception:
{{exception_html}}
"
'Log: Link
'
- 'Host: {{ti.hostname}}
'
+ "Host: {{ti.hostname}}
"
'Mark success: Link
'
)
default_html_content_err = (
- 'Try {{try_number}} out of {{max_tries + 1}}
'
- 'Exception:
Failed attempt to attach error logs
'
+ "Try {{try_number}} out of {{max_tries + 1}}
"
+ "Exception:
Failed attempt to attach error logs
"
'Log: Link
'
- 'Host: {{ti.hostname}}
'
+ "Host: {{ti.hostname}}
"
'Mark success: Link
'
)
# This function is called after changing the state from State.RUNNING,
# so we need to subtract 1 from self.try_number here.
current_try_number = self.try_number - 1
- additional_context: Dict[str, Any] = {
+ additional_context: dict[str, Any] = {
"exception": exception,
"exception_html": exception_html,
"try_number": current_try_number,
@@ -2238,19 +2206,24 @@ def get_email_subject_content(
context_merge(jinja_context, additional_context)
def render(key: str, content: str) -> str:
- if conf.has_option('email', key):
- path = conf.get_mandatory_value('email', key)
- with open(path) as f:
- content = f.read()
+ if conf.has_option("email", key):
+ path = conf.get_mandatory_value("email", key)
+ try:
+ with open(path) as f:
+ content = f.read()
+ except FileNotFoundError:
+ self.log.warning(f"Could not find email template file '{path!r}'. Using defaults...")
+ except OSError:
+ self.log.exception(f"Error while using email template '{path!r}'. Using defaults...")
return render_template_to_string(jinja_env.from_string(content), jinja_context)
- subject = render('subject_template', default_subject)
- html_content = render('html_content_template', default_html_content)
- html_content_err = render('html_content_template', default_html_content_err)
+ subject = render("subject_template", default_subject)
+ html_content = render("html_content_template", default_html_content)
+ html_content_err = render("html_content_template", default_html_content_err)
return subject, html_content, html_content_err
- def email_alert(self, exception, task: "BaseOperator"):
+ def email_alert(self, exception, task: BaseOperator) -> None:
"""Send alert email with exception information."""
subject, html_content, html_content_err = self.get_email_subject_content(exception, task=task)
assert task.email
@@ -2267,18 +2240,18 @@ def set_duration(self) -> None:
self.duration = None
self.log.debug("Task Duration set to %s", self.duration)
- def _record_task_map_for_downstreams(self, task: "Operator", value: Any, *, session: Session) -> None:
+ def _record_task_map_for_downstreams(self, task: Operator, value: Any, *, session: Session) -> None:
+ if next(task.iter_mapped_dependants(), None) is None: # No mapped dependants, no need to validate.
+ return
# TODO: We don't push TaskMap for mapped task instances because it's not
# currently possible for a downstream to depend on one individual mapped
- # task instance, only a task as a whole. This will change in AIP-42
- # Phase 2, and we'll need to further analyze the mapped task case.
- if next(task.iter_mapped_dependants(), None) is None:
+ # task instance. This will change when we implement task mapping inside
+ # a mapped task group, and we'll need to further analyze the case.
+ if isinstance(task, MappedOperator):
return
if value is None:
raise XComForMappingNotPushed()
- if task.is_mapped:
- return
- if not isinstance(value, collections.abc.Collection) or isinstance(value, (bytes, str)):
+ if not _is_mappable_value(value):
raise UnmappableXComTypePushed(value)
task_map = TaskMap.from_task_instance_xcom(self, value)
max_map_length = conf.getint("core", "max_map_length", fallback=1024)
@@ -2291,7 +2264,7 @@ def xcom_push(
self,
key: str,
value: Any,
- execution_date: Optional[datetime] = None,
+ execution_date: datetime | None = None,
session: Session = NEW_SESSION,
) -> None:
"""
@@ -2307,12 +2280,12 @@ def xcom_push(
self_execution_date = self.get_dagrun(session).execution_date
if execution_date < self_execution_date:
raise ValueError(
- f'execution_date can not be in the past (current execution_date is '
- f'{self_execution_date}; received {execution_date})'
+ f"execution_date can not be in the past (current execution_date is "
+ f"{self_execution_date}; received {execution_date})"
)
elif execution_date is not None:
message = "Passing 'execution_date' to 'TaskInstance.xcom_push()' is deprecated."
- warnings.warn(message, DeprecationWarning, stacklevel=3)
+ warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3)
XCom.set(
key=key,
@@ -2327,13 +2300,13 @@ def xcom_push(
@provide_session
def xcom_pull(
self,
- task_ids: Optional[Union[str, Iterable[str]]] = None,
- dag_id: Optional[str] = None,
+ task_ids: str | Iterable[str] | None = None,
+ dag_id: str | None = None,
key: str = XCOM_RETURN_KEY,
include_prior_dates: bool = False,
session: Session = NEW_SESSION,
*,
- map_indexes: Optional[Union[int, Iterable[int]]] = None,
+ map_indexes: int | Iterable[int] | None = None,
default: Any = None,
) -> Any:
"""Pull XComs that optionally meet certain criteria.
@@ -2392,38 +2365,37 @@ def xcom_pull(
return default
if map_indexes is not None or first.map_index < 0:
return XCom.deserialize_value(first)
-
- return _LazyXComAccess.build_from_single_xcom(first, query)
+ query = query.order_by(None).order_by(XCom.map_index.asc())
+ return LazyXComAccess.build_from_xcom_query(query)
# At this point either task_ids or map_indexes is explicitly multi-value.
-
- results = (
- (r.task_id, r.map_index, XCom.deserialize_value(r))
- for r in query.with_entities(XCom.task_id, XCom.map_index, XCom.value)
- )
-
- if task_ids is None:
- task_id_pos: Dict[str, int] = defaultdict(int)
- elif isinstance(task_ids, str):
- task_id_pos = {task_ids: 0}
+ # Order return values to match task_ids and map_indexes ordering.
+ query = query.order_by(None)
+ if task_ids is None or isinstance(task_ids, str):
+ query = query.order_by(XCom.task_id)
else:
- task_id_pos = {task_id: i for i, task_id in enumerate(task_ids)}
- if map_indexes is None:
- map_index_pos: Dict[int, int] = defaultdict(int)
- elif isinstance(map_indexes, int):
- map_index_pos = {map_indexes: 0}
+ task_id_whens = {tid: i for i, tid in enumerate(task_ids)}
+ if task_id_whens:
+ query = query.order_by(case(task_id_whens, value=XCom.task_id))
+ else:
+ query = query.order_by(XCom.task_id)
+ if map_indexes is None or isinstance(map_indexes, int):
+ query = query.order_by(XCom.map_index)
+ elif isinstance(map_indexes, range):
+ order = XCom.map_index
+ if map_indexes.step < 0:
+ order = order.desc()
+ query = query.order_by(order)
else:
- map_index_pos = {map_index: i for i, map_index in enumerate(map_indexes)}
-
- def _arg_pos(item: Tuple[str, int, Any]) -> Tuple[int, int]:
- task_id, map_index, _ = item
- return task_id_pos[task_id], map_index_pos[map_index]
-
- results_sorted_by_arg_pos = sorted(results, key=_arg_pos)
- return [value for _, _, value in results_sorted_by_arg_pos]
+ map_index_whens = {map_index: i for i, map_index in enumerate(map_indexes)}
+ if map_index_whens:
+ query = query.order_by(case(map_index_whens, value=XCom.map_index))
+ else:
+ query = query.order_by(XCom.map_index)
+ return LazyXComAccess.build_from_xcom_query(query)
@provide_session
- def get_num_running_task_instances(self, session):
+ def get_num_running_task_instances(self, session: Session) -> int:
"""Return Number of running TIs from the DB"""
# .count() is inefficient
return (
@@ -2436,13 +2408,13 @@ def get_num_running_task_instances(self, session):
.scalar()
)
- def init_run_context(self, raw=False):
+ def init_run_context(self, raw: bool = False) -> None:
"""Sets the log context."""
self.raw = raw
self._set_context(self)
@staticmethod
- def filter_for_tis(tis: Iterable[Union["TaskInstance", TaskInstanceKey]]) -> Optional[BooleanClauseList]:
+ def filter_for_tis(tis: Iterable[TaskInstance | TaskInstanceKey]) -> BooleanClauseList | None:
"""Returns SQLAlchemy filter to query selected task instances"""
# DictKeys type, (what we often pass here from the scheduler) is not directly indexable :(
# Or it might be a generator, but we need to be able to iterate over it more than once
@@ -2457,37 +2429,81 @@ def filter_for_tis(tis: Iterable[Union["TaskInstance", TaskInstanceKey]]) -> Opt
run_id = first.run_id
map_index = first.map_index
first_task_id = first.task_id
+
+ # pre-compute the set of dag_id, run_id, map_indices and task_ids
+ dag_ids, run_ids, map_indices, task_ids = set(), set(), set(), set()
+ for t in tis:
+ dag_ids.add(t.dag_id)
+ run_ids.add(t.run_id)
+ map_indices.add(t.map_index)
+ task_ids.add(t.task_id)
+
# Common path optimisations: when all TIs are for the same dag_id and run_id, or same dag_id
# and task_id -- this can be over 150x faster for huge numbers of TIs (20k+)
- if all(t.dag_id == dag_id and t.run_id == run_id and t.map_index == map_index for t in tis):
+ if dag_ids == {dag_id} and run_ids == {run_id} and map_indices == {map_index}:
return and_(
TaskInstance.dag_id == dag_id,
TaskInstance.run_id == run_id,
TaskInstance.map_index == map_index,
- TaskInstance.task_id.in_(t.task_id for t in tis),
+ TaskInstance.task_id.in_(task_ids),
)
- if all(t.dag_id == dag_id and t.task_id == first_task_id and t.map_index == map_index for t in tis):
+ if dag_ids == {dag_id} and task_ids == {first_task_id} and map_indices == {map_index}:
return and_(
TaskInstance.dag_id == dag_id,
- TaskInstance.run_id.in_(t.run_id for t in tis),
+ TaskInstance.run_id.in_(run_ids),
TaskInstance.map_index == map_index,
TaskInstance.task_id == first_task_id,
)
- if all(t.dag_id == dag_id and t.run_id == run_id and t.task_id == first_task_id for t in tis):
+ if dag_ids == {dag_id} and run_ids == {run_id} and task_ids == {first_task_id}:
return and_(
TaskInstance.dag_id == dag_id,
TaskInstance.run_id == run_id,
- TaskInstance.map_index.in_(t.map_index for t in tis),
+ TaskInstance.map_index.in_(map_indices),
TaskInstance.task_id == first_task_id,
)
- return tuple_in_condition(
- (TaskInstance.dag_id, TaskInstance.task_id, TaskInstance.run_id, TaskInstance.map_index),
- (ti.key.primary for ti in tis),
- )
+ filter_condition = []
+ # create 2 nested groups, both primarily grouped by dag_id and run_id,
+ # and in the nested group 1 grouped by task_id the other by map_index.
+ task_id_groups: dict[tuple, dict[Any, list[Any]]] = defaultdict(lambda: defaultdict(list))
+ map_index_groups: dict[tuple, dict[Any, list[Any]]] = defaultdict(lambda: defaultdict(list))
+ for t in tis:
+ task_id_groups[(t.dag_id, t.run_id)][t.task_id].append(t.map_index)
+ map_index_groups[(t.dag_id, t.run_id)][t.map_index].append(t.task_id)
+
+ # this assumes that most dags have dag_id as the largest grouping, followed by run_id. even
+ # if its not, this is still a significant optimization over querying for every single tuple key
+ for cur_dag_id in dag_ids:
+ for cur_run_id in run_ids:
+ # we compare the group size between task_id and map_index and use the smaller group
+ dag_task_id_groups = task_id_groups[(cur_dag_id, cur_run_id)]
+ dag_map_index_groups = map_index_groups[(cur_dag_id, cur_run_id)]
+
+ if len(dag_task_id_groups) <= len(dag_map_index_groups):
+ for cur_task_id, cur_map_indices in dag_task_id_groups.items():
+ filter_condition.append(
+ and_(
+ TaskInstance.dag_id == cur_dag_id,
+ TaskInstance.run_id == cur_run_id,
+ TaskInstance.task_id == cur_task_id,
+ TaskInstance.map_index.in_(cur_map_indices),
+ )
+ )
+ else:
+ for cur_map_index, cur_task_ids in dag_map_index_groups.items():
+ filter_condition.append(
+ and_(
+ TaskInstance.dag_id == cur_dag_id,
+ TaskInstance.run_id == cur_run_id,
+ TaskInstance.task_id.in_(cur_task_ids),
+ TaskInstance.map_index == cur_map_index,
+ )
+ )
+
+ return or_(*filter_condition)
@classmethod
- def ti_selector_condition(cls, vals: Collection[Union[str, Tuple[str, int]]]) -> ColumnOperators:
+ def ti_selector_condition(cls, vals: Collection[str | tuple[str, int]]) -> ColumnOperators:
"""
Build an SQLAlchemy filter for a list where each element can contain
whether a task_id, or a tuple of (task_id,map_index)
@@ -2499,7 +2515,7 @@ def ti_selector_condition(cls, vals: Collection[Union[str, Tuple[str, int]]]) ->
task_id_only = [v for v in vals if isinstance(v, str)]
with_map_index = [v for v in vals if not isinstance(v, str)]
- filters: List[ColumnOperators] = []
+ filters: list[ColumnOperators] = []
if task_id_only:
filters.append(cls.task_id.in_(task_id_only))
if with_map_index:
@@ -2511,6 +2527,164 @@ def ti_selector_condition(cls, vals: Collection[Union[str, Tuple[str, int]]]) ->
return filters[0]
return or_(*filters)
+ @Sentry.enrich_errors
+ @provide_session
+ def schedule_downstream_tasks(self, session=None):
+ """
+ The mini-scheduler for scheduling downstream tasks of this task instance
+ :meta: private
+ """
+ from sqlalchemy.exc import OperationalError
+
+ from airflow.models import DagRun
+
+ try:
+ # Re-select the row with a lock
+ dag_run = with_row_locks(
+ session.query(DagRun).filter_by(
+ dag_id=self.dag_id,
+ run_id=self.run_id,
+ ),
+ session=session,
+ ).one()
+
+ task = self.task
+ if TYPE_CHECKING:
+ assert task.dag
+
+ # Get a partial DAG with just the specific tasks we want to examine.
+ # In order for dep checks to work correctly, we include ourself (so
+ # TriggerRuleDep can check the state of the task we just executed).
+ partial_dag = task.dag.partial_subset(
+ task.downstream_task_ids,
+ include_downstream=True,
+ include_upstream=False,
+ include_direct_upstream=True,
+ )
+
+ dag_run.dag = partial_dag
+ info = dag_run.task_instance_scheduling_decisions(session)
+
+ skippable_task_ids = {
+ task_id for task_id in partial_dag.task_ids if task_id not in task.downstream_task_ids
+ }
+
+ schedulable_tis = [ti for ti in info.schedulable_tis if ti.task_id not in skippable_task_ids]
+ for schedulable_ti in schedulable_tis:
+ if not hasattr(schedulable_ti, "task"):
+ schedulable_ti.task = task.dag.get_task(schedulable_ti.task_id)
+
+ num = dag_run.schedule_tis(schedulable_tis, session=session)
+ self.log.info("%d downstream tasks scheduled from follow-on schedule check", num)
+
+ session.flush()
+
+ except OperationalError as e:
+ # Any kind of DB error here is _non fatal_ as this block is just an optimisation.
+ self.log.info(
+ "Skipping mini scheduling run due to exception: %s",
+ e.statement,
+ exc_info=True,
+ )
+ session.rollback()
+
+ def get_relevant_upstream_map_indexes(
+ self,
+ upstream: Operator,
+ ti_count: int | None,
+ *,
+ session: Session,
+ ) -> int | range | None:
+ """Infer the map indexes of an upstream "relevant" to this ti.
+
+ The bulk of the logic mainly exists to solve the problem described by
+ the following example, where 'val' must resolve to different values,
+ depending on where the reference is being used::
+
+ @task
+ def this_task(v): # This is self.task.
+ return v * 2
+
+ @task_group
+ def tg1(inp):
+ val = upstream(inp) # This is the upstream task.
+ this_task(val) # When inp is 1, val here should resolve to 2.
+ return val
+
+ # This val is the same object returned by tg1.
+ val = tg1.expand(inp=[1, 2, 3])
+
+ @task_group
+ def tg2(inp):
+ another_task(inp, val) # val here should resolve to [2, 4, 6].
+
+ tg2.expand(inp=["a", "b"])
+
+ The surrounding mapped task groups of ``upstream`` and ``self.task`` are
+ inspected to find a common "ancestor". If such an ancestor is found,
+ we need to return specific map indexes to pull a partial value from
+ upstream XCom.
+
+ :param upstream: The referenced upstream task.
+ :param ti_count: The total count of task instance this task was expanded
+ by the scheduler, i.e. ``expanded_ti_count`` in the template context.
+ :return: Specific map index or map indexes to pull, or ``None`` if we
+ want to "whole" return value (i.e. no mapped task groups involved).
+ """
+ # Find the innermost common mapped task group between the current task
+ # If the current task and the referenced task does not have a common
+ # mapped task group, the two are in different task mapping contexts
+ # (like another_task above), and we should use the "whole" value.
+ common_ancestor = _find_common_ancestor_mapped_group(self.task, upstream)
+ if common_ancestor is None:
+ return None
+
+ # This value should never be None since we already know the current task
+ # is in a mapped task group, and should have been expanded. The check
+ # exists mainly to satisfy Mypy.
+ if ti_count is None:
+ return None
+
+ # At this point we know the two tasks share a mapped task group, and we
+ # should use a "partial" value. Let's break down the mapped ti count
+ # between the ancestor and further expansion happened inside it.
+ ancestor_ti_count = common_ancestor.get_mapped_ti_count(self.run_id, session=session)
+ ancestor_map_index = self.map_index * ancestor_ti_count // ti_count
+
+ # If the task is NOT further expanded inside the common ancestor, we
+ # only want to reference one single ti. We must walk the actual DAG,
+ # and "ti_count == ancestor_ti_count" does not work, since the further
+ # expansion may be of length 1.
+ if not _is_further_mapped_inside(upstream, common_ancestor):
+ return ancestor_map_index
+
+ # Otherwise we need a partial aggregation for values from selected task
+ # instances in the ancestor's expansion context.
+ further_count = ti_count // ancestor_ti_count
+ map_index_start = ancestor_map_index * further_count
+ return range(map_index_start, map_index_start + further_count)
+
+
+def _find_common_ancestor_mapped_group(node1: Operator, node2: Operator) -> MappedTaskGroup | None:
+ """Given two operators, find their innermost common mapped task group."""
+ if node1.dag is None or node2.dag is None or node1.dag_id != node2.dag_id:
+ return None
+ parent_group_ids = {g.group_id for g in node1.iter_mapped_task_groups()}
+ common_groups = (g for g in node2.iter_mapped_task_groups() if g.group_id in parent_group_ids)
+ return next(common_groups, None)
+
+
+def _is_further_mapped_inside(operator: Operator, container: TaskGroup) -> bool:
+ """Whether given operator is *further* mapped inside a task group."""
+ if isinstance(operator, MappedOperator):
+ return True
+ task_group = operator.task_group
+ while task_group is not None and task_group.group_id != container.group_id:
+ if isinstance(task_group, MappedTaskGroup):
+ return True
+ task_group = task_group.parent_group
+ return False
+
# State of the task instance.
# Stores string version of the task state.
@@ -2529,8 +2703,8 @@ def __init__(
dag_id: str,
task_id: str,
run_id: str,
- start_date: Optional[datetime],
- end_date: Optional[datetime],
+ start_date: datetime | None,
+ end_date: datetime | None,
try_number: int,
map_index: int,
state: str,
@@ -2538,8 +2712,8 @@ def __init__(
pool: str,
queue: str,
key: TaskInstanceKey,
- run_as_user: Optional[str] = None,
- priority_weight: Optional[int] = None,
+ run_as_user: str | None = None,
+ priority_weight: int | None = None,
):
self.dag_id = dag_id
self.task_id = task_id
@@ -2561,8 +2735,23 @@ def __eq__(self, other):
return self.__dict__ == other.__dict__
return NotImplemented
+ def as_dict(self):
+ warnings.warn(
+ "This method is deprecated. Use BaseSerialization.serialize.",
+ RemovedInAirflow3Warning,
+ stacklevel=2,
+ )
+ new_dict = dict(self.__dict__)
+ for key in new_dict:
+ if key in ["start_date", "end_date"]:
+ val = new_dict[key]
+ if not val or isinstance(val, str):
+ continue
+ new_dict.update({key: val.isoformat()})
+ return new_dict
+
@classmethod
- def from_ti(cls, ti: TaskInstance):
+ def from_ti(cls, ti: TaskInstance) -> SimpleTaskInstance:
return cls(
dag_id=ti.dag_id,
task_id=ti.task_id,
@@ -2576,17 +2765,22 @@ def from_ti(cls, ti: TaskInstance):
pool=ti.pool,
queue=ti.queue,
key=ti.key,
- run_as_user=ti.run_as_user if hasattr(ti, 'run_as_user') else None,
- priority_weight=ti.priority_weight if hasattr(ti, 'priority_weight') else None,
+ run_as_user=ti.run_as_user if hasattr(ti, "run_as_user") else None,
+ priority_weight=ti.priority_weight if hasattr(ti, "priority_weight") else None,
)
@classmethod
- def from_dict(cls, obj_dict: dict):
- ti_key = TaskInstanceKey(*obj_dict.pop('key'))
+ def from_dict(cls, obj_dict: dict) -> SimpleTaskInstance:
+ warnings.warn(
+ "This method is deprecated. Use BaseSerialization.deserialize.",
+ RemovedInAirflow3Warning,
+ stacklevel=2,
+ )
+ ti_key = TaskInstanceKey(*obj_dict.pop("key"))
start_date = None
end_date = None
- start_date_str: Optional[str] = obj_dict.pop('start_date')
- end_date_str: Optional[str] = obj_dict.pop('end_date')
+ start_date_str: str | None = obj_dict.pop("start_date")
+ end_date_str: str | None = obj_dict.pop("end_date")
if start_date_str:
start_date = timezone.parse(start_date_str)
if end_date_str:
@@ -2594,8 +2788,57 @@ def from_dict(cls, obj_dict: dict):
return cls(**obj_dict, start_date=start_date, end_date=end_date, key=ti_key)
+class TaskInstanceNote(Base):
+ """For storage of arbitrary notes concerning the task instance."""
+
+ __tablename__ = "task_instance_note"
+
+ user_id = Column(Integer, nullable=True)
+ task_id = Column(StringID(), primary_key=True, nullable=False)
+ dag_id = Column(StringID(), primary_key=True, nullable=False)
+ run_id = Column(StringID(), primary_key=True, nullable=False)
+ map_index = Column(Integer, primary_key=True, nullable=False)
+ content = Column(String(1000).with_variant(Text(1000), "mysql"))
+ created_at = Column(UtcDateTime, default=timezone.utcnow, nullable=False)
+ updated_at = Column(UtcDateTime, default=timezone.utcnow, onupdate=timezone.utcnow, nullable=False)
+
+ task_instance = relationship("TaskInstance", back_populates="task_instance_note")
+
+ __table_args__ = (
+ PrimaryKeyConstraint(
+ "task_id", "dag_id", "run_id", "map_index", name="task_instance_note_pkey", mssql_clustered=True
+ ),
+ ForeignKeyConstraint(
+ (dag_id, task_id, run_id, map_index),
+ [
+ "task_instance.dag_id",
+ "task_instance.task_id",
+ "task_instance.run_id",
+ "task_instance.map_index",
+ ],
+ name="task_instance_note_ti_fkey",
+ ondelete="CASCADE",
+ ),
+ ForeignKeyConstraint(
+ (user_id,),
+ ["ab_user.id"],
+ name="task_instance_note_user_fkey",
+ ),
+ )
+
+ def __init__(self, content, user_id=None):
+ self.content = content
+ self.user_id = user_id
+
+ def __repr__(self):
+ prefix = f"<{self.__class__.__name__}: {self.dag_id}.{self.task_id} {self.run_id}"
+ if self.map_index != -1:
+ prefix += f" map_index={self.map_index}"
+ return prefix + ">"
+
+
STATICA_HACK = True
-globals()['kcah_acitats'[::-1].upper()] = False
+globals()["kcah_acitats"[::-1].upper()] = False
if STATICA_HACK: # pragma: no cover
from airflow.jobs.base_job import BaseJob
diff --git a/airflow/models/tasklog.py b/airflow/models/tasklog.py
index a5a3e510fcfb0..3e5a9e195ea56 100644
--- a/airflow/models/tasklog.py
+++ b/airflow/models/tasklog.py
@@ -15,6 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
from sqlalchemy import Column, Integer, Text
diff --git a/airflow/models/taskmap.py b/airflow/models/taskmap.py
index 3945e2dcd7d30..e7abcc1b6e0ae 100644
--- a/airflow/models/taskmap.py
+++ b/airflow/models/taskmap.py
@@ -15,12 +15,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Table to store information about mapped task instances (AIP-42)."""
+from __future__ import annotations
import collections.abc
import enum
-from typing import TYPE_CHECKING, Any, Collection, List, Optional
+from typing import TYPE_CHECKING, Any, Collection
from sqlalchemy import CheckConstraint, Column, ForeignKeyConstraint, Integer, String
@@ -82,7 +82,7 @@ def __init__(
run_id: str,
map_index: int,
length: int,
- keys: Optional[List[Any]],
+ keys: list[Any] | None,
) -> None:
self.dag_id = dag_id
self.task_id = task_id
@@ -92,7 +92,7 @@ def __init__(
self.keys = keys
@classmethod
- def from_task_instance_xcom(cls, ti: "TaskInstance", value: Collection) -> "TaskMap":
+ def from_task_instance_xcom(cls, ti: TaskInstance, value: Collection) -> TaskMap:
if ti.run_id is None:
raise ValueError("cannot record task map for unrun task instance")
return cls(
diff --git a/airflow/models/taskmixin.py b/airflow/models/taskmixin.py
index 1d66a719a6aec..211fad6ff909d 100644
--- a/airflow/models/taskmixin.py
+++ b/airflow/models/taskmixin.py
@@ -14,21 +14,22 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
import warnings
from abc import ABCMeta, abstractmethod
-from typing import TYPE_CHECKING, Any, Iterable, Iterator, List, Optional, Sequence, Set, Tuple, Union
+from typing import TYPE_CHECKING, Any, Iterable, Sequence
import pendulum
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, RemovedInAirflow3Warning
from airflow.serialization.enums import DagAttributeTypes
if TYPE_CHECKING:
from logging import Logger
from airflow.models.dag import DAG
- from airflow.models.mappedoperator import MappedOperator
+ from airflow.models.operator import Operator
from airflow.utils.edgemodifier import EdgeModifier
from airflow.utils.task_group import TaskGroup
@@ -37,7 +38,7 @@ class DependencyMixin:
"""Mixing implementing common dependency setting methods methods like >> and <<."""
@property
- def roots(self) -> Sequence["DependencyMixin"]:
+ def roots(self) -> Sequence[DependencyMixin]:
"""
List of root nodes -- ones with no upstream dependencies.
@@ -46,7 +47,7 @@ def roots(self) -> Sequence["DependencyMixin"]:
raise NotImplementedError()
@property
- def leaves(self) -> Sequence["DependencyMixin"]:
+ def leaves(self) -> Sequence[DependencyMixin]:
"""
List of leaf nodes -- ones with only upstream dependencies.
@@ -55,37 +56,37 @@ def leaves(self) -> Sequence["DependencyMixin"]:
raise NotImplementedError()
@abstractmethod
- def set_upstream(self, other: Union["DependencyMixin", Sequence["DependencyMixin"]]):
+ def set_upstream(self, other: DependencyMixin | Sequence[DependencyMixin]):
"""Set a task or a task list to be directly upstream from the current task."""
raise NotImplementedError()
@abstractmethod
- def set_downstream(self, other: Union["DependencyMixin", Sequence["DependencyMixin"]]):
+ def set_downstream(self, other: DependencyMixin | Sequence[DependencyMixin]):
"""Set a task or a task list to be directly downstream from the current task."""
raise NotImplementedError()
- def update_relative(self, other: "DependencyMixin", upstream=True) -> None:
+ def update_relative(self, other: DependencyMixin, upstream=True) -> None:
"""
Update relationship information about another TaskMixin. Default is no-op.
Override if necessary.
"""
- def __lshift__(self, other: Union["DependencyMixin", Sequence["DependencyMixin"]]):
+ def __lshift__(self, other: DependencyMixin | Sequence[DependencyMixin]):
"""Implements Task << Task"""
self.set_upstream(other)
return other
- def __rshift__(self, other: Union["DependencyMixin", Sequence["DependencyMixin"]]):
+ def __rshift__(self, other: DependencyMixin | Sequence[DependencyMixin]):
"""Implements Task >> Task"""
self.set_downstream(other)
return other
- def __rrshift__(self, other: Union["DependencyMixin", Sequence["DependencyMixin"]]):
+ def __rrshift__(self, other: DependencyMixin | Sequence[DependencyMixin]):
"""Called for Task >> [Task] because list don't have __rshift__ operators."""
self.__lshift__(other)
return self
- def __rlshift__(self, other: Union["DependencyMixin", Sequence["DependencyMixin"]]):
+ def __rlshift__(self, other: DependencyMixin | Sequence[DependencyMixin]):
"""Called for Task << [Task] because list don't have __lshift__ operators."""
self.__rshift__(other)
return self
@@ -97,7 +98,7 @@ class TaskMixin(DependencyMixin):
def __init_subclass__(cls) -> None:
warnings.warn(
f"TaskMixin has been renamed to DependencyMixin, please update {cls.__name__}",
- category=DeprecationWarning,
+ category=RemovedInAirflow3Warning,
stacklevel=2,
)
return super().__init_subclass__()
@@ -109,8 +110,8 @@ class DAGNode(DependencyMixin, metaclass=ABCMeta):
unmapped.
"""
- dag: Optional["DAG"] = None
- task_group: Optional["TaskGroup"] = None
+ dag: DAG | None = None
+ task_group: TaskGroup | None = None
"""The task_group that contains this node"""
@property
@@ -119,17 +120,17 @@ def node_id(self) -> str:
raise NotImplementedError()
@property
- def label(self) -> Optional[str]:
+ def label(self) -> str | None:
tg = self.task_group
if tg and tg.node_id and tg.prefix_group_id:
# "task_group_id.task_id" -> "task_id"
return self.node_id[len(tg.node_id) + 1 :]
return self.node_id
- start_date: Optional[pendulum.DateTime]
- end_date: Optional[pendulum.DateTime]
- upstream_task_ids: Set[str]
- downstream_task_ids: Set[str]
+ start_date: pendulum.DateTime | None
+ end_date: pendulum.DateTime | None
+ upstream_task_ids: set[str]
+ downstream_task_ids: set[str]
def has_dag(self) -> bool:
return self.dag is not None
@@ -142,24 +143,24 @@ def dag_id(self) -> str:
return "_in_memory_dag_"
@property
- def log(self) -> "Logger":
+ def log(self) -> Logger:
raise NotImplementedError()
@property
@abstractmethod
- def roots(self) -> Sequence["DAGNode"]:
+ def roots(self) -> Sequence[DAGNode]:
raise NotImplementedError()
@property
@abstractmethod
- def leaves(self) -> Sequence["DAGNode"]:
+ def leaves(self) -> Sequence[DAGNode]:
raise NotImplementedError()
def _set_relatives(
self,
- task_or_task_list: Union[DependencyMixin, Sequence[DependencyMixin]],
+ task_or_task_list: DependencyMixin | Sequence[DependencyMixin],
upstream: bool = False,
- edge_modifier: Optional["EdgeModifier"] = None,
+ edge_modifier: EdgeModifier | None = None,
) -> None:
"""Sets relatives for the task or task list."""
from airflow.models.baseoperator import BaseOperator
@@ -169,7 +170,7 @@ def _set_relatives(
if not isinstance(task_or_task_list, Sequence):
task_or_task_list = [task_or_task_list]
- task_list: List[Operator] = []
+ task_list: list[Operator] = []
for task_object in task_or_task_list:
task_object.update_relative(self, not upstream)
relatives = task_object.leaves if upstream else task_object.roots
@@ -182,10 +183,10 @@ def _set_relatives(
# relationships can only be set if the tasks share a single DAG. Tasks
# without a DAG are assigned to that DAG.
- dags: Set["DAG"] = {task.dag for task in [*self.roots, *task_list] if task.has_dag() and task.dag}
+ dags: set[DAG] = {task.dag for task in [*self.roots, *task_list] if task.has_dag() and task.dag}
if len(dags) > 1:
- raise AirflowException(f'Tried to set relationships between tasks in more than one DAG: {dags}')
+ raise AirflowException(f"Tried to set relationships between tasks in more than one DAG: {dags}")
elif len(dags) == 1:
dag = dags.pop()
else:
@@ -195,24 +196,20 @@ def _set_relatives(
)
if not self.has_dag():
- # If this task does not yet have a dag, add it to the same dag as the other task and
- # put it in the dag's root TaskGroup.
+ # If this task does not yet have a dag, add it to the same dag as the other task.
self.dag = dag
- self.dag.task_group.add(self)
- def add_only_new(obj, item_set: Set[str], item: str) -> None:
+ def add_only_new(obj, item_set: set[str], item: str) -> None:
"""Adds only new items to item set"""
if item in item_set:
- self.log.warning('Dependency %s, %s already registered for DAG: %s', obj, item, dag.dag_id)
+ self.log.warning("Dependency %s, %s already registered for DAG: %s", obj, item, dag.dag_id)
else:
item_set.add(item)
for task in task_list:
if dag and not task.has_dag():
# If the other task does not yet have a dag, add it to the same dag as this task and
- # put it in the dag's root TaskGroup.
dag.add_task(task)
- dag.task_group.add(task)
if upstream:
add_only_new(task, task.downstream_task_ids, self.node_id)
add_only_new(self, self.upstream_task_ids, task.node_id)
@@ -226,35 +223,35 @@ def add_only_new(obj, item_set: Set[str], item: str) -> None:
def set_downstream(
self,
- task_or_task_list: Union[DependencyMixin, Sequence[DependencyMixin]],
- edge_modifier: Optional["EdgeModifier"] = None,
+ task_or_task_list: DependencyMixin | Sequence[DependencyMixin],
+ edge_modifier: EdgeModifier | None = None,
) -> None:
"""Set a node (or nodes) to be directly downstream from the current node."""
self._set_relatives(task_or_task_list, upstream=False, edge_modifier=edge_modifier)
def set_upstream(
self,
- task_or_task_list: Union[DependencyMixin, Sequence[DependencyMixin]],
- edge_modifier: Optional["EdgeModifier"] = None,
+ task_or_task_list: DependencyMixin | Sequence[DependencyMixin],
+ edge_modifier: EdgeModifier | None = None,
) -> None:
- """Set a node (or nodes) to be directly downstream from the current node."""
+ """Set a node (or nodes) to be directly upstream from the current node."""
self._set_relatives(task_or_task_list, upstream=True, edge_modifier=edge_modifier)
@property
- def downstream_list(self) -> Iterable["DAGNode"]:
+ def downstream_list(self) -> Iterable[Operator]:
"""List of nodes directly downstream"""
if not self.dag:
- raise AirflowException(f'Operator {self} has not been assigned to a DAG yet')
+ raise AirflowException(f"Operator {self} has not been assigned to a DAG yet")
return [self.dag.get_task(tid) for tid in self.downstream_task_ids]
@property
- def upstream_list(self) -> Iterable["DAGNode"]:
+ def upstream_list(self) -> Iterable[Operator]:
"""List of nodes directly upstream"""
if not self.dag:
- raise AirflowException(f'Operator {self} has not been assigned to a DAG yet')
+ raise AirflowException(f"Operator {self} has not been assigned to a DAG yet")
return [self.dag.get_task(tid) for tid in self.upstream_task_ids]
- def get_direct_relative_ids(self, upstream: bool = False) -> Set[str]:
+ def get_direct_relative_ids(self, upstream: bool = False) -> set[str]:
"""
Get set of the direct relative ids to the current task, upstream or
downstream.
@@ -264,7 +261,7 @@ def get_direct_relative_ids(self, upstream: bool = False) -> Set[str]:
else:
return self.downstream_task_ids
- def get_direct_relatives(self, upstream: bool = False) -> Iterable["DAGNode"]:
+ def get_direct_relatives(self, upstream: bool = False) -> Iterable[DAGNode]:
"""
Get list of the direct relatives to the current task, upstream or
downstream.
@@ -274,60 +271,6 @@ def get_direct_relatives(self, upstream: bool = False) -> Iterable["DAGNode"]:
else:
return self.downstream_list
- def serialize_for_task_group(self) -> Tuple[DagAttributeTypes, Any]:
- """This is used by SerializedTaskGroup to serialize a task group's content."""
+ def serialize_for_task_group(self) -> tuple[DagAttributeTypes, Any]:
+ """This is used by TaskGroupSerialization to serialize a task group's content."""
raise NotImplementedError()
-
- def _iter_all_mapped_downstreams(self) -> Iterator["MappedOperator"]:
- """Return mapped nodes that are direct dependencies of the current task.
-
- For now, this walks the entire DAG to find mapped nodes that has this
- current task as an upstream. We cannot use ``downstream_list`` since it
- only contains operators, not task groups. In the future, we should
- provide a way to record an DAG node's all downstream nodes instead.
-
- Note that this does not guarantee the returned tasks actually use the
- current task for task mapping, but only checks those task are mapped
- operators, and are downstreams of the current task.
-
- To get a list of tasks that uses the current task for task mapping, use
- :meth:`iter_mapped_dependants` instead.
- """
- from airflow.models.mappedoperator import MappedOperator
- from airflow.utils.task_group import TaskGroup
-
- def _walk_group(group: TaskGroup) -> Iterable[Tuple[str, DAGNode]]:
- """Recursively walk children in a task group.
-
- This yields all direct children (including both tasks and task
- groups), and all children of any task groups.
- """
- for key, child in group.children.items():
- yield key, child
- if isinstance(child, TaskGroup):
- yield from _walk_group(child)
-
- tg = self.task_group
- if not tg:
- raise RuntimeError("Cannot check for mapped dependants when not attached to a DAG")
- for key, child in _walk_group(tg):
- if key == self.node_id:
- continue
- if not isinstance(child, MappedOperator):
- continue
- if self.node_id in child.upstream_task_ids:
- yield child
-
- def iter_mapped_dependants(self) -> Iterator["MappedOperator"]:
- """Return mapped nodes that depend on the current task the expansion.
-
- For now, this walks the entire DAG to find mapped nodes that has this
- current task as an upstream. We cannot use ``downstream_list`` since it
- only contains operators, not task groups. In the future, we should
- provide a way to record an DAG node's all downstream nodes instead.
- """
- return (
- downstream
- for downstream in self._iter_all_mapped_downstreams()
- if any(p.node_id == self.node_id for p in downstream.iter_mapped_dependencies())
- )
diff --git a/airflow/models/taskreschedule.py b/airflow/models/taskreschedule.py
index 518f1e77ff65f..bbe09145c0c95 100644
--- a/airflow/models/taskreschedule.py
+++ b/airflow/models/taskreschedule.py
@@ -16,11 +16,12 @@
# specific language governing permissions and limitations
# under the License.
"""TaskReschedule tracks rescheduled task instances."""
+from __future__ import annotations
import datetime
from typing import TYPE_CHECKING
-from sqlalchemy import Column, ForeignKeyConstraint, Index, Integer, String, asc, desc, text
+from sqlalchemy import Column, ForeignKeyConstraint, Index, Integer, String, asc, desc, event, text
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.orm import relationship
@@ -49,7 +50,7 @@ class TaskReschedule(Base):
reschedule_date = Column(UtcDateTime, nullable=False)
__table_args__ = (
- Index('idx_task_reschedule_dag_task_run', dag_id, task_id, run_id, map_index, unique=False),
+ Index("idx_task_reschedule_dag_task_run", dag_id, task_id, run_id, map_index, unique=False),
ForeignKeyConstraint(
[dag_id, task_id, run_id, map_index],
[
@@ -61,11 +62,12 @@ class TaskReschedule(Base):
name="task_reschedule_ti_fkey",
ondelete="CASCADE",
),
+ Index("idx_task_reschedule_dag_run", dag_id, run_id),
ForeignKeyConstraint(
[dag_id, run_id],
- ['dag_run.dag_id', 'dag_run.run_id'],
- name='task_reschedule_dr_fkey',
- ondelete='CASCADE',
+ ["dag_run.dag_id", "dag_run.run_id"],
+ name="task_reschedule_dr_fkey",
+ ondelete="CASCADE",
),
)
dag_run = relationship("DagRun")
@@ -73,7 +75,7 @@ class TaskReschedule(Base):
def __init__(
self,
- task: "BaseOperator",
+ task: BaseOperator,
run_id: str,
try_number: int,
start_date: datetime.datetime,
@@ -111,6 +113,7 @@ def query_for_task_instance(task_instance, descending=False, session=None, try_n
TR.dag_id == task_instance.dag_id,
TR.task_id == task_instance.task_id,
TR.run_id == task_instance.run_id,
+ TR.map_index == task_instance.map_index,
TR.try_number == try_number,
)
if descending:
@@ -133,3 +136,15 @@ def find_for_task_instance(task_instance, session=None, try_number=None):
return TaskReschedule.query_for_task_instance(
task_instance, session=session, try_number=try_number
).all()
+
+
+@event.listens_for(TaskReschedule.__table__, "before_create")
+def add_ondelete_for_mssql(table, conn, **kw):
+ if conn.dialect.name != "mssql":
+ return
+
+ for constraint in table.constraints:
+ if constraint.name != "task_reschedule_dr_fkey":
+ continue
+ constraint.ondelete = "NO ACTION"
+ return
diff --git a/airflow/models/trigger.py b/airflow/models/trigger.py
index dc91fdccdf2ba..57d2ac8f2600a 100644
--- a/airflow/models/trigger.py
+++ b/airflow/models/trigger.py
@@ -14,11 +14,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
import datetime
from traceback import format_exception
-from typing import Any, Dict, Iterable, Optional
+from typing import Any, Iterable
from sqlalchemy import Column, Integer, String, func, or_
+from sqlalchemy.orm import relationship
from airflow.models.base import Base
from airflow.models.taskinstance import TaskInstance
@@ -26,7 +29,7 @@
from airflow.utils import timezone
from airflow.utils.retries import run_with_db_retries
from airflow.utils.session import provide_session
-from airflow.utils.sqlalchemy import ExtendedJSON, UtcDateTime
+from airflow.utils.sqlalchemy import ExtendedJSON, UtcDateTime, with_row_locks
from airflow.utils.state import State
@@ -55,9 +58,14 @@ class Trigger(Base):
created_date = Column(UtcDateTime, nullable=False)
triggerer_id = Column(Integer, nullable=True)
- def __init__(
- self, classpath: str, kwargs: Dict[str, Any], created_date: Optional[datetime.datetime] = None
- ):
+ triggerer_job = relationship(
+ "BaseJob",
+ primaryjoin="BaseJob.id == Trigger.triggerer_id",
+ foreign_keys=triggerer_id,
+ uselist=False,
+ )
+
+ def __init__(self, classpath: str, kwargs: dict[str, Any], created_date: datetime.datetime | None = None):
super().__init__()
self.classpath = classpath
self.kwargs = kwargs
@@ -74,7 +82,7 @@ def from_object(cls, trigger: BaseTrigger):
@classmethod
@provide_session
- def bulk_fetch(cls, ids: Iterable[int], session=None) -> Dict[int, "Trigger"]:
+ def bulk_fetch(cls, ids: Iterable[int], session=None) -> dict[int, Trigger]:
"""
Fetches all of the Triggers by ID and returns a dict mapping
ID -> Trigger instance
@@ -188,15 +196,16 @@ def assign_unassigned(cls, triggerer_id, capacity, session=None):
# Find triggers who do NOT have an alive triggerer_id, and then assign
# up to `capacity` of those to us.
- trigger_ids_query = (
+ trigger_ids_query = with_row_locks(
session.query(cls.id)
- # notin_ doesn't find NULL rows
.filter(or_(cls.triggerer_id.is_(None), cls.triggerer_id.notin_(alive_triggerer_ids)))
- .limit(capacity)
- .all()
- )
- session.query(cls).filter(cls.id.in_([i.id for i in trigger_ids_query])).update(
- {cls.triggerer_id: triggerer_id},
- synchronize_session=False,
- )
+ .limit(capacity),
+ session,
+ skip_locked=True,
+ ).all()
+ if trigger_ids_query:
+ session.query(cls).filter(cls.id.in_([i.id for i in trigger_ids_query])).update(
+ {cls.triggerer_id: triggerer_id},
+ synchronize_session=False,
+ )
session.commit()
diff --git a/airflow/models/variable.py b/airflow/models/variable.py
index 0904ddb23e6e9..dc7db0f7c09de 100644
--- a/airflow/models/variable.py
+++ b/airflow/models/variable.py
@@ -15,13 +15,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
import json
import logging
-from typing import Any, Optional
+from typing import Any
-from cryptography.fernet import InvalidToken as InvalidFernetToken
from sqlalchemy import Boolean, Column, Integer, String, Text
+from sqlalchemy.dialects.mysql import MEDIUMTEXT
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm import Session, reconstructor, synonym
@@ -47,7 +48,7 @@ class Variable(Base, LoggingMixin):
id = Column(Integer, primary_key=True)
key = Column(String(ID_LEN), unique=True)
- _val = Column('val', Text)
+ _val = Column("val", Text().with_variant(MEDIUMTEXT, "mysql"))
description = Column(Text)
is_encrypted = Column(Boolean, unique=False, default=False)
@@ -64,14 +65,16 @@ def on_db_load(self):
def __repr__(self):
# Hiding the value
- return f'{self.key} : {self._val}'
+ return f"{self.key} : {self._val}"
def get_val(self):
"""Get Airflow Variable from Metadata DB and decode it using the Fernet Key"""
+ from cryptography.fernet import InvalidToken as InvalidFernetToken
+
if self._val is not None and self.is_encrypted:
try:
fernet = get_fernet()
- return fernet.decrypt(bytes(self._val, 'utf-8')).decode()
+ return fernet.decrypt(bytes(self._val, "utf-8")).decode()
except InvalidFernetToken:
self.log.error("Can't decrypt _val for key=%s, invalid token or value", self.key)
return None
@@ -85,13 +88,13 @@ def set_val(self, value):
"""Encode the specified value with Fernet Key and store it in Variables Table."""
if value is not None:
fernet = get_fernet()
- self._val = fernet.encrypt(bytes(value, 'utf-8')).decode()
+ self._val = fernet.encrypt(bytes(value, "utf-8")).decode()
self.is_encrypted = fernet.is_encrypted
@declared_attr
def val(cls):
"""Get Airflow Variable from Metadata DB and decode it using the Fernet Key"""
- return synonym('_val', descriptor=property(cls.get_val, cls.set_val))
+ return synonym("_val", descriptor=property(cls.get_val, cls.set_val))
@classmethod
def setdefault(cls, key, default, description=None, deserialize_json=False):
@@ -112,7 +115,7 @@ def setdefault(cls, key, default, description=None, deserialize_json=False):
Variable.set(key, default, description=description, serialize_json=deserialize_json)
return default
else:
- raise ValueError('Default Value must be set')
+ raise ValueError("Default Value must be set")
else:
return obj
@@ -135,7 +138,7 @@ def get(
if default_var is not cls.__NO_DEFAULT_SENTINEL:
return default_var
else:
- raise KeyError(f'Variable {key} does not exist')
+ raise KeyError(f"Variable {key} does not exist")
else:
if deserialize_json:
obj = json.loads(var_val)
@@ -151,7 +154,7 @@ def set(
cls,
key: str,
value: Any,
- description: Optional[str] = None,
+ description: str | None = None,
serialize_json: bool = False,
session: Session = None,
):
@@ -196,11 +199,11 @@ def update(
cls.check_for_write_conflict(key)
if cls.get_variable_from_secrets(key=key) is None:
- raise KeyError(f'Variable {key} does not exist')
+ raise KeyError(f"Variable {key} does not exist")
obj = session.query(cls).filter(cls.key == key).first()
if obj is None:
- raise AttributeError(f'Variable {key} does not exist in the Database and cannot be updated.')
+ raise AttributeError(f"Variable {key} does not exist in the Database and cannot be updated.")
cls.set(key, value, description=obj.description, serialize_json=serialize_json)
@@ -219,7 +222,7 @@ def rotate_fernet_key(self):
"""Rotate Fernet Key"""
fernet = get_fernet()
if self._val and self.is_encrypted:
- self._val = fernet.rotate(self._val.encode('utf-8')).decode()
+ self._val = fernet.rotate(self._val.encode("utf-8")).decode()
@staticmethod
def check_for_write_conflict(key: str) -> None:
@@ -246,14 +249,14 @@ def check_for_write_conflict(key: str) -> None:
return
except Exception:
log.exception(
- 'Unable to retrieve variable from secrets backend (%s). '
- 'Checking subsequent secrets backend.',
+ "Unable to retrieve variable from secrets backend (%s). "
+ "Checking subsequent secrets backend.",
type(secrets_backend).__name__,
)
return None
@staticmethod
- def get_variable_from_secrets(key: str) -> Optional[str]:
+ def get_variable_from_secrets(key: str) -> str | None:
"""
Get Airflow Variable by iterating over all Secret Backends.
@@ -267,8 +270,8 @@ def get_variable_from_secrets(key: str) -> Optional[str]:
return var_val
except Exception:
log.exception(
- 'Unable to retrieve variable from secrets backend (%s). '
- 'Checking subsequent secrets backend.',
+ "Unable to retrieve variable from secrets backend (%s). "
+ "Checking subsequent secrets backend.",
type(secrets_backend).__name__,
)
return None
diff --git a/airflow/models/xcom.py b/airflow/models/xcom.py
index d67160d1fa901..3b4361842409f 100644
--- a/airflow/models/xcom.py
+++ b/airflow/models/xcom.py
@@ -15,26 +15,44 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+import collections.abc
+import contextlib
import datetime
import inspect
+import itertools
import json
import logging
import pickle
import warnings
from functools import wraps
-from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Type, Union, cast, overload
+from typing import TYPE_CHECKING, Any, Generator, Iterable, cast, overload
+import attr
import pendulum
-from sqlalchemy import Column, ForeignKeyConstraint, Index, Integer, LargeBinary, String
+from sqlalchemy import (
+ Column,
+ ForeignKeyConstraint,
+ Index,
+ Integer,
+ LargeBinary,
+ PrimaryKeyConstraint,
+ String,
+ text,
+)
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.orm import Query, Session, reconstructor, relationship
from sqlalchemy.orm.exc import NoResultFound
+from airflow import settings
+from airflow.compat.functools import cached_property
from airflow.configuration import conf
+from airflow.exceptions import RemovedInAirflow3Warning
from airflow.models.base import COLLATION_ARGS, ID_LEN, Base
from airflow.utils import timezone
from airflow.utils.helpers import exactly_one, is_container
+from airflow.utils.json import XComDecoder, XComEncoder
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import UtcDateTime
@@ -44,7 +62,7 @@
# MAX XCOM Size is 48KB
# https://github.com/apache/airflow/pull/1618#discussion_r68249677
MAX_XCOM_SIZE = 49344
-XCOM_RETURN_KEY = 'return_value'
+XCOM_RETURN_KEY = "return_value"
if TYPE_CHECKING:
from airflow.models.taskinstance import TaskInstanceKey
@@ -57,7 +75,7 @@ class BaseXCom(Base, LoggingMixin):
dag_run_id = Column(Integer(), nullable=False, primary_key=True)
task_id = Column(String(ID_LEN, **COLLATION_ARGS), nullable=False, primary_key=True)
- map_index = Column(Integer, primary_key=True, nullable=False, server_default="-1")
+ map_index = Column(Integer, primary_key=True, nullable=False, server_default=text("-1"))
key = Column(String(512, **COLLATION_ARGS), nullable=False, primary_key=True)
# Denormalized for easier lookup.
@@ -72,6 +90,10 @@ class BaseXCom(Base, LoggingMixin):
# but it goes over MySQL's index length limit. So we instead index 'key'
# separately, and enforce uniqueness with DagRun.id instead.
Index("idx_xcom_key", key),
+ Index("idx_xcom_task_instance", dag_id, task_id, run_id, map_index),
+ PrimaryKeyConstraint(
+ "dag_run_id", "task_id", "map_index", "key", name="xcom_pkey", mssql_clustered=True
+ ),
ForeignKeyConstraint(
[dag_id, task_id, run_id, map_index],
[
@@ -157,10 +179,10 @@ def set(
value: Any,
task_id: str,
dag_id: str,
- execution_date: Optional[datetime.datetime] = None,
+ execution_date: datetime.datetime | None = None,
session: Session = NEW_SESSION,
*,
- run_id: Optional[str] = None,
+ run_id: str | None = None,
map_index: int = -1,
) -> None:
""":sphinx-autoapi-skip:"""
@@ -174,7 +196,7 @@ def set(
if run_id is None:
message = "Passing 'execution_date' to 'XCom.set()' is deprecated. Use 'run_id' instead."
- warnings.warn(message, DeprecationWarning, stacklevel=3)
+ warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3)
try:
dag_run_id, run_id = (
session.query(DagRun.id, DagRun.run_id)
@@ -188,6 +210,27 @@ def set(
if dag_run_id is None:
raise ValueError(f"DAG run not found on DAG {dag_id!r} with ID {run_id!r}")
+ # Seamlessly resolve LazyXComAccess to a list. This is intended to work
+ # as a "lazy list" to avoid pulling a ton of XComs unnecessarily, but if
+ # it's pushed into XCom, the user should be aware of the performance
+ # implications, and this avoids leaking the implementation detail.
+ if isinstance(value, LazyXComAccess):
+ warning_message = (
+ "Coercing mapped lazy proxy %s from task %s (DAG %s, run %s) "
+ "to list, which may degrade performance. Review resource "
+ "requirements for this operation, and call list() to suppress "
+ "this message. See Dynamic Task Mapping documentation for "
+ "more information about lazy proxy objects."
+ )
+ log.warning(
+ warning_message,
+ "return value" if key == XCOM_RETURN_KEY else f"value {key}",
+ task_id,
+ dag_id,
+ run_id or execution_date,
+ )
+ value = list(value)
+
value = cls.serialize_value(
value=value,
key=key,
@@ -222,8 +265,8 @@ def set(
def get_value(
cls,
*,
- ti_key: "TaskInstanceKey",
- key: Optional[str] = None,
+ ti_key: TaskInstanceKey,
+ key: str | None = None,
session: Session = NEW_SESSION,
) -> Any:
"""Retrieve an XCom value for a task instance.
@@ -255,13 +298,13 @@ def get_value(
def get_one(
cls,
*,
- key: Optional[str] = None,
- dag_id: Optional[str] = None,
- task_id: Optional[str] = None,
- run_id: Optional[str] = None,
- map_index: Optional[int] = None,
+ key: str | None = None,
+ dag_id: str | None = None,
+ task_id: str | None = None,
+ run_id: str | None = None,
+ map_index: int | None = None,
session: Session = NEW_SESSION,
- ) -> Optional[Any]:
+ ) -> Any | None:
"""Retrieve an XCom value, optionally meeting certain criteria.
This method returns "full" XCom values (i.e. uses ``deserialize_value``
@@ -298,28 +341,28 @@ def get_one(
def get_one(
cls,
execution_date: datetime.datetime,
- key: Optional[str] = None,
- task_id: Optional[str] = None,
- dag_id: Optional[str] = None,
+ key: str | None = None,
+ task_id: str | None = None,
+ dag_id: str | None = None,
include_prior_dates: bool = False,
session: Session = NEW_SESSION,
- ) -> Optional[Any]:
+ ) -> Any | None:
""":sphinx-autoapi-skip:"""
@classmethod
@provide_session
def get_one(
cls,
- execution_date: Optional[datetime.datetime] = None,
- key: Optional[str] = None,
- task_id: Optional[str] = None,
- dag_id: Optional[str] = None,
+ execution_date: datetime.datetime | None = None,
+ key: str | None = None,
+ task_id: str | None = None,
+ dag_id: str | None = None,
include_prior_dates: bool = False,
session: Session = NEW_SESSION,
*,
- run_id: Optional[str] = None,
- map_index: Optional[int] = None,
- ) -> Optional[Any]:
+ run_id: str | None = None,
+ map_index: int | None = None,
+ ) -> Any | None:
""":sphinx-autoapi-skip:"""
if not exactly_one(execution_date is not None, run_id is not None):
raise ValueError("Exactly one of ti_key, run_id, or execution_date must be passed")
@@ -337,10 +380,10 @@ def get_one(
)
elif execution_date is not None:
message = "Passing 'execution_date' to 'XCom.get_one()' is deprecated. Use 'run_id' instead."
- warnings.warn(message, PendingDeprecationWarning, stacklevel=3)
+ warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3)
with warnings.catch_warnings():
- warnings.simplefilter("ignore", DeprecationWarning)
+ warnings.simplefilter("ignore", RemovedInAirflow3Warning)
query = cls.get_many(
execution_date=execution_date,
key=key,
@@ -365,12 +408,12 @@ def get_many(
cls,
*,
run_id: str,
- key: Optional[str] = None,
- task_ids: Union[str, Iterable[str], None] = None,
- dag_ids: Union[str, Iterable[str], None] = None,
- map_indexes: Union[int, Iterable[int], None] = None,
+ key: str | None = None,
+ task_ids: str | Iterable[str] | None = None,
+ dag_ids: str | Iterable[str] | None = None,
+ map_indexes: int | Iterable[int] | None = None,
include_prior_dates: bool = False,
- limit: Optional[int] = None,
+ limit: int | None = None,
session: Session = NEW_SESSION,
) -> Query:
"""Composes a query to get one or more XCom entries.
@@ -402,12 +445,12 @@ def get_many(
def get_many(
cls,
execution_date: datetime.datetime,
- key: Optional[str] = None,
- task_ids: Union[str, Iterable[str], None] = None,
- dag_ids: Union[str, Iterable[str], None] = None,
- map_indexes: Union[int, Iterable[int], None] = None,
+ key: str | None = None,
+ task_ids: str | Iterable[str] | None = None,
+ dag_ids: str | Iterable[str] | None = None,
+ map_indexes: int | Iterable[int] | None = None,
include_prior_dates: bool = False,
- limit: Optional[int] = None,
+ limit: int | None = None,
session: Session = NEW_SESSION,
) -> Query:
""":sphinx-autoapi-skip:"""
@@ -416,16 +459,16 @@ def get_many(
@provide_session
def get_many(
cls,
- execution_date: Optional[datetime.datetime] = None,
- key: Optional[str] = None,
- task_ids: Optional[Union[str, Iterable[str]]] = None,
- dag_ids: Optional[Union[str, Iterable[str]]] = None,
- map_indexes: Union[int, Iterable[int], None] = None,
+ execution_date: datetime.datetime | None = None,
+ key: str | None = None,
+ task_ids: str | Iterable[str] | None = None,
+ dag_ids: str | Iterable[str] | None = None,
+ map_indexes: int | Iterable[int] | None = None,
include_prior_dates: bool = False,
- limit: Optional[int] = None,
+ limit: int | None = None,
session: Session = NEW_SESSION,
*,
- run_id: Optional[str] = None,
+ run_id: str | None = None,
) -> Query:
""":sphinx-autoapi-skip:"""
from airflow.models.dagrun import DagRun
@@ -437,7 +480,7 @@ def get_many(
)
if execution_date is not None:
message = "Passing 'execution_date' to 'XCom.get_many()' is deprecated. Use 'run_id' instead."
- warnings.warn(message, PendingDeprecationWarning, stacklevel=3)
+ warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3)
query = session.query(cls).join(cls.dag_run)
@@ -454,7 +497,9 @@ def get_many(
elif dag_ids is not None:
query = query.filter(cls.dag_id == dag_ids)
- if is_container(map_indexes):
+ if isinstance(map_indexes, range) and map_indexes.step == 1:
+ query = query.filter(cls.map_index >= map_indexes.start, cls.map_index < map_indexes.stop)
+ elif is_container(map_indexes):
query = query.filter(cls.map_index.in_(map_indexes))
elif map_indexes is not None:
query = query.filter(cls.map_index == map_indexes)
@@ -477,13 +522,13 @@ def get_many(
@classmethod
@provide_session
- def delete(cls, xcoms: Union["XCom", Iterable["XCom"]], session: Session) -> None:
+ def delete(cls, xcoms: XCom | Iterable[XCom], session: Session) -> None:
"""Delete one or multiple XCom entries."""
if isinstance(xcoms, XCom):
xcoms = [xcoms]
for xcom in xcoms:
if not isinstance(xcom, XCom):
- raise TypeError(f'Expected XCom; received {xcom.__class__.__name__}')
+ raise TypeError(f"Expected XCom; received {xcom.__class__.__name__}")
session.delete(xcom)
session.commit()
@@ -495,7 +540,7 @@ def clear(
dag_id: str,
task_id: str,
run_id: str,
- map_index: Optional[int] = None,
+ map_index: int | None = None,
session: Session = NEW_SESSION,
) -> None:
"""Clear all XCom data from the database for the given task instance.
@@ -527,13 +572,13 @@ def clear(
@provide_session
def clear(
cls,
- execution_date: Optional[pendulum.DateTime] = None,
- dag_id: Optional[str] = None,
- task_id: Optional[str] = None,
+ execution_date: pendulum.DateTime | None = None,
+ dag_id: str | None = None,
+ task_id: str | None = None,
session: Session = NEW_SESSION,
*,
- run_id: Optional[str] = None,
- map_index: Optional[int] = None,
+ run_id: str | None = None,
+ map_index: int | None = None,
) -> None:
""":sphinx-autoapi-skip:"""
from airflow.models import DagRun
@@ -553,7 +598,7 @@ def clear(
if execution_date is not None:
message = "Passing 'execution_date' to 'XCom.clear()' is deprecated. Use 'run_id' instead."
- warnings.warn(message, DeprecationWarning, stacklevel=3)
+ warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3)
run_id = (
session.query(DagRun.run_id)
.filter(DagRun.dag_id == dag_id, DagRun.execution_date == execution_date)
@@ -569,42 +614,52 @@ def clear(
def serialize_value(
value: Any,
*,
- key: Optional[str] = None,
- task_id: Optional[str] = None,
- dag_id: Optional[str] = None,
- run_id: Optional[str] = None,
- map_index: Optional[int] = None,
- ):
- """Serialize XCom value to str or pickled object"""
- if conf.getboolean('core', 'enable_xcom_pickling'):
+ key: str | None = None,
+ task_id: str | None = None,
+ dag_id: str | None = None,
+ run_id: str | None = None,
+ map_index: int | None = None,
+ ) -> Any:
+ """Serialize XCom value to str or pickled object."""
+ if conf.getboolean("core", "enable_xcom_pickling"):
return pickle.dumps(value)
try:
- return json.dumps(value).encode('UTF-8')
- except (ValueError, TypeError):
+ return json.dumps(value, cls=XComEncoder).encode("UTF-8")
+ except (ValueError, TypeError) as ex:
log.error(
- "Could not serialize the XCom value into JSON."
+ "%s."
" If you are using pickle instead of JSON for XCom,"
" then you need to enable pickle support for XCom"
- " in your airflow config."
+ " in your airflow config or make sure to decorate your"
+ " object with attr.",
+ ex,
)
raise
@staticmethod
- def deserialize_value(result: "XCom") -> Any:
- """Deserialize XCom value from str or pickle object"""
+ def _deserialize_value(result: XCom, orm: bool) -> Any:
+ object_hook = None
+ if orm:
+ object_hook = XComDecoder.orm_object_hook
+
if result.value is None:
return None
- if conf.getboolean('core', 'enable_xcom_pickling'):
+ if conf.getboolean("core", "enable_xcom_pickling"):
try:
return pickle.loads(result.value)
except pickle.UnpicklingError:
- return json.loads(result.value.decode('UTF-8'))
+ return json.loads(result.value.decode("UTF-8"), cls=XComDecoder, object_hook=object_hook)
else:
try:
- return json.loads(result.value.decode('UTF-8'))
+ return json.loads(result.value.decode("UTF-8"), cls=XComDecoder, object_hook=object_hook)
except (json.JSONDecodeError, UnicodeDecodeError):
return pickle.loads(result.value)
+ @staticmethod
+ def deserialize_value(result: XCom) -> Any:
+ """Deserialize XCom value from str or pickle object"""
+ return BaseXCom._deserialize_value(result, False)
+
def orm_deserialize_value(self) -> Any:
"""
Deserialize method which is used to reconstruct ORM XCom object.
@@ -614,10 +669,109 @@ def orm_deserialize_value(self) -> Any:
creating XCom orm model. This is used when viewing XCom listing
in the webserver, for example.
"""
- return BaseXCom.deserialize_value(self)
+ return BaseXCom._deserialize_value(self, True)
+
+
+class _LazyXComAccessIterator(collections.abc.Iterator):
+ def __init__(self, cm: contextlib.AbstractContextManager[Query]) -> None:
+ self._cm = cm
+ self._entered = False
+
+ def __del__(self) -> None:
+ if self._entered:
+ self._cm.__exit__(None, None, None)
+
+ def __iter__(self) -> collections.abc.Iterator:
+ return self
+ def __next__(self) -> Any:
+ return XCom.deserialize_value(next(self._it))
-def _patch_outdated_serializer(clazz: Type[BaseXCom], params: Iterable[str]) -> None:
+ @cached_property
+ def _it(self) -> collections.abc.Iterator:
+ self._entered = True
+ return iter(self._cm.__enter__())
+
+
+@attr.define(slots=True)
+class LazyXComAccess(collections.abc.Sequence):
+ """Wrapper to lazily pull XCom with a sequence-like interface.
+
+ Note that since the session bound to the parent query may have died when we
+ actually access the sequence's content, we must create a new session
+ for every function call with ``with_session()``.
+
+ :meta private:
+ """
+
+ _query: Query
+ _len: int | None = attr.ib(init=False, default=None)
+
+ @classmethod
+ def build_from_xcom_query(cls, query: Query) -> LazyXComAccess:
+ return cls(query=query.with_entities(XCom.value))
+
+ def __repr__(self) -> str:
+ return f"LazyXComAccess([{len(self)} items])"
+
+ def __str__(self) -> str:
+ return str(list(self))
+
+ def __eq__(self, other: Any) -> bool:
+ if isinstance(other, (list, LazyXComAccess)):
+ z = itertools.zip_longest(iter(self), iter(other), fillvalue=object())
+ return all(x == y for x, y in z)
+ return NotImplemented
+
+ def __getstate__(self) -> Any:
+ # We don't want to go to the trouble of serializing the entire Query
+ # object, including its filters, hints, etc. (plus SQLAlchemy does not
+ # provide a public API to inspect a query's contents). Converting the
+ # query into a SQL string is the best we can get. Theoratically we can
+ # do the same for count(), but I think it should be performant enough to
+ # calculate only that eagerly.
+ with self._get_bound_query() as query:
+ statement = query.statement.compile(query.session.get_bind())
+ return (str(statement), query.count())
+
+ def __setstate__(self, state: Any) -> None:
+ statement, self._len = state
+ self._query = Query(XCom.value).from_statement(text(statement))
+
+ def __len__(self):
+ if self._len is None:
+ with self._get_bound_query() as query:
+ self._len = query.count()
+ return self._len
+
+ def __iter__(self):
+ return _LazyXComAccessIterator(self._get_bound_query())
+
+ def __getitem__(self, key):
+ if not isinstance(key, int):
+ raise ValueError("only support index access for now")
+ try:
+ with self._get_bound_query() as query:
+ r = query.offset(key).limit(1).one()
+ except NoResultFound:
+ raise IndexError(key) from None
+ return XCom.deserialize_value(r)
+
+ @contextlib.contextmanager
+ def _get_bound_query(self) -> Generator[Query, None, None]:
+ # Do we have a valid session already?
+ if self._query.session and self._query.session.is_active:
+ yield self._query
+ return
+
+ session = settings.Session()
+ try:
+ yield self._query.with_session(session)
+ finally:
+ session.close()
+
+
+def _patch_outdated_serializer(clazz: type[BaseXCom], params: Iterable[str]) -> None:
"""Patch a custom ``serialize_value`` to accept the modern signature.
To give custom XCom backends more flexibility with how they store values, we
@@ -635,19 +789,18 @@ def _shim(**kwargs):
f"Method `serialize_value` in XCom backend {XCom.__name__} is using outdated signature and"
f"must be updated to accept all params in `BaseXCom.set` except `session`. Support will be "
f"removed in a future release.",
- DeprecationWarning,
+ RemovedInAirflow3Warning,
)
return old_serializer(**kwargs)
clazz.serialize_value = _shim # type: ignore[assignment]
-def _get_function_params(function) -> List[str]:
+def _get_function_params(function) -> list[str]:
"""
Returns the list of variables names of a function
:param function: The function to inspect
- :rtype: List[str]
"""
parameters = inspect.signature(function).parameters
bound_arguments = [
@@ -656,8 +809,12 @@ def _get_function_params(function) -> List[str]:
return bound_arguments
-def resolve_xcom_backend() -> Type[BaseXCom]:
- """Resolves custom XCom class"""
+def resolve_xcom_backend() -> type[BaseXCom]:
+ """Resolves custom XCom class
+
+ Confirms that custom XCom class extends the BaseXCom.
+ Compares the function signature of the custom XCom serialize_value to the base XCom serialize_value.
+ """
clazz = conf.getimport("core", "xcom_backend", fallback=f"airflow.models.xcom.{BaseXCom.__name__}")
if not clazz:
return BaseXCom
diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py
index 2c2a8f9b5868d..d2b80474a9f12 100644
--- a/airflow/models/xcom_arg.py
+++ b/airflow/models/xcom_arg.py
@@ -14,33 +14,46 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Sequence, Union
-from airflow.exceptions import AirflowException
+from __future__ import annotations
+
+import contextlib
+import inspect
+from typing import TYPE_CHECKING, Any, Callable, Iterator, Mapping, Sequence, Union, overload
+
+from sqlalchemy import func
+from sqlalchemy.orm import Session
+
+from airflow.exceptions import XComNotFound
from airflow.models.abstractoperator import AbstractOperator
+from airflow.models.mappedoperator import MappedOperator
from airflow.models.taskmixin import DAGNode, DependencyMixin
from airflow.models.xcom import XCOM_RETURN_KEY
from airflow.utils.context import Context
from airflow.utils.edgemodifier import EdgeModifier
+from airflow.utils.mixins import ResolveMixin
from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.types import NOTSET
+from airflow.utils.types import NOTSET, ArgNotSet
if TYPE_CHECKING:
- from sqlalchemy.orm import Session
-
+ from airflow.models.dag import DAG
from airflow.models.operator import Operator
+# Callable objects contained by MapXComArg. We only accept callables from
+# the user, but deserialize them into strings in a serialized XComArg for
+# safety (those callables are arbitrary user code).
+MapCallables = Sequence[Union[Callable[[Any], Any], str]]
-class XComArg(DependencyMixin):
- """
- Class that represents a XCom push from a previous operator.
- Defaults to "return_value" as only key.
- Current implementation supports
+class XComArg(ResolveMixin, DependencyMixin):
+ """Reference to an XCom value pushed from another operator.
+
+ The implementation supports::
+
xcomarg >> op
xcomarg << op
- op >> xcomarg (by BaseOperator code)
- op << xcomarg (by BaseOperator code)
+ op >> xcomarg # By BaseOperator code
+ op << xcomarg # By BaseOperator code
**Example**: The moment you get a result from any operator (decorated or regular) you can ::
@@ -53,29 +66,175 @@ class XComArg(DependencyMixin):
This object can be used in legacy Operators via Jinja.
- **Example**: You can make this result to be part of any generated string ::
+ **Example**: You can make this result to be part of any generated string::
any_op = AnyOperator()
xcomarg = any_op.output
op1 = MyOperator(my_text_message=f"the value is {xcomarg}")
op2 = MyOperator(my_text_message=f"the value is {xcomarg['topic']}")
- :param operator: operator to which the XComArg belongs to
- :param key: key value which is used for xcom_pull (key in the XCom table)
+ :param operator: Operator instance to which the XComArg references.
+ :param key: Key used to pull the XCom value. Defaults to *XCOM_RETURN_KEY*,
+ i.e. the referenced operator's return value.
"""
- def __init__(self, operator: "Operator", key: str = XCOM_RETURN_KEY):
+ @overload
+ def __new__(cls: type[XComArg], operator: Operator, key: str = XCOM_RETURN_KEY) -> XComArg:
+ """Called when the user writes ``XComArg(...)`` directly."""
+
+ @overload
+ def __new__(cls: type[XComArg]) -> XComArg:
+ """Called by Python internals from subclasses."""
+
+ def __new__(cls, *args, **kwargs) -> XComArg:
+ if cls is XComArg:
+ return PlainXComArg(*args, **kwargs)
+ return super().__new__(cls)
+
+ @staticmethod
+ def iter_xcom_references(arg: Any) -> Iterator[tuple[Operator, str]]:
+ """Return XCom references in an arbitrary value.
+
+ Recursively traverse ``arg`` and look for XComArg instances in any
+ collection objects, and instances with ``template_fields`` set.
+ """
+ if isinstance(arg, ResolveMixin):
+ yield from arg.iter_references()
+ elif isinstance(arg, (tuple, set, list)):
+ for elem in arg:
+ yield from XComArg.iter_xcom_references(elem)
+ elif isinstance(arg, dict):
+ for elem in arg.values():
+ yield from XComArg.iter_xcom_references(elem)
+ elif isinstance(arg, AbstractOperator):
+ for attr in arg.template_fields:
+ yield from XComArg.iter_xcom_references(getattr(arg, attr))
+
+ @staticmethod
+ def apply_upstream_relationship(op: Operator, arg: Any):
+ """Set dependency for XComArgs.
+
+ This looks for XComArg objects in ``arg`` "deeply" (looking inside
+ collections objects and classes decorated with ``template_fields``), and
+ sets the relationship to ``op`` on any found.
+ """
+ for operator, _ in XComArg.iter_xcom_references(arg):
+ op.set_upstream(operator)
+
+ @property
+ def roots(self) -> list[DAGNode]:
+ """Required by TaskMixin"""
+ return [op for op, _ in self.iter_references()]
+
+ @property
+ def leaves(self) -> list[DAGNode]:
+ """Required by TaskMixin"""
+ return [op for op, _ in self.iter_references()]
+
+ def set_upstream(
+ self,
+ task_or_task_list: DependencyMixin | Sequence[DependencyMixin],
+ edge_modifier: EdgeModifier | None = None,
+ ):
+ """Proxy to underlying operator set_upstream method. Required by TaskMixin."""
+ for operator, _ in self.iter_references():
+ operator.set_upstream(task_or_task_list, edge_modifier)
+
+ def set_downstream(
+ self,
+ task_or_task_list: DependencyMixin | Sequence[DependencyMixin],
+ edge_modifier: EdgeModifier | None = None,
+ ):
+ """Proxy to underlying operator set_downstream method. Required by TaskMixin."""
+ for operator, _ in self.iter_references():
+ operator.set_downstream(task_or_task_list, edge_modifier)
+
+ def _serialize(self) -> dict[str, Any]:
+ """Called by DAG serialization.
+
+ The implementation should be the inverse function to ``deserialize``,
+ returning a data dict converted from this XComArg derivative. DAG
+ serialization does not call this directly, but ``serialize_xcom_arg``
+ instead, which adds additional information to dispatch deserialization
+ to the correct class.
+ """
+ raise NotImplementedError()
+
+ @classmethod
+ def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg:
+ """Called when deserializing a DAG.
+
+ The implementation should be the inverse function to ``serialize``,
+ implementing given a data dict converted from this XComArg derivative,
+ how the original XComArg should be created. DAG serialization relies on
+ additional information added in ``serialize_xcom_arg`` to dispatch data
+ dicts to the correct ``_deserialize`` information, so this function does
+ not need to validate whether the incoming data contains correct keys.
+ """
+ raise NotImplementedError()
+
+ def map(self, f: Callable[[Any], Any]) -> MapXComArg:
+ return MapXComArg(self, [f])
+
+ def zip(self, *others: XComArg, fillvalue: Any = NOTSET) -> ZipXComArg:
+ return ZipXComArg([self, *others], fillvalue=fillvalue)
+
+ def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
+ """Inspect length of pushed value for task-mapping.
+
+ This is used to determine how many task instances the scheduler should
+ create for a downstream using this XComArg for task-mapping.
+
+ *None* may be returned if the depended XCom has not been pushed.
+ """
+ raise NotImplementedError()
+
+ def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
+ """Pull XCom value.
+
+ This should only be called during ``op.execute()`` with an appropriate
+ context (e.g. generated from ``TaskInstance.get_template_context()``).
+ Although the ``ResolveMixin`` parent mixin also has a ``resolve``
+ protocol, this adds the optional ``session`` argument that some of the
+ subclasses need.
+
+ :meta private:
+ """
+ raise NotImplementedError()
+
+
+class PlainXComArg(XComArg):
+ """Reference to one single XCom without any additional semantics.
+
+ This class should not be accessed directly, but only through XComArg. The
+ class inheritance chain and ``__new__`` is implemented in this slightly
+ convoluted way because we want to
+
+ a. Allow the user to continue using XComArg directly for the simple
+ semantics (see documentation of the base class for details).
+ b. Make ``isinstance(thing, XComArg)`` be able to detect all kinds of XCom
+ references.
+ c. Not allow many properties of PlainXComArg (including ``__getitem__`` and
+ ``__str__``) to exist on other kinds of XComArg implementations since
+ they don't make sense.
+
+ :meta private:
+ """
+
+ def __init__(self, operator: Operator, key: str = XCOM_RETURN_KEY):
self.operator = operator
self.key = key
- def __eq__(self, other):
+ def __eq__(self, other: Any) -> bool:
+ if not isinstance(other, PlainXComArg):
+ return NotImplemented
return self.operator == other.operator and self.key == other.key
- def __getitem__(self, item: str) -> "XComArg":
+ def __getitem__(self, item: str) -> XComArg:
"""Implements xcomresult['some_result_key']"""
if not isinstance(item, str):
raise ValueError(f"XComArg only supports str lookup, received {type(item).__name__}")
- return XComArg(operator=self.operator, key=item)
+ return PlainXComArg(operator=self.operator, key=item)
def __iter__(self):
"""Override iterable protocol to raise error explicitly.
@@ -89,9 +248,14 @@ def __iter__(self):
This override catches the error eagerly, so an incorrectly implemented
DAG fails fast and avoids wasting resources on nonsensical iterating.
"""
- raise TypeError(f"{self.__class__.__name__!r} object is not iterable")
+ raise TypeError("'XComArg' object is not iterable")
+
+ def __repr__(self) -> str:
+ if self.key == XCOM_RETURN_KEY:
+ return f"XComArg({self.operator!r})"
+ return f"XComArg({self.operator!r}, {self.key!r})"
- def __str__(self):
+ def __str__(self) -> str:
"""
Backward compatibility for old-style jinja used in Airflow Operators
@@ -103,84 +267,267 @@ def __str__(self):
"""
xcom_pull_kwargs = [
f"task_ids='{self.operator.task_id}'",
- f"dag_id='{self.operator.dag.dag_id}'",
+ f"dag_id='{self.operator.dag_id}'",
]
if self.key is not None:
xcom_pull_kwargs.append(f"key='{self.key}'")
- xcom_pull_kwargs = ", ".join(xcom_pull_kwargs)
+ xcom_pull_str = ", ".join(xcom_pull_kwargs)
# {{{{ are required for escape {{ in f-string
- xcom_pull = f"{{{{ task_instance.xcom_pull({xcom_pull_kwargs}) }}}}"
+ xcom_pull = f"{{{{ task_instance.xcom_pull({xcom_pull_str}) }}}}"
return xcom_pull
- @property
- def roots(self) -> List[DAGNode]:
- """Required by TaskMixin"""
- return [self.operator]
+ def _serialize(self) -> dict[str, Any]:
+ return {"task_id": self.operator.task_id, "key": self.key}
+
+ @classmethod
+ def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg:
+ return cls(dag.get_task(data["task_id"]), data["key"])
+
+ def iter_references(self) -> Iterator[tuple[Operator, str]]:
+ yield self.operator, self.key
+
+ def map(self, f: Callable[[Any], Any]) -> MapXComArg:
+ if self.key != XCOM_RETURN_KEY:
+ raise ValueError("cannot map against non-return XCom")
+ return super().map(f)
+
+ def zip(self, *others: XComArg, fillvalue: Any = NOTSET) -> ZipXComArg:
+ if self.key != XCOM_RETURN_KEY:
+ raise ValueError("cannot map against non-return XCom")
+ return super().zip(*others, fillvalue=fillvalue)
+
+ def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
+ from airflow.models.taskmap import TaskMap
+ from airflow.models.xcom import XCom
+
+ task = self.operator
+ if isinstance(task, MappedOperator):
+ query = session.query(func.count(XCom.map_index)).filter(
+ XCom.dag_id == task.dag_id,
+ XCom.run_id == run_id,
+ XCom.task_id == task.task_id,
+ XCom.map_index >= 0,
+ XCom.key == XCOM_RETURN_KEY,
+ )
+ else:
+ query = session.query(TaskMap.length).filter(
+ TaskMap.dag_id == task.dag_id,
+ TaskMap.run_id == run_id,
+ TaskMap.task_id == task.task_id,
+ TaskMap.map_index < 0,
+ )
+ return query.scalar()
- @property
- def leaves(self) -> List[DAGNode]:
- """Required by TaskMixin"""
- return [self.operator]
+ @provide_session
+ def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
+ ti = context["ti"]
+ task_id = self.operator.task_id
+ map_indexes = ti.get_relevant_upstream_map_indexes(
+ self.operator,
+ context["expanded_ti_count"],
+ session=session,
+ )
+ result = ti.xcom_pull(
+ task_ids=task_id,
+ map_indexes=map_indexes,
+ key=self.key,
+ default=NOTSET,
+ session=session,
+ )
+ if not isinstance(result, ArgNotSet):
+ return result
+ if self.key == XCOM_RETURN_KEY:
+ return None
+ raise XComNotFound(ti.dag_id, task_id, self.key)
+
+
+def _get_callable_name(f: Callable | str) -> str:
+ """Try to "describe" a callable by getting its name."""
+ if callable(f):
+ return f.__name__
+ # Parse the source to find whatever is behind "def". For safety, we don't
+ # want to evaluate the code in any meaningful way!
+ with contextlib.suppress(Exception):
+ kw, name, _ = f.lstrip().split(None, 2)
+ if kw == "def":
+ return name
+ return ""
+
+
+class _MapResult(Sequence):
+ def __init__(self, value: Sequence | dict, callables: MapCallables) -> None:
+ self.value = value
+ self.callables = callables
+
+ def __getitem__(self, index: Any) -> Any:
+ value = self.value[index]
+
+ # In the worker, we can access all actual callables. Call them.
+ callables = [f for f in self.callables if callable(f)]
+ if len(callables) == len(self.callables):
+ for f in callables:
+ value = f(value)
+ return value
+
+ # In the scheduler, we don't have access to the actual callables, nor do
+ # we want to run it since it's arbitrary code. This builds a string to
+ # represent the call chain in the UI or logs instead.
+ for v in self.callables:
+ value = f"{_get_callable_name(v)}({value})"
+ return value
+
+ def __len__(self) -> int:
+ return len(self.value)
+
+
+class MapXComArg(XComArg):
+ """An XCom reference with ``map()`` call(s) applied.
+
+ This is based on an XComArg, but also applies a series of "transforms" that
+ convert the pulled XCom value.
+
+ :meta private:
+ """
- def set_upstream(
- self,
- task_or_task_list: Union[DependencyMixin, Sequence[DependencyMixin]],
- edge_modifier: Optional[EdgeModifier] = None,
- ):
- """Proxy to underlying operator set_upstream method. Required by TaskMixin."""
- self.operator.set_upstream(task_or_task_list, edge_modifier)
+ def __init__(self, arg: XComArg, callables: MapCallables) -> None:
+ for c in callables:
+ if getattr(c, "_airflow_is_task_decorator", False):
+ raise ValueError("map() argument must be a plain function, not a @task operator")
+ self.arg = arg
+ self.callables = callables
- def set_downstream(
- self,
- task_or_task_list: Union[DependencyMixin, Sequence[DependencyMixin]],
- edge_modifier: Optional[EdgeModifier] = None,
- ):
- """Proxy to underlying operator set_downstream method. Required by TaskMixin."""
- self.operator.set_downstream(task_or_task_list, edge_modifier)
+ def __repr__(self) -> str:
+ map_calls = "".join(f".map({_get_callable_name(f)})" for f in self.callables)
+ return f"{self.arg!r}{map_calls}"
+
+ def _serialize(self) -> dict[str, Any]:
+ return {
+ "arg": serialize_xcom_arg(self.arg),
+ "callables": [inspect.getsource(c) if callable(c) else c for c in self.callables],
+ }
+
+ @classmethod
+ def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg:
+ # We are deliberately NOT deserializing the callables. These are shown
+ # in the UI, and displaying a function object is useless.
+ return cls(deserialize_xcom_arg(data["arg"], dag), data["callables"])
+
+ def iter_references(self) -> Iterator[tuple[Operator, str]]:
+ yield from self.arg.iter_references()
+
+ def map(self, f: Callable[[Any], Any]) -> MapXComArg:
+ # Flatten arg.map(f1).map(f2) into one MapXComArg.
+ return MapXComArg(self.arg, [*self.callables, f])
+
+ def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
+ return self.arg.get_task_map_length(run_id, session=session)
@provide_session
- def resolve(self, context: Context, session: "Session" = NEW_SESSION) -> Any:
- """
- Pull XCom value for the existing arg. This method is run during ``op.execute()``
- in respectable context.
- """
- result = context["ti"].xcom_pull(
- task_ids=self.operator.task_id, key=str(self.key), default=NOTSET, session=session
- )
- if result is NOTSET:
- raise AirflowException(
- f'XComArg result from {self.operator.task_id} at {context["ti"].dag_id} '
- f'with key="{self.key}" is not found!'
- )
- return result
+ def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
+ value = self.arg.resolve(context, session=session)
+ if not isinstance(value, (Sequence, dict)):
+ raise ValueError(f"XCom map expects sequence or dict, not {type(value).__name__}")
+ return _MapResult(value, self.callables)
- @staticmethod
- def iter_xcom_args(arg: Any) -> Iterator["XComArg"]:
- """Return XComArg instances in an arbitrary value.
- This recursively traverse ``arg`` and look for XComArg instances in any
- collection objects, and instances with ``template_fields`` set.
- """
- if isinstance(arg, XComArg):
- yield arg
- elif isinstance(arg, (tuple, set, list)):
- for elem in arg:
- yield from XComArg.iter_xcom_args(elem)
- elif isinstance(arg, dict):
- for elem in arg.values():
- yield from XComArg.iter_xcom_args(elem)
- elif isinstance(arg, AbstractOperator):
- for elem in arg.template_fields:
- yield from XComArg.iter_xcom_args(elem)
+class _ZipResult(Sequence):
+ def __init__(self, values: Sequence[Sequence | dict], *, fillvalue: Any = NOTSET) -> None:
+ self.values = values
+ self.fillvalue = fillvalue
@staticmethod
- def apply_upstream_relationship(op: "Operator", arg: Any):
- """Set dependency for XComArgs.
+ def _get_or_fill(container: Sequence | dict, index: Any, fillvalue: Any) -> Any:
+ try:
+ return container[index]
+ except (IndexError, KeyError):
+ return fillvalue
+
+ def __getitem__(self, index: Any) -> Any:
+ if index >= len(self):
+ raise IndexError(index)
+ return tuple(self._get_or_fill(value, index, self.fillvalue) for value in self.values)
+
+ def __len__(self) -> int:
+ lengths = (len(v) for v in self.values)
+ if isinstance(self.fillvalue, ArgNotSet):
+ return min(lengths)
+ return max(lengths)
+
+
+class ZipXComArg(XComArg):
+ """An XCom reference with ``zip()`` applied.
+
+ This is constructed from multiple XComArg instances, and presents an
+ iterable that "zips" them together like the built-in ``zip()`` (and
+ ``itertools.zip_longest()`` if ``fillvalue`` is provided).
+ """
- This looks for XComArg objects in ``arg`` "deeply" (looking inside
- collections objects and classes decorated with ``template_fields``), and
- sets the relationship to ``op`` on any found.
- """
- for ref in XComArg.iter_xcom_args(arg):
- op.set_upstream(ref.operator)
+ def __init__(self, args: Sequence[XComArg], *, fillvalue: Any = NOTSET) -> None:
+ if not args:
+ raise ValueError("At least one input is required")
+ self.args = args
+ self.fillvalue = fillvalue
+
+ def __repr__(self) -> str:
+ args_iter = iter(self.args)
+ first = repr(next(args_iter))
+ rest = ", ".join(repr(arg) for arg in args_iter)
+ if isinstance(self.fillvalue, ArgNotSet):
+ return f"{first}.zip({rest})"
+ return f"{first}.zip({rest}, fillvalue={self.fillvalue!r})"
+
+ def _serialize(self) -> dict[str, Any]:
+ args = [serialize_xcom_arg(arg) for arg in self.args]
+ if isinstance(self.fillvalue, ArgNotSet):
+ return {"args": args}
+ return {"args": args, "fillvalue": self.fillvalue}
+
+ @classmethod
+ def _deserialize(cls, data: dict[str, Any], dag: DAG) -> XComArg:
+ return cls(
+ [deserialize_xcom_arg(arg, dag) for arg in data["args"]],
+ fillvalue=data.get("fillvalue", NOTSET),
+ )
+
+ def iter_references(self) -> Iterator[tuple[Operator, str]]:
+ for arg in self.args:
+ yield from arg.iter_references()
+
+ def get_task_map_length(self, run_id: str, *, session: Session) -> int | None:
+ all_lengths = (arg.get_task_map_length(run_id, session=session) for arg in self.args)
+ ready_lengths = [length for length in all_lengths if length is not None]
+ if len(ready_lengths) != len(self.args):
+ return None # If any of the referenced XComs is not ready, we are not ready either.
+ if isinstance(self.fillvalue, ArgNotSet):
+ return min(ready_lengths)
+ return max(ready_lengths)
+
+ @provide_session
+ def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
+ values = [arg.resolve(context, session=session) for arg in self.args]
+ for value in values:
+ if not isinstance(value, (Sequence, dict)):
+ raise ValueError(f"XCom zip expects sequence or dict, not {type(value).__name__}")
+ return _ZipResult(values, fillvalue=self.fillvalue)
+
+
+_XCOM_ARG_TYPES: Mapping[str, type[XComArg]] = {
+ "": PlainXComArg,
+ "map": MapXComArg,
+ "zip": ZipXComArg,
+}
+
+
+def serialize_xcom_arg(value: XComArg) -> dict[str, Any]:
+ """DAG serialization interface."""
+ key = next(k for k, v in _XCOM_ARG_TYPES.items() if v == type(value))
+ if key:
+ return {"type": key, **value._serialize()}
+ return value._serialize()
+
+
+def deserialize_xcom_arg(data: dict[str, Any], dag: DAG) -> XComArg:
+ """DAG serialization interface."""
+ klass = _XCOM_ARG_TYPES[data.get("type", "")]
+ return klass._deserialize(data, dag)
diff --git a/airflow/operators/__init__.py b/airflow/operators/__init__.py
index ec83c1e147a32..f4acbe0ebf883 100644
--- a/airflow/operators/__init__.py
+++ b/airflow/operators/__init__.py
@@ -15,4 +15,180 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+# fmt: off
"""Operators."""
+from __future__ import annotations
+
+from airflow.utils.deprecation_tools import add_deprecated_classes
+
+__deprecated_classes = {
+ 'bash_operator': {
+ 'BashOperator': 'airflow.operators.bash.BashOperator',
+ },
+ 'branch_operator': {
+ 'BaseBranchOperator': 'airflow.operators.branch.BaseBranchOperator',
+ },
+ 'check_operator': {
+ 'SQLCheckOperator': 'airflow.providers.common.sql.operators.sql.SQLCheckOperator',
+ 'SQLIntervalCheckOperator': 'airflow.providers.common.sql.operators.sql.SQLIntervalCheckOperator',
+ 'SQLThresholdCheckOperator': 'airflow.providers.common.sql.operators.sql.SQLThresholdCheckOperator',
+ 'SQLValueCheckOperator': 'airflow.providers.common.sql.operators.sql.SQLValueCheckOperator',
+ 'CheckOperator': 'airflow.providers.common.sql.operators.sql.SQLCheckOperator',
+ 'IntervalCheckOperator': 'airflow.providers.common.sql.operators.sql.SQLIntervalCheckOperator',
+ 'ThresholdCheckOperator': 'airflow.providers.common.sql.operators.sql.SQLThresholdCheckOperator',
+ 'ValueCheckOperator': 'airflow.providers.common.sql.operators.sql.SQLValueCheckOperator',
+ },
+ 'dagrun_operator': {
+ 'TriggerDagRunLink': 'airflow.operators.trigger_dagrun.TriggerDagRunLink',
+ 'TriggerDagRunOperator': 'airflow.operators.trigger_dagrun.TriggerDagRunOperator',
+ },
+ 'docker_operator': {
+ 'DockerOperator': 'airflow.providers.docker.operators.docker.DockerOperator',
+ },
+ 'druid_check_operator': {
+ 'DruidCheckOperator': 'airflow.providers.apache.druid.operators.druid_check.DruidCheckOperator',
+ },
+ 'dummy': {
+ 'EmptyOperator': 'airflow.operators.empty.EmptyOperator',
+ 'DummyOperator': 'airflow.operators.empty.EmptyOperator',
+ },
+ 'dummy_operator': {
+ 'EmptyOperator': 'airflow.operators.empty.EmptyOperator',
+ 'DummyOperator': 'airflow.operators.empty.EmptyOperator',
+ },
+ 'email_operator': {
+ 'EmailOperator': 'airflow.operators.email.EmailOperator',
+ },
+ 'gcs_to_s3': {
+ 'GCSToS3Operator': 'airflow.providers.amazon.aws.transfers.gcs_to_s3.GCSToS3Operator',
+ },
+ 'google_api_to_s3_transfer': {
+ 'GoogleApiToS3Operator': (
+ 'airflow.providers.amazon.aws.transfers.google_api_to_s3.GoogleApiToS3Operator'
+ ),
+ 'GoogleApiToS3Transfer': (
+ 'airflow.providers.amazon.aws.transfers.google_api_to_s3.GoogleApiToS3Operator'
+ ),
+ },
+ 'hive_operator': {
+ 'HiveOperator': 'airflow.providers.apache.hive.operators.hive.HiveOperator',
+ },
+ 'hive_stats_operator': {
+ 'HiveStatsCollectionOperator': (
+ 'airflow.providers.apache.hive.operators.hive_stats.HiveStatsCollectionOperator'
+ ),
+ },
+ 'hive_to_druid': {
+ 'HiveToDruidOperator': 'airflow.providers.apache.druid.transfers.hive_to_druid.HiveToDruidOperator',
+ 'HiveToDruidTransfer': 'airflow.providers.apache.druid.transfers.hive_to_druid.HiveToDruidOperator',
+ },
+ 'hive_to_mysql': {
+ 'HiveToMySqlOperator': 'airflow.providers.apache.hive.transfers.hive_to_mysql.HiveToMySqlOperator',
+ 'HiveToMySqlTransfer': 'airflow.providers.apache.hive.transfers.hive_to_mysql.HiveToMySqlOperator',
+ },
+ 'hive_to_samba_operator': {
+ 'HiveToSambaOperator': 'airflow.providers.apache.hive.transfers.hive_to_samba.HiveToSambaOperator',
+ },
+ 'http_operator': {
+ 'SimpleHttpOperator': 'airflow.providers.http.operators.http.SimpleHttpOperator',
+ },
+ 'jdbc_operator': {
+ 'JdbcOperator': 'airflow.providers.jdbc.operators.jdbc.JdbcOperator',
+ },
+ 'latest_only_operator': {
+ 'LatestOnlyOperator': 'airflow.operators.latest_only.LatestOnlyOperator',
+ },
+ 'mssql_operator': {
+ 'MsSqlOperator': 'airflow.providers.microsoft.mssql.operators.mssql.MsSqlOperator',
+ },
+ 'mssql_to_hive': {
+ 'MsSqlToHiveOperator': 'airflow.providers.apache.hive.transfers.mssql_to_hive.MsSqlToHiveOperator',
+ 'MsSqlToHiveTransfer': 'airflow.providers.apache.hive.transfers.mssql_to_hive.MsSqlToHiveOperator',
+ },
+ 'mysql_operator': {
+ 'MySqlOperator': 'airflow.providers.mysql.operators.mysql.MySqlOperator',
+ },
+ 'mysql_to_hive': {
+ 'MySqlToHiveOperator': 'airflow.providers.apache.hive.transfers.mysql_to_hive.MySqlToHiveOperator',
+ 'MySqlToHiveTransfer': 'airflow.providers.apache.hive.transfers.mysql_to_hive.MySqlToHiveOperator',
+ },
+ 'oracle_operator': {
+ 'OracleOperator': 'airflow.providers.oracle.operators.oracle.OracleOperator',
+ },
+ 'papermill_operator': {
+ 'PapermillOperator': 'airflow.providers.papermill.operators.papermill.PapermillOperator',
+ },
+ 'pig_operator': {
+ 'PigOperator': 'airflow.providers.apache.pig.operators.pig.PigOperator',
+ },
+ 'postgres_operator': {
+ 'Mapping': 'airflow.providers.postgres.operators.postgres.Mapping',
+ 'PostgresOperator': 'airflow.providers.postgres.operators.postgres.PostgresOperator',
+ },
+ 'presto_check_operator': {
+ 'SQLCheckOperator': 'airflow.providers.common.sql.operators.sql.SQLCheckOperator',
+ 'SQLIntervalCheckOperator': 'airflow.providers.common.sql.operators.sql.SQLIntervalCheckOperator',
+ 'SQLValueCheckOperator': 'airflow.providers.common.sql.operators.sql.SQLValueCheckOperator',
+ 'PrestoCheckOperator': 'airflow.providers.common.sql.operators.sql.SQLCheckOperator',
+ 'PrestoIntervalCheckOperator': 'airflow.providers.common.sql.operators.sql.SQLIntervalCheckOperator',
+ 'PrestoValueCheckOperator': 'airflow.providers.common.sql.operators.sql.SQLValueCheckOperator',
+ },
+ 'presto_to_mysql': {
+ 'PrestoToMySqlOperator': 'airflow.providers.mysql.transfers.presto_to_mysql.PrestoToMySqlOperator',
+ 'PrestoToMySqlTransfer': 'airflow.providers.mysql.transfers.presto_to_mysql.PrestoToMySqlOperator',
+ },
+ 'python_operator': {
+ 'BranchPythonOperator': 'airflow.operators.python.BranchPythonOperator',
+ 'PythonOperator': 'airflow.operators.python.PythonOperator',
+ 'PythonVirtualenvOperator': 'airflow.operators.python.PythonVirtualenvOperator',
+ 'ShortCircuitOperator': 'airflow.operators.python.ShortCircuitOperator',
+ },
+ 'redshift_to_s3_operator': {
+ 'RedshiftToS3Operator': 'airflow.providers.amazon.aws.transfers.redshift_to_s3.RedshiftToS3Operator',
+ 'RedshiftToS3Transfer': 'airflow.providers.amazon.aws.transfers.redshift_to_s3.RedshiftToS3Operator',
+ },
+ 's3_file_transform_operator': {
+ 'S3FileTransformOperator': (
+ 'airflow.providers.amazon.aws.operators.s3_file_transform.S3FileTransformOperator'
+ ),
+ },
+ 's3_to_hive_operator': {
+ 'S3ToHiveOperator': 'airflow.providers.apache.hive.transfers.s3_to_hive.S3ToHiveOperator',
+ 'S3ToHiveTransfer': 'airflow.providers.apache.hive.transfers.s3_to_hive.S3ToHiveOperator',
+ },
+ 's3_to_redshift_operator': {
+ 'S3ToRedshiftOperator': 'airflow.providers.amazon.aws.transfers.s3_to_redshift.S3ToRedshiftOperator',
+ 'S3ToRedshiftTransfer': 'airflow.providers.amazon.aws.transfers.s3_to_redshift.S3ToRedshiftOperator',
+ },
+ 'slack_operator': {
+ 'SlackAPIOperator': 'airflow.providers.slack.operators.slack.SlackAPIOperator',
+ 'SlackAPIPostOperator': 'airflow.providers.slack.operators.slack.SlackAPIPostOperator',
+ },
+ 'sql': {
+ 'BaseSQLOperator': 'airflow.providers.common.sql.operators.sql.BaseSQLOperator',
+ 'BranchSQLOperator': 'airflow.providers.common.sql.operators.sql.BranchSQLOperator',
+ 'SQLCheckOperator': 'airflow.providers.common.sql.operators.sql.SQLCheckOperator',
+ 'SQLColumnCheckOperator': 'airflow.providers.common.sql.operators.sql.SQLColumnCheckOperator',
+ 'SQLIntervalCheckOperator': 'airflow.providers.common.sql.operators.sql.SQLIntervalCheckOperator',
+ 'SQLTableCheckOperator': 'airflow.providers.common.sql.operators.sql.SQLTableCheckOperator',
+ 'SQLThresholdCheckOperator': 'airflow.providers.common.sql.operators.sql.SQLThresholdCheckOperator',
+ 'SQLValueCheckOperator': 'airflow.providers.common.sql.operators.sql.SQLValueCheckOperator',
+ '_convert_to_float_if_possible': (
+ 'airflow.providers.common.sql.operators.sql._convert_to_float_if_possible'
+ ),
+ 'parse_boolean': 'airflow.providers.common.sql.operators.sql.parse_boolean',
+ },
+ 'sql_branch_operator': {
+ 'BranchSQLOperator': 'airflow.providers.common.sql.operators.sql.BranchSQLOperator',
+ 'BranchSqlOperator': 'airflow.providers.common.sql.operators.sql.BranchSQLOperator',
+ },
+ 'sqlite_operator': {
+ 'SqliteOperator': 'airflow.providers.sqlite.operators.sqlite.SqliteOperator',
+ },
+ 'subdag_operator': {
+ 'SkippedStatePropagationOptions': 'airflow.operators.subdag.SkippedStatePropagationOptions',
+ 'SubDagOperator': 'airflow.operators.subdag.SubDagOperator',
+ },
+}
+
+add_deprecated_classes(__deprecated_classes, __name__)
diff --git a/airflow/operators/bash.py b/airflow/operators/bash.py
index a4fac2b7f7eab..d50733bafa33f 100644
--- a/airflow/operators/bash.py
+++ b/airflow/operators/bash.py
@@ -15,8 +15,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
import os
-from typing import Dict, Optional, Sequence
+import shutil
+from typing import Sequence
from airflow.compat.functools import cached_property
from airflow.exceptions import AirflowException, AirflowSkipException
@@ -122,23 +125,23 @@ class BashOperator(BaseOperator):
"""
- template_fields: Sequence[str] = ('bash_command', 'env')
- template_fields_renderers = {'bash_command': 'bash', 'env': 'json'}
+ template_fields: Sequence[str] = ("bash_command", "env")
+ template_fields_renderers = {"bash_command": "bash", "env": "json"}
template_ext: Sequence[str] = (
- '.sh',
- '.bash',
+ ".sh",
+ ".bash",
)
- ui_color = '#f0ede4'
+ ui_color = "#f0ede4"
def __init__(
self,
*,
bash_command: str,
- env: Optional[Dict[str, str]] = None,
+ env: dict[str, str] | None = None,
append_env: bool = False,
- output_encoding: str = 'utf-8',
+ output_encoding: str = "utf-8",
skip_exit_code: int = 99,
- cwd: Optional[str] = None,
+ cwd: str | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -148,8 +151,6 @@ def __init__(
self.skip_exit_code = skip_exit_code
self.cwd = cwd
self.append_env = append_env
- if kwargs.get('xcom_push') is not None:
- raise AirflowException("'xcom_push' was deprecated, use 'BaseOperator.do_xcom_push' instead")
@cached_property
def subprocess_hook(self):
@@ -169,13 +170,14 @@ def get_env(self, context):
airflow_context_vars = context_to_airflow_vars(context, in_env_var_format=True)
self.log.debug(
- 'Exporting the following env vars:\n%s',
- '\n'.join(f"{k}={v}" for k, v in airflow_context_vars.items()),
+ "Exporting the following env vars:\n%s",
+ "\n".join(f"{k}={v}" for k, v in airflow_context_vars.items()),
)
env.update(airflow_context_vars)
return env
def execute(self, context: Context):
+ bash_path = shutil.which("bash") or "bash"
if self.cwd is not None:
if not os.path.exists(self.cwd):
raise AirflowException(f"Can not find the cwd: {self.cwd}")
@@ -183,7 +185,7 @@ def execute(self, context: Context):
raise AirflowException(f"The cwd {self.cwd} must be a directory")
env = self.get_env(context)
result = self.subprocess_hook.run_command(
- command=['bash', '-c', self.bash_command],
+ command=[bash_path, "-c", self.bash_command],
env=env,
output_encoding=self.output_encoding,
cwd=self.cwd,
@@ -192,7 +194,7 @@ def execute(self, context: Context):
raise AirflowSkipException(f"Bash command returned exit code {self.skip_exit_code}. Skipping.")
elif result.exit_code != 0:
raise AirflowException(
- f'Bash command failed. The command returned a non-zero exit code {result.exit_code}.'
+ f"Bash command failed. The command returned a non-zero exit code {result.exit_code}."
)
return result.output
diff --git a/airflow/operators/bash_operator.py b/airflow/operators/bash_operator.py
deleted file mode 100644
index 3b7764dfbd316..0000000000000
--- a/airflow/operators/bash_operator.py
+++ /dev/null
@@ -1,26 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.operators.bash`."""
-
-import warnings
-
-from airflow.operators.bash import BashOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.operators.bash`.", DeprecationWarning, stacklevel=2
-)
diff --git a/airflow/operators/branch.py b/airflow/operators/branch.py
index cdd546feca469..8dfbe7c2767dc 100644
--- a/airflow/operators/branch.py
+++ b/airflow/operators/branch.py
@@ -16,10 +16,11 @@
# specific language governing permissions and limitations
# under the License.
"""Branching operators"""
+from __future__ import annotations
-from typing import Iterable, Union
+from typing import Iterable
-from airflow.models import BaseOperator
+from airflow.models.baseoperator import BaseOperator
from airflow.models.skipmixin import SkipMixin
from airflow.utils.context import Context
@@ -38,7 +39,7 @@ class BaseBranchOperator(BaseOperator, SkipMixin):
tasks directly downstream of this operator will be skipped.
"""
- def choose_branch(self, context: Context) -> Union[str, Iterable[str]]:
+ def choose_branch(self, context: Context) -> str | Iterable[str]:
"""
Subclasses should implement this, running whatever logic is
necessary to choose a branch and returning a task_id or list of
@@ -50,5 +51,5 @@ def choose_branch(self, context: Context) -> Union[str, Iterable[str]]:
def execute(self, context: Context):
branches_to_execute = self.choose_branch(context)
- self.skip_all_except(context['ti'], branches_to_execute)
+ self.skip_all_except(context["ti"], branches_to_execute)
return branches_to_execute
diff --git a/airflow/operators/branch_operator.py b/airflow/operators/branch_operator.py
deleted file mode 100644
index b4c71d5bc1f88..0000000000000
--- a/airflow/operators/branch_operator.py
+++ /dev/null
@@ -1,26 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.operators.branch`."""
-
-import warnings
-
-from airflow.operators.branch import BaseBranchOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.operators.branch`.", DeprecationWarning, stacklevel=2
-)
diff --git a/airflow/operators/check_operator.py b/airflow/operators/check_operator.py
deleted file mode 100644
index 130eb6e577c9c..0000000000000
--- a/airflow/operators/check_operator.py
+++ /dev/null
@@ -1,96 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""This module is deprecated. Please use :mod:`airflow.operators.sql`."""
-
-import warnings
-
-from airflow.operators.sql import (
- SQLCheckOperator,
- SQLIntervalCheckOperator,
- SQLThresholdCheckOperator,
- SQLValueCheckOperator,
-)
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.operators.sql`.", DeprecationWarning, stacklevel=2
-)
-
-
-class CheckOperator(SQLCheckOperator):
- """
- This class is deprecated.
- Please use `airflow.operators.sql.SQLCheckOperator`.
- """
-
- def __init__(self, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.operators.sql.SQLCheckOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(**kwargs)
-
-
-class IntervalCheckOperator(SQLIntervalCheckOperator):
- """
- This class is deprecated.
- Please use `airflow.operators.sql.SQLIntervalCheckOperator`.
- """
-
- def __init__(self, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.operators.sql.SQLIntervalCheckOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(**kwargs)
-
-
-class ThresholdCheckOperator(SQLThresholdCheckOperator):
- """
- This class is deprecated.
- Please use `airflow.operators.sql.SQLThresholdCheckOperator`.
- """
-
- def __init__(self, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.operators.sql.SQLThresholdCheckOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(**kwargs)
-
-
-class ValueCheckOperator(SQLValueCheckOperator):
- """
- This class is deprecated.
- Please use `airflow.operators.sql.SQLValueCheckOperator`.
- """
-
- def __init__(self, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.operators.sql.SQLValueCheckOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(**kwargs)
diff --git a/airflow/operators/dagrun_operator.py b/airflow/operators/dagrun_operator.py
deleted file mode 100644
index bdcc6671516af..0000000000000
--- a/airflow/operators/dagrun_operator.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.operators.trigger_dagrun`."""
-
-import warnings
-
-from airflow.operators.trigger_dagrun import TriggerDagRunLink, TriggerDagRunOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.operators.trigger_dagrun`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/operators/datetime.py b/airflow/operators/datetime.py
index c5a423d563868..335d33707684c 100644
--- a/airflow/operators/datetime.py
+++ b/airflow/operators/datetime.py
@@ -14,12 +14,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
import datetime
import warnings
-from typing import Iterable, Union
+from typing import Iterable
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, RemovedInAirflow3Warning
from airflow.operators.branch import BaseBranchOperator
from airflow.utils import timezone
from airflow.utils.context import Context
@@ -47,10 +48,10 @@ class BranchDateTimeOperator(BaseBranchOperator):
def __init__(
self,
*,
- follow_task_ids_if_true: Union[str, Iterable[str]],
- follow_task_ids_if_false: Union[str, Iterable[str]],
- target_lower: Union[datetime.datetime, datetime.time, None],
- target_upper: Union[datetime.datetime, datetime.time, None],
+ follow_task_ids_if_true: str | Iterable[str],
+ follow_task_ids_if_false: str | Iterable[str],
+ target_lower: datetime.datetime | datetime.time | None,
+ target_upper: datetime.datetime | datetime.time | None,
use_task_logical_date: bool = False,
use_task_execution_date: bool = False,
**kwargs,
@@ -71,17 +72,19 @@ def __init__(
self.use_task_logical_date = use_task_execution_date
warnings.warn(
"Parameter ``use_task_execution_date`` is deprecated. Use ``use_task_logical_date``.",
- DeprecationWarning,
+ RemovedInAirflow3Warning,
stacklevel=2,
)
- def choose_branch(self, context: Context) -> Union[str, Iterable[str]]:
+ def choose_branch(self, context: Context) -> str | Iterable[str]:
if self.use_task_logical_date:
- now = timezone.make_naive(context["logical_date"], self.dag.timezone)
+ now = context["logical_date"]
else:
- now = timezone.make_naive(timezone.utcnow(), self.dag.timezone)
-
+ now = timezone.coerce_datetime(timezone.utcnow())
lower, upper = target_times_as_dates(now, self.target_lower, self.target_upper)
+ lower = timezone.coerce_datetime(lower, self.dag.timezone)
+ upper = timezone.coerce_datetime(upper, self.dag.timezone)
+
if upper is not None and upper < now:
return self.follow_task_ids_if_false
@@ -93,8 +96,8 @@ def choose_branch(self, context: Context) -> Union[str, Iterable[str]]:
def target_times_as_dates(
base_date: datetime.datetime,
- lower: Union[datetime.datetime, datetime.time, None],
- upper: Union[datetime.datetime, datetime.time, None],
+ lower: datetime.datetime | datetime.time | None,
+ upper: datetime.datetime | datetime.time | None,
):
"""Ensures upper and lower time targets are datetimes by combining them with base_date"""
if isinstance(lower, datetime.datetime) and isinstance(upper, datetime.datetime):
diff --git a/airflow/operators/docker_operator.py b/airflow/operators/docker_operator.py
deleted file mode 100644
index 88235b4382461..0000000000000
--- a/airflow/operators/docker_operator.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.docker.operators.docker`."""
-
-import warnings
-
-from airflow.providers.docker.operators.docker import DockerOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.docker.operators.docker`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/operators/druid_check_operator.py b/airflow/operators/druid_check_operator.py
deleted file mode 100644
index 008a91750c91d..0000000000000
--- a/airflow/operators/druid_check_operator.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.apache.druid.operators.druid_check`."""
-
-import warnings
-
-from airflow.providers.apache.druid.operators.druid_check import DruidCheckOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.operators.sql.SQLCheckOperator`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/operators/dummy.py b/airflow/operators/dummy.py
deleted file mode 100644
index b2912e92586f0..0000000000000
--- a/airflow/operators/dummy.py
+++ /dev/null
@@ -1,45 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.operators.empty`."""
-
-import warnings
-
-from airflow.operators.empty import EmptyOperator
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.operators.empty`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class DummyOperator(EmptyOperator):
- """This class is deprecated. Please use `airflow.operators.empty.EmptyOperator`."""
-
- @property
- def inherits_from_dummy_operator(self):
- return True
-
- def __init__(self, **kwargs):
- warnings.warn(
- """This class is deprecated. Please use `airflow.operators.empty.EmptyOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- self.inherits_from_empty_operator = self.inherits_from_dummy_operator
- super().__init__(**kwargs)
diff --git a/airflow/operators/dummy_operator.py b/airflow/operators/dummy_operator.py
deleted file mode 100644
index 2b46095027875..0000000000000
--- a/airflow/operators/dummy_operator.py
+++ /dev/null
@@ -1,38 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.operators.empty`."""
-
-import warnings
-
-from airflow.operators.empty import EmptyOperator
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.operators.empty`.", DeprecationWarning, stacklevel=2
-)
-
-
-class DummyOperator(EmptyOperator):
- """This class is deprecated. Please use `airflow.operators.empty.EmptyOperator`."""
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- """This class is deprecated. Please use `airflow.operators.empty.EmptyOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/operators/email.py b/airflow/operators/email.py
index 220bafa944b12..1b8e529593451 100644
--- a/airflow/operators/email.py
+++ b/airflow/operators/email.py
@@ -15,9 +15,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict, List, Optional, Sequence, Union
+from __future__ import annotations
-from airflow.models import BaseOperator
+from typing import Any, Sequence
+
+from airflow.models.baseoperator import BaseOperator
from airflow.utils.context import Context
from airflow.utils.email import send_email
@@ -39,24 +41,24 @@ class EmailOperator(BaseOperator):
:param custom_headers: additional headers to add to the MIME message.
"""
- template_fields: Sequence[str] = ('to', 'subject', 'html_content', 'files')
+ template_fields: Sequence[str] = ("to", "subject", "html_content", "files")
template_fields_renderers = {"html_content": "html"}
- template_ext: Sequence[str] = ('.html',)
- ui_color = '#e6faf9'
+ template_ext: Sequence[str] = (".html",)
+ ui_color = "#e6faf9"
def __init__(
self,
*,
- to: Union[List[str], str],
+ to: list[str] | str,
subject: str,
html_content: str,
- files: Optional[List] = None,
- cc: Optional[Union[List[str], str]] = None,
- bcc: Optional[Union[List[str], str]] = None,
- mime_subtype: str = 'mixed',
- mime_charset: str = 'utf-8',
- conn_id: Optional[str] = None,
- custom_headers: Optional[Dict[str, Any]] = None,
+ files: list | None = None,
+ cc: list[str] | str | None = None,
+ bcc: list[str] | str | None = None,
+ mime_subtype: str = "mixed",
+ mime_charset: str = "utf-8",
+ conn_id: str | None = None,
+ custom_headers: dict[str, Any] | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
diff --git a/airflow/operators/email_operator.py b/airflow/operators/email_operator.py
deleted file mode 100644
index 80901d010f669..0000000000000
--- a/airflow/operators/email_operator.py
+++ /dev/null
@@ -1,26 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.operators.email`."""
-
-import warnings
-
-from airflow.operators.email import EmailOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.operators.email`.", DeprecationWarning, stacklevel=2
-)
diff --git a/airflow/operators/empty.py b/airflow/operators/empty.py
index 1f396d0f6a664..d4438c565e927 100644
--- a/airflow/operators/empty.py
+++ b/airflow/operators/empty.py
@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
from airflow.models.baseoperator import BaseOperator
from airflow.utils.context import Context
@@ -27,7 +28,7 @@ class EmptyOperator(BaseOperator):
The task is evaluated by the scheduler but never processed by the executor.
"""
- ui_color = '#e8f7e4'
+ ui_color = "#e8f7e4"
inherits_from_empty_operator = True
def execute(self, context: Context):
diff --git a/airflow/operators/gcs_to_s3.py b/airflow/operators/gcs_to_s3.py
deleted file mode 100644
index d02bc7f224ea9..0000000000000
--- a/airflow/operators/gcs_to_s3.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.transfers.gcs_to_s3`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.transfers.gcs_to_s3 import GCSToS3Operator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.transfers.gcs_to_s3`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/operators/generic_transfer.py b/airflow/operators/generic_transfer.py
index 2a42859d17e74..25cd3f1f69bdf 100644
--- a/airflow/operators/generic_transfer.py
+++ b/airflow/operators/generic_transfer.py
@@ -15,7 +15,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import List, Optional, Sequence, Union
+from __future__ import annotations
+
+from typing import Sequence
from airflow.hooks.base import BaseHook
from airflow.models import BaseOperator
@@ -40,13 +42,13 @@ class GenericTransfer(BaseOperator):
:param insert_args: extra params for `insert_rows` method.
"""
- template_fields: Sequence[str] = ('sql', 'destination_table', 'preoperator')
+ template_fields: Sequence[str] = ("sql", "destination_table", "preoperator")
template_ext: Sequence[str] = (
- '.sql',
- '.hql',
+ ".sql",
+ ".hql",
)
template_fields_renderers = {"preoperator": "sql"}
- ui_color = '#b0f07c'
+ ui_color = "#b0f07c"
def __init__(
self,
@@ -55,8 +57,8 @@ def __init__(
destination_table: str,
source_conn_id: str,
destination_conn_id: str,
- preoperator: Optional[Union[str, List[str]]] = None,
- insert_args: Optional[dict] = None,
+ preoperator: str | list[str] | None = None,
+ insert_args: dict | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -73,7 +75,7 @@ def execute(self, context: Context):
self.log.info("Extracting data from %s", self.source_conn_id)
self.log.info("Executing: \n %s", self.sql)
- get_records = getattr(source_hook, 'get_records', None)
+ get_records = getattr(source_hook, "get_records", None)
if not callable(get_records):
raise RuntimeError(
f"Hook for connection {self.source_conn_id!r} "
@@ -83,7 +85,7 @@ def execute(self, context: Context):
results = get_records(self.sql)
if self.preoperator:
- run = getattr(destination_hook, 'run', None)
+ run = getattr(destination_hook, "run", None)
if not callable(run):
raise RuntimeError(
f"Hook for connection {self.destination_conn_id!r} "
@@ -93,7 +95,7 @@ def execute(self, context: Context):
self.log.info(self.preoperator)
run(self.preoperator)
- insert_rows = getattr(destination_hook, 'insert_rows', None)
+ insert_rows = getattr(destination_hook, "insert_rows", None)
if not callable(insert_rows):
raise RuntimeError(
f"Hook for connection {self.destination_conn_id!r} "
diff --git a/airflow/operators/google_api_to_s3_transfer.py b/airflow/operators/google_api_to_s3_transfer.py
deleted file mode 100644
index 9566cddc77641..0000000000000
--- a/airflow/operators/google_api_to_s3_transfer.py
+++ /dev/null
@@ -1,50 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-This module is deprecated.
-Please use `airflow.providers.amazon.aws.transfers.google_api_to_s3`.
-"""
-
-import warnings
-
-from airflow.providers.amazon.aws.transfers.google_api_to_s3 import GoogleApiToS3Operator
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.transfers.google_api_to_s3`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class GoogleApiToS3Transfer(GoogleApiToS3Operator):
- """This class is deprecated.
-
- Please use:
- `airflow.providers.amazon.aws.transfers.google_api_to_s3.GoogleApiToS3Operator`.
- """
-
- def __init__(self, **kwargs):
- warnings.warn(
- "This class is deprecated. "
- "Please use "
- "`airflow.providers.amazon.aws.transfers."
- "google_api_to_s3_transfer.GoogleApiToS3Operator`.",
- DeprecationWarning,
- stacklevel=3,
- )
- super().__init__(**kwargs)
diff --git a/airflow/operators/hive_operator.py b/airflow/operators/hive_operator.py
deleted file mode 100644
index b49cf097305ea..0000000000000
--- a/airflow/operators/hive_operator.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.apache.hive.operators.hive`."""
-
-import warnings
-
-from airflow.providers.apache.hive.operators.hive import HiveOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.apache.hive.operators.hive`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/operators/hive_stats_operator.py b/airflow/operators/hive_stats_operator.py
deleted file mode 100644
index af1e260a4a155..0000000000000
--- a/airflow/operators/hive_stats_operator.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.apache.hive.operators.hive_stats`."""
-
-import warnings
-
-from airflow.providers.apache.hive.operators.hive_stats import HiveStatsCollectionOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.apache.hive.operators.hive_stats`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/operators/hive_to_druid.py b/airflow/operators/hive_to_druid.py
deleted file mode 100644
index a6537a1337a56..0000000000000
--- a/airflow/operators/hive_to_druid.py
+++ /dev/null
@@ -1,49 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-This module is deprecated.
-Please use :mod:`airflow.providers.apache.druid.transfers.hive_to_druid`.
-"""
-
-import warnings
-
-from airflow.providers.apache.druid.transfers.hive_to_druid import HiveToDruidOperator
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.apache.druid.transfers.hive_to_druid`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class HiveToDruidTransfer(HiveToDruidOperator):
- """This class is deprecated.
-
- Please use:
- `airflow.providers.apache.druid.transfers.hive_to_druid.HiveToDruidOperator`.
- """
-
- def __init__(self, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use
- `airflow.providers.apache.druid.transfers.hive_to_druid.HiveToDruidOperator`.""",
- DeprecationWarning,
- stacklevel=3,
- )
- super().__init__(**kwargs)
diff --git a/airflow/operators/hive_to_mysql.py b/airflow/operators/hive_to_mysql.py
deleted file mode 100644
index 0a13c7666a4cb..0000000000000
--- a/airflow/operators/hive_to_mysql.py
+++ /dev/null
@@ -1,49 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-This module is deprecated.
-Please use `airflow.providers.apache.hive.transfers.hive_to_mysql`.
-"""
-
-import warnings
-
-from airflow.providers.apache.hive.transfers.hive_to_mysql import HiveToMySqlOperator
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.apache.hive.transfers.hive_to_mysql`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class HiveToMySqlTransfer(HiveToMySqlOperator):
- """This class is deprecated.
-
- Please use:
- `airflow.providers.apache.hive.transfers.hive_to_mysql.HiveToMySqlOperator`.
- """
-
- def __init__(self, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use
- `airflow.providers.apache.hive.transfers.hive_to_mysql.HiveToMySqlOperator`.""",
- DeprecationWarning,
- stacklevel=3,
- )
- super().__init__(**kwargs)
diff --git a/airflow/operators/hive_to_samba_operator.py b/airflow/operators/hive_to_samba_operator.py
deleted file mode 100644
index ed3b180b3e7c4..0000000000000
--- a/airflow/operators/hive_to_samba_operator.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.apache.hive.transfers.hive_to_samba`."""
-
-import warnings
-
-from airflow.providers.apache.hive.transfers.hive_to_samba import HiveToSambaOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.apache.hive.transfers.hive_to_samba`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/operators/http_operator.py b/airflow/operators/http_operator.py
deleted file mode 100644
index 6e2ab56df4e58..0000000000000
--- a/airflow/operators/http_operator.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.http.operators.http`."""
-
-import warnings
-
-from airflow.providers.http.operators.http import SimpleHttpOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.http.operators.http`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/operators/jdbc_operator.py b/airflow/operators/jdbc_operator.py
deleted file mode 100644
index ff36f9f5d6467..0000000000000
--- a/airflow/operators/jdbc_operator.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.jdbc.operators.jdbc`."""
-
-import warnings
-
-from airflow.providers.jdbc.operators.jdbc import JdbcOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.jdbc.operators.jdbc`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/operators/latest_only.py b/airflow/operators/latest_only.py
index 8ca9688f96627..f0ba76ecf6b67 100644
--- a/airflow/operators/latest_only.py
+++ b/airflow/operators/latest_only.py
@@ -19,7 +19,9 @@
This module contains an operator to run downstream tasks only for the
latest scheduled DagRun
"""
-from typing import TYPE_CHECKING, Iterable, Union
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Iterable
import pendulum
@@ -42,37 +44,37 @@ class LatestOnlyOperator(BaseBranchOperator):
marked as externally triggered.
"""
- ui_color = '#e9ffdb' # nyanza
+ ui_color = "#e9ffdb" # nyanza
- def choose_branch(self, context: Context) -> Union[str, Iterable[str]]:
+ def choose_branch(self, context: Context) -> str | Iterable[str]:
# If the DAG Run is externally triggered, then return without
# skipping downstream tasks
- dag_run: "DagRun" = context["dag_run"]
+ dag_run: DagRun = context["dag_run"]
if dag_run.external_trigger:
self.log.info("Externally triggered DAG_Run: allowing execution to proceed.")
- return list(context['task'].get_direct_relative_ids(upstream=False))
+ return list(context["task"].get_direct_relative_ids(upstream=False))
- dag: "DAG" = context["dag"]
+ dag: DAG = context["dag"]
next_info = dag.next_dagrun_info(dag.get_run_data_interval(dag_run), restricted=False)
- now = pendulum.now('UTC')
+ now = pendulum.now("UTC")
if next_info is None:
self.log.info("Last scheduled execution: allowing execution to proceed.")
- return list(context['task'].get_direct_relative_ids(upstream=False))
+ return list(context["task"].get_direct_relative_ids(upstream=False))
left_window, right_window = next_info.data_interval
self.log.info(
- 'Checking latest only with left_window: %s right_window: %s now: %s',
+ "Checking latest only with left_window: %s right_window: %s now: %s",
left_window,
right_window,
now,
)
if not left_window < now <= right_window:
- self.log.info('Not latest execution, skipping downstream.')
+ self.log.info("Not latest execution, skipping downstream.")
# we return an empty list, thus the parent BaseBranchOperator
# won't exclude any downstream tasks from skipping.
return []
else:
- self.log.info('Latest, allowing execution to proceed.')
- return list(context['task'].get_direct_relative_ids(upstream=False))
+ self.log.info("Latest, allowing execution to proceed.")
+ return list(context["task"].get_direct_relative_ids(upstream=False))
diff --git a/airflow/operators/latest_only_operator.py b/airflow/operators/latest_only_operator.py
deleted file mode 100644
index 07644f4a82c10..0000000000000
--- a/airflow/operators/latest_only_operator.py
+++ /dev/null
@@ -1,25 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.operators.latest_only`"""
-import warnings
-
-from airflow.operators.latest_only import LatestOnlyOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.operators.latest_only`.", DeprecationWarning, stacklevel=2
-)
diff --git a/airflow/operators/mssql_operator.py b/airflow/operators/mssql_operator.py
deleted file mode 100644
index d1047b827a722..0000000000000
--- a/airflow/operators/mssql_operator.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.microsoft.mssql.operators.mssql`."""
-
-import warnings
-
-from airflow.providers.microsoft.mssql.operators.mssql import MsSqlOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.microsoft.mssql.operators.mssql`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/operators/mssql_to_hive.py b/airflow/operators/mssql_to_hive.py
deleted file mode 100644
index 02edb36b89b15..0000000000000
--- a/airflow/operators/mssql_to_hive.py
+++ /dev/null
@@ -1,49 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-This module is deprecated.
-Please use :mod:`airflow.providers.apache.hive.transfers.mssql_to_hive`.
-"""
-
-import warnings
-
-from airflow.providers.apache.hive.transfers.mssql_to_hive import MsSqlToHiveOperator
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.apache.hive.transfers.mssql_to_hive`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class MsSqlToHiveTransfer(MsSqlToHiveOperator):
- """This class is deprecated.
-
- Please use:
- `airflow.providers.apache.hive.transfers.mssql_to_hive.MsSqlToHiveOperator`.
- """
-
- def __init__(self, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use
- `airflow.providers.apache.hive.transfers.mssql_to_hive.MsSqlToHiveOperator`.""",
- DeprecationWarning,
- stacklevel=3,
- )
- super().__init__(**kwargs)
diff --git a/airflow/operators/mysql_operator.py b/airflow/operators/mysql_operator.py
deleted file mode 100644
index 82a94edd66add..0000000000000
--- a/airflow/operators/mysql_operator.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.mysql.operators.mysql`."""
-
-import warnings
-
-from airflow.providers.mysql.operators.mysql import MySqlOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.mysql.operators.mysql`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/operators/mysql_to_hive.py b/airflow/operators/mysql_to_hive.py
deleted file mode 100644
index 95bd4302e7961..0000000000000
--- a/airflow/operators/mysql_to_hive.py
+++ /dev/null
@@ -1,44 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.apache.hive.transfers.mysql_to_hive`."""
-
-import warnings
-
-from airflow.providers.apache.hive.transfers.mysql_to_hive import MySqlToHiveOperator
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.apache.hive.transfers.mysql_to_hive`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class MySqlToHiveTransfer(MySqlToHiveOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.apache.hive.transfers.mysql_to_hive.MySqlToHiveOperator`.
- """
-
- def __init__(self, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.apache.hive.transfers.mysql_to_hive.MySqlToHiveOperator`.""",
- DeprecationWarning,
- stacklevel=3,
- )
- super().__init__(**kwargs)
diff --git a/airflow/operators/oracle_operator.py b/airflow/operators/oracle_operator.py
deleted file mode 100644
index 8ad61db754dcb..0000000000000
--- a/airflow/operators/oracle_operator.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.oracle.operators.oracle`."""
-
-import warnings
-
-from airflow.providers.oracle.operators.oracle import OracleOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.oracle.operators.oracle`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/operators/papermill_operator.py b/airflow/operators/papermill_operator.py
deleted file mode 100644
index 5d63e38e13721..0000000000000
--- a/airflow/operators/papermill_operator.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.papermill.operators.papermill`."""
-
-import warnings
-
-from airflow.providers.papermill.operators.papermill import PapermillOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.papermill.operators.papermill`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/operators/pig_operator.py b/airflow/operators/pig_operator.py
deleted file mode 100644
index 3b2ea0e05ac99..0000000000000
--- a/airflow/operators/pig_operator.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.apache.pig.operators.pig`."""
-
-import warnings
-
-from airflow.providers.apache.pig.operators.pig import PigOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.apache.pig.operators.pig`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/operators/postgres_operator.py b/airflow/operators/postgres_operator.py
deleted file mode 100644
index e5dc53c82bde6..0000000000000
--- a/airflow/operators/postgres_operator.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.postgres.operators.postgres`."""
-
-import warnings
-
-from airflow.providers.postgres.operators.postgres import Mapping, PostgresOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.postgres.operators.postgres`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/operators/presto_check_operator.py b/airflow/operators/presto_check_operator.py
deleted file mode 100644
index 693471f18ceb4..0000000000000
--- a/airflow/operators/presto_check_operator.py
+++ /dev/null
@@ -1,78 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.operators.sql`."""
-
-import warnings
-
-from airflow.operators.sql import SQLCheckOperator, SQLIntervalCheckOperator, SQLValueCheckOperator
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.operators.sql`.", DeprecationWarning, stacklevel=2
-)
-
-
-class PrestoCheckOperator(SQLCheckOperator):
- """
- This class is deprecated.
- Please use `airflow.operators.sql.SQLCheckOperator`.
- """
-
- def __init__(self, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.operators.sql.SQLCheckOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(**kwargs)
-
-
-class PrestoIntervalCheckOperator(SQLIntervalCheckOperator):
- """
- This class is deprecated.
- Please use `airflow.operators.sql.SQLIntervalCheckOperator`.
- """
-
- def __init__(self, **kwargs):
- warnings.warn(
- """
- This class is deprecated.l
- Please use `airflow.operators.sql.SQLIntervalCheckOperator`.
- """,
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(**kwargs)
-
-
-class PrestoValueCheckOperator(SQLValueCheckOperator):
- """
- This class is deprecated.
- Please use `airflow.operators.sql.SQLValueCheckOperator`.
- """
-
- def __init__(self, **kwargs):
- warnings.warn(
- """
- This class is deprecated.l
- Please use `airflow.operators.sql.SQLValueCheckOperator`.
- """,
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(**kwargs)
diff --git a/airflow/operators/presto_to_mysql.py b/airflow/operators/presto_to_mysql.py
deleted file mode 100644
index bfc117327d672..0000000000000
--- a/airflow/operators/presto_to_mysql.py
+++ /dev/null
@@ -1,49 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-This module is deprecated.
-Please use :mod:`airflow.providers.mysql.transfers.presto_to_mysql`.
-"""
-
-import warnings
-
-from airflow.providers.mysql.transfers.presto_to_mysql import PrestoToMySqlOperator
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.mysql.transfers.presto_to_mysql`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class PrestoToMySqlTransfer(PrestoToMySqlOperator):
- """This class is deprecated.
-
- Please use:
- `airflow.providers.mysql.transfers.presto_to_mysql.PrestoToMySqlOperator`.
- """
-
- def __init__(self, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use
- `airflow.providers.mysql.transfers.presto_to_mysql.PrestoToMySqlOperator`.""",
- DeprecationWarning,
- stacklevel=3,
- )
- super().__init__(**kwargs)
diff --git a/airflow/operators/python.py b/airflow/operators/python.py
index 920a75d08f766..0a0dd34fef355 100644
--- a/airflow/operators/python.py
+++ b/airflow/operators/python.py
@@ -15,20 +15,25 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
import inspect
import os
import pickle
import shutil
+import subprocess
import sys
import types
import warnings
+from abc import ABCMeta, abstractmethod
+from pathlib import Path
from tempfile import TemporaryDirectory
from textwrap import dedent
-from typing import Any, Callable, Collection, Dict, Iterable, List, Mapping, Optional, Sequence, Union
+from typing import Any, Callable, Collection, Iterable, Mapping, Sequence
import dill
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowConfigException, AirflowException, RemovedInAirflow3Warning
from airflow.models.baseoperator import BaseOperator
from airflow.models.skipmixin import SkipMixin
from airflow.models.taskinstance import _CURRENT_CONTEXT
@@ -36,9 +41,10 @@
from airflow.utils.operator_helpers import KeywordParameters
from airflow.utils.process_utils import execute_in_subprocess
from airflow.utils.python_virtualenv import prepare_virtualenv, write_python_script
+from airflow.version import version as airflow_version
-def task(python_callable: Optional[Callable] = None, multiple_outputs: Optional[bool] = None, **kwargs):
+def task(python_callable: Callable | None = None, multiple_outputs: bool | None = None, **kwargs):
"""
Deprecated function that calls @task.python and allows users to turn a python function into
an Airflow task. Please use the following instead:
@@ -68,7 +74,7 @@ def my_task()
from airflow.decorators import task
@task
def my_task()""",
- DeprecationWarning,
+ RemovedInAirflow3Warning,
stacklevel=2,
)
return python_task(python_callable=python_callable, multiple_outputs=multiple_outputs, **kwargs)
@@ -121,41 +127,39 @@ def my_python_callable(**kwargs):
such as transmission a large amount of XCom to TaskAPI.
"""
- template_fields: Sequence[str] = ('templates_dict', 'op_args', 'op_kwargs')
+ template_fields: Sequence[str] = ("templates_dict", "op_args", "op_kwargs")
template_fields_renderers = {"templates_dict": "json", "op_args": "py", "op_kwargs": "py"}
- BLUE = '#ffefeb'
+ BLUE = "#ffefeb"
ui_color = BLUE
# since we won't mutate the arguments, we should just do the shallow copy
# there are some cases we can't deepcopy the objects(e.g protobuf).
shallow_copy_attrs: Sequence[str] = (
- 'python_callable',
- 'op_kwargs',
+ "python_callable",
+ "op_kwargs",
)
- mapped_arguments_validated_by_init = True
-
def __init__(
self,
*,
python_callable: Callable,
- op_args: Optional[Collection[Any]] = None,
- op_kwargs: Optional[Mapping[str, Any]] = None,
- templates_dict: Optional[Dict[str, Any]] = None,
- templates_exts: Optional[Sequence[str]] = None,
+ op_args: Collection[Any] | None = None,
+ op_kwargs: Mapping[str, Any] | None = None,
+ templates_dict: dict[str, Any] | None = None,
+ templates_exts: Sequence[str] | None = None,
show_return_value_in_logs: bool = True,
**kwargs,
) -> None:
if kwargs.get("provide_context"):
warnings.warn(
"provide_context is deprecated as of 2.0 and is no longer required",
- DeprecationWarning,
+ RemovedInAirflow3Warning,
stacklevel=2,
)
- kwargs.pop('provide_context', None)
+ kwargs.pop("provide_context", None)
super().__init__(**kwargs)
if not callable(python_callable):
- raise AirflowException('`python_callable` param must be callable')
+ raise AirflowException("`python_callable` param must be callable")
self.python_callable = python_callable
self.op_args = op_args or ()
self.op_kwargs = op_kwargs or {}
@@ -179,12 +183,11 @@ def execute(self, context: Context) -> Any:
def determine_kwargs(self, context: Mapping[str, Any]) -> Mapping[str, Any]:
return KeywordParameters.determine(self.python_callable, self.op_args, context).unpacking()
- def execute_callable(self):
+ def execute_callable(self) -> Any:
"""
Calls the python callable with the given arguments.
:return: the return value of the call.
- :rtype: any
"""
return self.python_callable(*self.op_args, **self.op_kwargs)
@@ -205,22 +208,8 @@ class BranchPythonOperator(PythonOperator, SkipMixin):
def execute(self, context: Context) -> Any:
branch = super().execute(context)
- # TODO: The logic should be moved to SkipMixin to be available to all branch operators.
- if isinstance(branch, str):
- branches = {branch}
- elif isinstance(branch, list):
- branches = set(branch)
- elif branch is None:
- branches = set()
- else:
- raise AirflowException("Branch callable must return either None, a task ID, or a list of IDs")
- valid_task_ids = set(context["dag"].task_ids)
- invalid_task_ids = branches - valid_task_ids
- if invalid_task_ids:
- raise AirflowException(
- f"Branch callable must return valid task_ids. Invalid tasks found: {invalid_task_ids}"
- )
- self.skip_all_except(context['ti'], branch)
+ self.log.info("Branch callable return %s", branch)
+ self.skip_all_except(context["ti"], branch)
return branch
@@ -259,10 +248,10 @@ def execute(self, context: Context) -> Any:
self.log.info("Condition result is %s", condition)
if condition:
- self.log.info('Proceeding with downstream tasks...')
+ self.log.info("Proceeding with downstream tasks...")
return condition
- downstream_tasks = context['task'].get_flat_relatives(upstream=False)
+ downstream_tasks = context["task"].get_flat_relatives(upstream=False)
self.log.debug("Downstream task IDs %s", downstream_tasks)
if downstream_tasks:
@@ -281,7 +270,161 @@ def execute(self, context: Context) -> Any:
self.log.info("Done.")
-class PythonVirtualenvOperator(PythonOperator):
+class _BasePythonVirtualenvOperator(PythonOperator, metaclass=ABCMeta):
+ BASE_SERIALIZABLE_CONTEXT_KEYS = {
+ "ds",
+ "ds_nodash",
+ "expanded_ti_count",
+ "inlets",
+ "next_ds",
+ "next_ds_nodash",
+ "outlets",
+ "prev_ds",
+ "prev_ds_nodash",
+ "run_id",
+ "task_instance_key_str",
+ "test_mode",
+ "tomorrow_ds",
+ "tomorrow_ds_nodash",
+ "ts",
+ "ts_nodash",
+ "ts_nodash_with_tz",
+ "yesterday_ds",
+ "yesterday_ds_nodash",
+ }
+ PENDULUM_SERIALIZABLE_CONTEXT_KEYS = {
+ "data_interval_end",
+ "data_interval_start",
+ "execution_date",
+ "logical_date",
+ "next_execution_date",
+ "prev_data_interval_end_success",
+ "prev_data_interval_start_success",
+ "prev_execution_date",
+ "prev_execution_date_success",
+ "prev_start_date_success",
+ }
+ AIRFLOW_SERIALIZABLE_CONTEXT_KEYS = {
+ "macros",
+ "conf",
+ "dag",
+ "dag_run",
+ "task",
+ "params",
+ "triggering_dataset_events",
+ }
+
+ def __init__(
+ self,
+ *,
+ python_callable: Callable,
+ use_dill: bool = False,
+ op_args: Collection[Any] | None = None,
+ op_kwargs: Mapping[str, Any] | None = None,
+ string_args: Iterable[str] | None = None,
+ templates_dict: dict | None = None,
+ templates_exts: list[str] | None = None,
+ expect_airflow: bool = True,
+ **kwargs,
+ ):
+ if (
+ not isinstance(python_callable, types.FunctionType)
+ or isinstance(python_callable, types.LambdaType)
+ and python_callable.__name__ == ""
+ ):
+ raise AirflowException("PythonVirtualenvOperator only supports functions for python_callable arg")
+ super().__init__(
+ python_callable=python_callable,
+ op_args=op_args,
+ op_kwargs=op_kwargs,
+ templates_dict=templates_dict,
+ templates_exts=templates_exts,
+ **kwargs,
+ )
+ self.string_args = string_args or []
+ self.use_dill = use_dill
+ self.pickling_library = dill if self.use_dill else pickle
+ self.expect_airflow = expect_airflow
+
+ @abstractmethod
+ def _iter_serializable_context_keys(self):
+ pass
+
+ def execute(self, context: Context) -> Any:
+ serializable_keys = set(self._iter_serializable_context_keys())
+ serializable_context = context_copy_partial(context, serializable_keys)
+ return super().execute(context=serializable_context)
+
+ def get_python_source(self):
+ """
+ Returns the source of self.python_callable
+ @return:
+ """
+ return dedent(inspect.getsource(self.python_callable))
+
+ def _write_args(self, file: Path):
+ if self.op_args or self.op_kwargs:
+ file.write_bytes(self.pickling_library.dumps({"args": self.op_args, "kwargs": self.op_kwargs}))
+
+ def _write_string_args(self, file: Path):
+ file.write_text("\n".join(map(str, self.string_args)))
+
+ def _read_result(self, path: Path):
+ if path.stat().st_size == 0:
+ return None
+ try:
+ return self.pickling_library.loads(path.read_bytes())
+ except ValueError:
+ self.log.error(
+ "Error deserializing result. Note that result deserialization "
+ "is not supported across major Python versions."
+ )
+ raise
+
+ def __deepcopy__(self, memo):
+ # module objects can't be copied _at all__
+ memo[id(self.pickling_library)] = self.pickling_library
+ return super().__deepcopy__(memo)
+
+ def _execute_python_callable_in_subprocess(self, python_path: Path, tmp_dir: Path):
+ op_kwargs: dict[str, Any] = {k: v for k, v in self.op_kwargs.items()}
+ if self.templates_dict:
+ op_kwargs["templates_dict"] = self.templates_dict
+ input_path = tmp_dir / "script.in"
+ output_path = tmp_dir / "script.out"
+ string_args_path = tmp_dir / "string_args.txt"
+ script_path = tmp_dir / "script.py"
+ self._write_args(input_path)
+ self._write_string_args(string_args_path)
+ write_python_script(
+ jinja_context=dict(
+ op_args=self.op_args,
+ op_kwargs=op_kwargs,
+ expect_airflow=self.expect_airflow,
+ pickling_library=self.pickling_library.__name__,
+ python_callable=self.python_callable.__name__,
+ python_callable_source=self.get_python_source(),
+ ),
+ filename=os.fspath(script_path),
+ render_template_as_native_obj=self.dag.render_template_as_native_obj,
+ )
+
+ execute_in_subprocess(
+ cmd=[
+ os.fspath(python_path),
+ os.fspath(script_path),
+ os.fspath(input_path),
+ os.fspath(output_path),
+ os.fspath(string_args_path),
+ ]
+ )
+ return self._read_result(output_path)
+
+ def determine_kwargs(self, context: Mapping[str, Any]) -> Mapping[str, Any]:
+ return KeywordParameters.determine(self.python_callable, self.op_args, context).serializing()
+
+
+class PythonVirtualenvOperator(_BasePythonVirtualenvOperator):
"""
Allows one to run a function in a virtualenv that is created and destroyed
automatically (with certain caveats).
@@ -325,67 +468,31 @@ class PythonVirtualenvOperator(PythonOperator):
in your callable's context after the template has been applied
:param templates_exts: a list of file extensions to resolve while
processing templated fields, for examples ``['.sql', '.hql']``
+ :param expect_airflow: expect Airflow to be installed in the target environment. If true, the operator
+ will raise warning if Airflow is not installed, and it will attempt to load Airflow
+ macros when starting.
"""
- template_fields: Sequence[str] = tuple({'requirements'} | set(PythonOperator.template_fields))
-
- template_ext: Sequence[str] = ('.txt',)
- BASE_SERIALIZABLE_CONTEXT_KEYS = {
- 'ds',
- 'ds_nodash',
- 'inlets',
- 'next_ds',
- 'next_ds_nodash',
- 'outlets',
- 'prev_ds',
- 'prev_ds_nodash',
- 'run_id',
- 'task_instance_key_str',
- 'test_mode',
- 'tomorrow_ds',
- 'tomorrow_ds_nodash',
- 'ts',
- 'ts_nodash',
- 'ts_nodash_with_tz',
- 'yesterday_ds',
- 'yesterday_ds_nodash',
- }
- PENDULUM_SERIALIZABLE_CONTEXT_KEYS = {
- 'data_interval_end',
- 'data_interval_start',
- 'execution_date',
- 'logical_date',
- 'next_execution_date',
- 'prev_data_interval_end_success',
- 'prev_data_interval_start_success',
- 'prev_execution_date',
- 'prev_execution_date_success',
- 'prev_start_date_success',
- }
- AIRFLOW_SERIALIZABLE_CONTEXT_KEYS = {'macros', 'conf', 'dag', 'dag_run', 'task', 'params'}
+ template_fields: Sequence[str] = tuple({"requirements"} | set(PythonOperator.template_fields))
+ template_ext: Sequence[str] = (".txt",)
def __init__(
self,
*,
python_callable: Callable,
- requirements: Union[None, Iterable[str], str] = None,
- python_version: Optional[Union[str, int, float]] = None,
+ requirements: None | Iterable[str] | str = None,
+ python_version: str | int | float | None = None,
use_dill: bool = False,
system_site_packages: bool = True,
- pip_install_options: Optional[List[str]] = None,
- op_args: Optional[Collection[Any]] = None,
- op_kwargs: Optional[Mapping[str, Any]] = None,
- string_args: Optional[Iterable[str]] = None,
- templates_dict: Optional[Dict] = None,
- templates_exts: Optional[List[str]] = None,
+ pip_install_options: list[str] | None = None,
+ op_args: Collection[Any] | None = None,
+ op_kwargs: Mapping[str, Any] | None = None,
+ string_args: Iterable[str] | None = None,
+ templates_dict: dict | None = None,
+ templates_exts: list[str] | None = None,
+ expect_airflow: bool = True,
**kwargs,
):
- if (
- not isinstance(python_callable, types.FunctionType)
- or isinstance(python_callable, types.LambdaType)
- and python_callable.__name__ == ""
- ):
- raise AirflowException('PythonVirtualenvOperator only supports functions for python_callable arg')
if (
python_version
and str(python_version)[0] != str(sys.version_info.major)
@@ -394,41 +501,35 @@ def __init__(
raise AirflowException(
"Passing op_args or op_kwargs is not supported across different Python "
"major versions for PythonVirtualenvOperator. Please use string_args."
+ f"Sys version: {sys.version_info}. Venv version: {python_version}"
)
if not shutil.which("virtualenv"):
- raise AirflowException('PythonVirtualenvOperator requires virtualenv, please install it.')
- super().__init__(
- python_callable=python_callable,
- op_args=op_args,
- op_kwargs=op_kwargs,
- templates_dict=templates_dict,
- templates_exts=templates_exts,
- **kwargs,
- )
+ raise AirflowException("PythonVirtualenvOperator requires virtualenv, please install it.")
if not requirements:
- self.requirements: Union[List[str], str] = []
+ self.requirements: list[str] | str = []
elif isinstance(requirements, str):
self.requirements = requirements
else:
self.requirements = list(requirements)
- self.string_args = string_args or []
self.python_version = python_version
- self.use_dill = use_dill
self.system_site_packages = system_site_packages
self.pip_install_options = pip_install_options
- self.pickling_library = dill if self.use_dill else pickle
-
- def execute(self, context: Context) -> Any:
- serializable_keys = set(self._iter_serializable_context_keys())
- serializable_context = context_copy_partial(context, serializable_keys)
- return super().execute(context=serializable_context)
-
- def determine_kwargs(self, context: Mapping[str, Any]) -> Mapping[str, Any]:
- return KeywordParameters.determine(self.python_callable, self.op_args, context).serializing()
+ super().__init__(
+ python_callable=python_callable,
+ use_dill=use_dill,
+ op_args=op_args,
+ op_kwargs=op_kwargs,
+ string_args=string_args,
+ templates_dict=templates_dict,
+ templates_exts=templates_exts,
+ expect_airflow=expect_airflow,
+ **kwargs,
+ )
def execute_callable(self):
- with TemporaryDirectory(prefix='venv') as tmp_dir:
- requirements_file_name = f'{tmp_dir}/requirements.txt'
+ with TemporaryDirectory(prefix="venv") as tmp_dir:
+ tmp_path = Path(tmp_dir)
+ requirements_file_name = f"{tmp_dir}/requirements.txt"
if not isinstance(self.requirements, str):
requirements_file_contents = "\n".join(str(dependency) for dependency in self.requirements)
@@ -436,94 +537,184 @@ def execute_callable(self):
requirements_file_contents = self.requirements
if not self.system_site_packages and self.use_dill:
- requirements_file_contents += '\ndill'
+ requirements_file_contents += "\ndill"
- with open(requirements_file_name, 'w') as file:
+ with open(requirements_file_name, "w") as file:
file.write(requirements_file_contents)
-
- if self.templates_dict:
- self.op_kwargs['templates_dict'] = self.templates_dict
-
- input_filename = os.path.join(tmp_dir, 'script.in')
- output_filename = os.path.join(tmp_dir, 'script.out')
- string_args_filename = os.path.join(tmp_dir, 'string_args.txt')
- script_filename = os.path.join(tmp_dir, 'script.py')
-
prepare_virtualenv(
venv_directory=tmp_dir,
- python_bin=f'python{self.python_version}' if self.python_version else None,
+ python_bin=f"python{self.python_version}" if self.python_version else None,
system_site_packages=self.system_site_packages,
requirements_file_path=requirements_file_name,
pip_install_options=self.pip_install_options,
)
+ python_path = tmp_path / "bin" / "python"
- self._write_args(input_filename)
- self._write_string_args(string_args_filename)
- write_python_script(
- jinja_context=dict(
- op_args=self.op_args,
- op_kwargs=self.op_kwargs,
- pickling_library=self.pickling_library.__name__,
- python_callable=self.python_callable.__name__,
- python_callable_source=self.get_python_source(),
- ),
- filename=script_filename,
- render_template_as_native_obj=self.dag.render_template_as_native_obj,
- )
+ return self._execute_python_callable_in_subprocess(python_path, tmp_path)
- execute_in_subprocess(
- cmd=[
- f'{tmp_dir}/bin/python',
- script_filename,
- input_filename,
- output_filename,
- string_args_filename,
- ]
- )
+ def _iter_serializable_context_keys(self):
+ yield from self.BASE_SERIALIZABLE_CONTEXT_KEYS
+ if self.system_site_packages or "apache-airflow" in self.requirements:
+ yield from self.AIRFLOW_SERIALIZABLE_CONTEXT_KEYS
+ yield from self.PENDULUM_SERIALIZABLE_CONTEXT_KEYS
+ elif "pendulum" in self.requirements:
+ yield from self.PENDULUM_SERIALIZABLE_CONTEXT_KEYS
- return self._read_result(output_filename)
- def get_python_source(self):
- """
- Returns the source of self.python_callable
- @return:
- """
- return dedent(inspect.getsource(self.python_callable))
+class ExternalPythonOperator(_BasePythonVirtualenvOperator):
+ """
+ Allows one to run a function in a virtualenv that is not re-created but used as is
+ without the overhead of creating the virtualenv (with certain caveats).
- def _write_args(self, filename):
- if self.op_args or self.op_kwargs:
- with open(filename, 'wb') as file:
- self.pickling_library.dump({'args': self.op_args, 'kwargs': self.op_kwargs}, file)
+ The function must be defined using def, and not be
+ part of a class. All imports must happen inside the function
+ and no variables outside the scope may be referenced. A global scope
+ variable named virtualenv_string_args will be available (populated by
+ string_args). In addition, one can pass stuff through op_args and op_kwargs, and one
+ can use a return value.
+ Note that if your virtualenv runs in a different Python major version than Airflow,
+ you cannot use return values, op_args, op_kwargs, or use any macros that are being provided to
+ Airflow through plugins. You can use string_args though.
+
+ If Airflow is installed in the external environment in different version that the version
+ used by the operator, the operator will fail.,
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:ExternalPythonOperator`
+
+ :param python: Full path string (file-system specific) that points to a Python binary inside
+ a virtualenv that should be used (in ``VENV/bin`` folder). Should be absolute path
+ (so usually start with "/" or "X:/" depending on the filesystem/os used).
+ :param python_callable: A python function with no references to outside variables,
+ defined with def, which will be run in a virtualenv
+ :param use_dill: Whether to use dill to serialize
+ the args and result (pickle is default). This allow more complex types
+ but if dill is not preinstalled in your venv, the task will fail with use_dill enabled.
+ :param op_args: A list of positional arguments to pass to python_callable.
+ :param op_kwargs: A dict of keyword arguments to pass to python_callable.
+ :param string_args: Strings that are present in the global var virtualenv_string_args,
+ available to python_callable at runtime as a list[str]. Note that args are split
+ by newline.
+ :param templates_dict: a dictionary where the values are templates that
+ will get templated by the Airflow engine sometime between
+ ``__init__`` and ``execute`` takes place and are made available
+ in your callable's context after the template has been applied
+ :param templates_exts: a list of file extensions to resolve while
+ processing templated fields, for examples ``['.sql', '.hql']``
+ :param expect_airflow: expect Airflow to be installed in the target environment. If true, the operator
+ will raise warning if Airflow is not installed, and it will attempt to load Airflow
+ macros when starting.
+ """
+
+ template_fields: Sequence[str] = tuple({"python"} | set(PythonOperator.template_fields))
+
+ def __init__(
+ self,
+ *,
+ python: str,
+ python_callable: Callable,
+ use_dill: bool = False,
+ op_args: Collection[Any] | None = None,
+ op_kwargs: Mapping[str, Any] | None = None,
+ string_args: Iterable[str] | None = None,
+ templates_dict: dict | None = None,
+ templates_exts: list[str] | None = None,
+ expect_airflow: bool = True,
+ expect_pendulum: bool = False,
+ **kwargs,
+ ):
+ if not python:
+ raise ValueError("Python Path must be defined in ExternalPythonOperator")
+ self.python = python
+ self.expect_pendulum = expect_pendulum
+ super().__init__(
+ python_callable=python_callable,
+ use_dill=use_dill,
+ op_args=op_args,
+ op_kwargs=op_kwargs,
+ string_args=string_args,
+ templates_dict=templates_dict,
+ templates_exts=templates_exts,
+ expect_airflow=expect_airflow,
+ **kwargs,
+ )
+
+ def execute_callable(self):
+ python_path = Path(self.python)
+ if not python_path.exists():
+ raise ValueError(f"Python Path '{python_path}' must exists")
+ if not python_path.is_file():
+ raise ValueError(f"Python Path '{python_path}' must be a file")
+ if not python_path.is_absolute():
+ raise ValueError(f"Python Path '{python_path}' must be an absolute path.")
+ python_version_as_list_of_strings = self._get_python_version_from_environment()
+ if (
+ python_version_as_list_of_strings
+ and str(python_version_as_list_of_strings[0]) != str(sys.version_info.major)
+ and (self.op_args or self.op_kwargs)
+ ):
+ raise AirflowException(
+ "Passing op_args or op_kwargs is not supported across different Python "
+ "major versions for ExternalPythonOperator. Please use string_args."
+ f"Sys version: {sys.version_info}. Venv version: {python_version_as_list_of_strings}"
+ )
+ with TemporaryDirectory(prefix="tmd") as tmp_dir:
+ tmp_path = Path(tmp_dir)
+ return self._execute_python_callable_in_subprocess(python_path, tmp_path)
+
+ def _get_python_version_from_environment(self) -> list[str]:
+ try:
+ result = subprocess.check_output([self.python, "--version"], text=True)
+ return result.strip().split(" ")[-1].split(".")
+ except Exception as e:
+ raise ValueError(f"Error while executing {self.python}: {e}")
def _iter_serializable_context_keys(self):
yield from self.BASE_SERIALIZABLE_CONTEXT_KEYS
- if self.system_site_packages or 'apache-airflow' in self.requirements:
+ if self._get_airflow_version_from_target_env():
yield from self.AIRFLOW_SERIALIZABLE_CONTEXT_KEYS
yield from self.PENDULUM_SERIALIZABLE_CONTEXT_KEYS
- elif 'pendulum' in self.requirements:
+ elif self._is_pendulum_installed_in_target_env():
yield from self.PENDULUM_SERIALIZABLE_CONTEXT_KEYS
- def _write_string_args(self, filename):
- with open(filename, 'w') as file:
- file.write('\n'.join(map(str, self.string_args)))
-
- def _read_result(self, filename):
- if os.stat(filename).st_size == 0:
- return None
- with open(filename, 'rb') as file:
- try:
- return self.pickling_library.load(file)
- except ValueError:
- self.log.error(
- "Error deserializing result. Note that result deserialization "
- "is not supported across major Python versions."
+ def _is_pendulum_installed_in_target_env(self) -> bool:
+ try:
+ subprocess.check_call([self.python, "-c", "import pendulum"])
+ return True
+ except Exception as e:
+ if self.expect_pendulum:
+ self.log.warning("When checking for Pendulum installed in venv got %s", e)
+ self.log.warning(
+ "Pendulum is not properly installed in the virtualenv "
+ "Pendulum context keys will not be available. "
+ "Please Install Pendulum or Airflow in your venv to access them."
)
- raise
+ return False
- def __deepcopy__(self, memo):
- # module objects can't be copied _at all__
- memo[id(self.pickling_library)] = self.pickling_library
- return super().__deepcopy__(memo)
+ def _get_airflow_version_from_target_env(self) -> str | None:
+ try:
+ result = subprocess.check_output(
+ [self.python, "-c", "from airflow import version; print(version.version)"], text=True
+ )
+ target_airflow_version = result.strip()
+ if target_airflow_version != airflow_version:
+ raise AirflowConfigException(
+ f"The version of Airflow installed for the {self.python}("
+ f"{target_airflow_version}) is different than the runtime Airflow version: "
+ f"{airflow_version}. Make sure your environment has the same Airflow version "
+ f"installed as the Airflow runtime."
+ )
+ return target_airflow_version
+ except Exception as e:
+ if self.expect_airflow:
+ self.log.warning("When checking for Airflow installed in venv got %s", e)
+ self.log.warning(
+ f"This means that Airflow is not properly installed by "
+ f"{self.python}. Airflow context keys will not be available. "
+ f"Please Install Airflow {airflow_version} in your environment to access them."
+ )
+ return None
def get_current_context() -> Context:
diff --git a/airflow/operators/python_operator.py b/airflow/operators/python_operator.py
deleted file mode 100644
index ac8c6448d241d..0000000000000
--- a/airflow/operators/python_operator.py
+++ /dev/null
@@ -1,31 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.operators.python`."""
-
-import warnings
-
-from airflow.operators.python import ( # noqa
- BranchPythonOperator,
- PythonOperator,
- PythonVirtualenvOperator,
- ShortCircuitOperator,
-)
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.operators.python`.", DeprecationWarning, stacklevel=2
-)
diff --git a/airflow/operators/redshift_to_s3_operator.py b/airflow/operators/redshift_to_s3_operator.py
deleted file mode 100644
index 9fceb700d42c7..0000000000000
--- a/airflow/operators/redshift_to_s3_operator.py
+++ /dev/null
@@ -1,48 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-This module is deprecated.
-Please use :mod:`airflow.providers.amazon.aws.transfers.redshift_to_s3`.
-"""
-
-import warnings
-
-from airflow.providers.amazon.aws.transfers.redshift_to_s3 import RedshiftToS3Operator
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.transfers.redshift_to_s3`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class RedshiftToS3Transfer(RedshiftToS3Operator):
- """
- This class is deprecated.
- Please use: :class:`airflow.providers.amazon.aws.transfers.redshift_to_s3.RedshiftToS3Operator`.
- """
-
- def __init__(self, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use
- `airflow.providers.amazon.aws.transfers.redshift_to_s3.RedshiftToS3Operator`.""",
- DeprecationWarning,
- stacklevel=3,
- )
- super().__init__(**kwargs)
diff --git a/airflow/operators/s3_file_transform_operator.py b/airflow/operators/s3_file_transform_operator.py
deleted file mode 100644
index 828031d814102..0000000000000
--- a/airflow/operators/s3_file_transform_operator.py
+++ /dev/null
@@ -1,31 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-This module is deprecated.
-Please use :mod:`airflow.providers.amazon.aws.operators.s3_file_transform`
-"""
-
-import warnings
-
-from airflow.providers.amazon.aws.operators.s3_file_transform import S3FileTransformOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.s3_file_transform`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/operators/s3_to_hive_operator.py b/airflow/operators/s3_to_hive_operator.py
deleted file mode 100644
index b0e1f6b69258a..0000000000000
--- a/airflow/operators/s3_to_hive_operator.py
+++ /dev/null
@@ -1,44 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.apache.hive.transfers.s3_to_hive`."""
-
-import warnings
-
-from airflow.providers.apache.hive.transfers.s3_to_hive import S3ToHiveOperator
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.apache.hive.transfers.s3_to_hive`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class S3ToHiveTransfer(S3ToHiveOperator):
- """
- This class is deprecated.
- Please use `airflow.providers.apache.hive.transfers.s3_to_hive.S3ToHiveOperator`.
- """
-
- def __init__(self, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.apache.hive.transfers.s3_to_hive.S3ToHiveOperator`.""",
- DeprecationWarning,
- stacklevel=3,
- )
- super().__init__(**kwargs)
diff --git a/airflow/operators/s3_to_redshift_operator.py b/airflow/operators/s3_to_redshift_operator.py
deleted file mode 100644
index f14a2912a8e8f..0000000000000
--- a/airflow/operators/s3_to_redshift_operator.py
+++ /dev/null
@@ -1,49 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-This module is deprecated.
-Please use :mod:`airflow.providers.amazon.aws.transfers.s3_to_redshift`.
-"""
-
-import warnings
-
-from airflow.providers.amazon.aws.transfers.s3_to_redshift import S3ToRedshiftOperator
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.transfers.s3_to_redshift`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class S3ToRedshiftTransfer(S3ToRedshiftOperator):
- """This class is deprecated.
-
- Please use:
- `airflow.providers.amazon.aws.transfers.s3_to_redshift.S3ToRedshiftOperator`.
- """
-
- def __init__(self, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use
- `airflow.providers.amazon.aws.transfers.s3_to_redshift.S3ToRedshiftOperator`.""",
- DeprecationWarning,
- stacklevel=3,
- )
- super().__init__(**kwargs)
diff --git a/airflow/operators/slack_operator.py b/airflow/operators/slack_operator.py
deleted file mode 100644
index 3af49e222218e..0000000000000
--- a/airflow/operators/slack_operator.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.slack.operators.slack`."""
-
-import warnings
-
-from airflow.providers.slack.operators.slack import SlackAPIOperator, SlackAPIPostOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.slack.operators.slack`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/operators/smooth.py b/airflow/operators/smooth.py
index 9dcbccb9e0a2f..66bf632de8395 100644
--- a/airflow/operators/smooth.py
+++ b/airflow/operators/smooth.py
@@ -15,6 +15,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
from airflow.models.baseoperator import BaseOperator
from airflow.utils.context import Context
@@ -25,7 +27,7 @@ class SmoothOperator(BaseOperator):
Sade song "Smooth Operator".
"""
- ui_color = '#e8f7e4'
+ ui_color = "#e8f7e4"
yt_link: str = "https://www.youtube.com/watch?v=4TYv2PhG89A"
def __init__(self, **kwargs) -> None:
diff --git a/airflow/operators/sql.py b/airflow/operators/sql.py
deleted file mode 100644
index efa5d0d81a8f0..0000000000000
--- a/airflow/operators/sql.py
+++ /dev/null
@@ -1,557 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, SupportsAbs, Union
-
-from airflow.compat.functools import cached_property
-from airflow.exceptions import AirflowException
-from airflow.hooks.base import BaseHook
-from airflow.hooks.dbapi import DbApiHook
-from airflow.models import BaseOperator, SkipMixin
-from airflow.utils.context import Context
-
-
-def parse_boolean(val: str) -> Union[str, bool]:
- """Try to parse a string into boolean.
-
- Raises ValueError if the input is not a valid true- or false-like string value.
- """
- val = val.lower()
- if val in ('y', 'yes', 't', 'true', 'on', '1'):
- return True
- if val in ('n', 'no', 'f', 'false', 'off', '0'):
- return False
- raise ValueError(f"{val!r} is not a boolean-like string value")
-
-
-class BaseSQLOperator(BaseOperator):
- """
- This is a base class for generic SQL Operator to get a DB Hook
-
- The provided method is .get_db_hook(). The default behavior will try to
- retrieve the DB hook based on connection type.
- You can custom the behavior by overriding the .get_db_hook() method.
- """
-
- def __init__(
- self,
- *,
- conn_id: Optional[str] = None,
- database: Optional[str] = None,
- hook_params: Optional[Dict] = None,
- **kwargs,
- ):
- super().__init__(**kwargs)
- self.conn_id = conn_id
- self.database = database
- self.hook_params = {} if hook_params is None else hook_params
-
- @cached_property
- def _hook(self):
- """Get DB Hook based on connection type"""
- self.log.debug("Get connection for %s", self.conn_id)
- conn = BaseHook.get_connection(self.conn_id)
-
- hook = conn.get_hook(hook_params=self.hook_params)
- if not isinstance(hook, DbApiHook):
- raise AirflowException(
- f'The connection type is not supported by {self.__class__.__name__}. '
- f'The associated hook should be a subclass of `DbApiHook`. Got {hook.__class__.__name__}'
- )
-
- if self.database:
- hook.schema = self.database
-
- return hook
-
- def get_db_hook(self) -> DbApiHook:
- """
- Get the database hook for the connection.
-
- :return: the database hook object.
- :rtype: DbApiHook
- """
- return self._hook
-
-
-class SQLCheckOperator(BaseSQLOperator):
- """
- Performs checks against a db. The ``SQLCheckOperator`` expects
- a sql query that will return a single row. Each value on that
- first row is evaluated using python ``bool`` casting. If any of the
- values return ``False`` the check is failed and errors out.
-
- Note that Python bool casting evals the following as ``False``:
-
- * ``False``
- * ``0``
- * Empty string (``""``)
- * Empty list (``[]``)
- * Empty dictionary or set (``{}``)
-
- Given a query like ``SELECT COUNT(*) FROM foo``, it will fail only if
- the count ``== 0``. You can craft much more complex query that could,
- for instance, check that the table has the same number of rows as
- the source table upstream, or that the count of today's partition is
- greater than yesterday's partition, or that a set of metrics are less
- than 3 standard deviation for the 7 day average.
-
- This operator can be used as a data quality check in your pipeline, and
- depending on where you put it in your DAG, you have the choice to
- stop the critical path, preventing from
- publishing dubious data, or on the side and receive email alerts
- without stopping the progress of the DAG.
-
- :param sql: the sql to be executed. (templated)
- :param conn_id: the connection ID used to connect to the database.
- :param database: name of database which overwrite the defined one in connection
- """
-
- template_fields: Sequence[str] = ("sql",)
- template_ext: Sequence[str] = (
- ".hql",
- ".sql",
- )
- template_fields_renderers = {"sql": "sql"}
- ui_color = "#fff7e6"
-
- def __init__(
- self, *, sql: str, conn_id: Optional[str] = None, database: Optional[str] = None, **kwargs
- ) -> None:
- super().__init__(conn_id=conn_id, database=database, **kwargs)
- self.sql = sql
-
- def execute(self, context: Context):
- self.log.info("Executing SQL check: %s", self.sql)
- records = self.get_db_hook().get_first(self.sql)
-
- self.log.info("Record: %s", records)
- if not records:
- raise AirflowException("The query returned None")
- elif not all(bool(r) for r in records):
- raise AirflowException(f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}")
-
- self.log.info("Success.")
-
-
-def _convert_to_float_if_possible(s):
- """
- A small helper function to convert a string to a numeric value
- if appropriate
-
- :param s: the string to be converted
- """
- try:
- ret = float(s)
- except (ValueError, TypeError):
- ret = s
- return ret
-
-
-class SQLValueCheckOperator(BaseSQLOperator):
- """
- Performs a simple value check using sql code.
-
- :param sql: the sql to be executed. (templated)
- :param conn_id: the connection ID used to connect to the database.
- :param database: name of database which overwrite the defined one in connection
- """
-
- __mapper_args__ = {"polymorphic_identity": "SQLValueCheckOperator"}
- template_fields: Sequence[str] = (
- "sql",
- "pass_value",
- )
- template_ext: Sequence[str] = (
- ".hql",
- ".sql",
- )
- template_fields_renderers = {"sql": "sql"}
- ui_color = "#fff7e6"
-
- def __init__(
- self,
- *,
- sql: str,
- pass_value: Any,
- tolerance: Any = None,
- conn_id: Optional[str] = None,
- database: Optional[str] = None,
- **kwargs,
- ):
- super().__init__(conn_id=conn_id, database=database, **kwargs)
- self.sql = sql
- self.pass_value = str(pass_value)
- tol = _convert_to_float_if_possible(tolerance)
- self.tol = tol if isinstance(tol, float) else None
- self.has_tolerance = self.tol is not None
-
- def execute(self, context=None):
- self.log.info("Executing SQL check: %s", self.sql)
- records = self.get_db_hook().get_first(self.sql)
-
- if not records:
- raise AirflowException("The query returned None")
-
- pass_value_conv = _convert_to_float_if_possible(self.pass_value)
- is_numeric_value_check = isinstance(pass_value_conv, float)
-
- tolerance_pct_str = str(self.tol * 100) + "%" if self.has_tolerance else None
- error_msg = (
- "Test failed.\nPass value:{pass_value_conv}\n"
- "Tolerance:{tolerance_pct_str}\n"
- "Query:\n{sql}\nResults:\n{records!s}"
- ).format(
- pass_value_conv=pass_value_conv,
- tolerance_pct_str=tolerance_pct_str,
- sql=self.sql,
- records=records,
- )
-
- if not is_numeric_value_check:
- tests = self._get_string_matches(records, pass_value_conv)
- elif is_numeric_value_check:
- try:
- numeric_records = self._to_float(records)
- except (ValueError, TypeError):
- raise AirflowException(f"Converting a result to float failed.\n{error_msg}")
- tests = self._get_numeric_matches(numeric_records, pass_value_conv)
- else:
- tests = []
-
- if not all(tests):
- raise AirflowException(error_msg)
-
- def _to_float(self, records):
- return [float(record) for record in records]
-
- def _get_string_matches(self, records, pass_value_conv):
- return [str(record) == pass_value_conv for record in records]
-
- def _get_numeric_matches(self, numeric_records, numeric_pass_value_conv):
- if self.has_tolerance:
- return [
- numeric_pass_value_conv * (1 - self.tol) <= record <= numeric_pass_value_conv * (1 + self.tol)
- for record in numeric_records
- ]
-
- return [record == numeric_pass_value_conv for record in numeric_records]
-
-
-class SQLIntervalCheckOperator(BaseSQLOperator):
- """
- Checks that the values of metrics given as SQL expressions are within
- a certain tolerance of the ones from days_back before.
-
- :param table: the table name
- :param conn_id: the connection ID used to connect to the database.
- :param database: name of database which will overwrite the defined one in connection
- :param days_back: number of days between ds and the ds we want to check
- against. Defaults to 7 days
- :param date_filter_column: The column name for the dates to filter on. Defaults to 'ds'
- :param ratio_formula: which formula to use to compute the ratio between
- the two metrics. Assuming cur is the metric of today and ref is
- the metric to today - days_back.
-
- max_over_min: computes max(cur, ref) / min(cur, ref)
- relative_diff: computes abs(cur-ref) / ref
-
- Default: 'max_over_min'
- :param ignore_zero: whether we should ignore zero metrics
- :param metrics_thresholds: a dictionary of ratios indexed by metrics
- """
-
- __mapper_args__ = {"polymorphic_identity": "SQLIntervalCheckOperator"}
- template_fields: Sequence[str] = ("sql1", "sql2")
- template_ext: Sequence[str] = (
- ".hql",
- ".sql",
- )
- template_fields_renderers = {"sql1": "sql", "sql2": "sql"}
- ui_color = "#fff7e6"
-
- ratio_formulas = {
- "max_over_min": lambda cur, ref: float(max(cur, ref)) / min(cur, ref),
- "relative_diff": lambda cur, ref: float(abs(cur - ref)) / ref,
- }
-
- def __init__(
- self,
- *,
- table: str,
- metrics_thresholds: Dict[str, int],
- date_filter_column: Optional[str] = "ds",
- days_back: SupportsAbs[int] = -7,
- ratio_formula: Optional[str] = "max_over_min",
- ignore_zero: bool = True,
- conn_id: Optional[str] = None,
- database: Optional[str] = None,
- **kwargs,
- ):
- super().__init__(conn_id=conn_id, database=database, **kwargs)
- if ratio_formula not in self.ratio_formulas:
- msg_template = "Invalid diff_method: {diff_method}. Supported diff methods are: {diff_methods}"
-
- raise AirflowException(
- msg_template.format(diff_method=ratio_formula, diff_methods=self.ratio_formulas)
- )
- self.ratio_formula = ratio_formula
- self.ignore_zero = ignore_zero
- self.table = table
- self.metrics_thresholds = metrics_thresholds
- self.metrics_sorted = sorted(metrics_thresholds.keys())
- self.date_filter_column = date_filter_column
- self.days_back = -abs(days_back)
- sqlexp = ", ".join(self.metrics_sorted)
- sqlt = f"SELECT {sqlexp} FROM {table} WHERE {date_filter_column}="
-
- self.sql1 = sqlt + "'{{ ds }}'"
- self.sql2 = sqlt + "'{{ macros.ds_add(ds, " + str(self.days_back) + ") }}'"
-
- def execute(self, context=None):
- hook = self.get_db_hook()
- self.log.info("Using ratio formula: %s", self.ratio_formula)
- self.log.info("Executing SQL check: %s", self.sql2)
- row2 = hook.get_first(self.sql2)
- self.log.info("Executing SQL check: %s", self.sql1)
- row1 = hook.get_first(self.sql1)
-
- if not row2:
- raise AirflowException(f"The query {self.sql2} returned None")
- if not row1:
- raise AirflowException(f"The query {self.sql1} returned None")
-
- current = dict(zip(self.metrics_sorted, row1))
- reference = dict(zip(self.metrics_sorted, row2))
-
- ratios = {}
- test_results = {}
-
- for metric in self.metrics_sorted:
- cur = current[metric]
- ref = reference[metric]
- threshold = self.metrics_thresholds[metric]
- if cur == 0 or ref == 0:
- ratios[metric] = None
- test_results[metric] = self.ignore_zero
- else:
- ratios[metric] = self.ratio_formulas[self.ratio_formula](current[metric], reference[metric])
- test_results[metric] = ratios[metric] < threshold
-
- self.log.info(
- (
- "Current metric for %s: %s\n"
- "Past metric for %s: %s\n"
- "Ratio for %s: %s\n"
- "Threshold: %s\n"
- ),
- metric,
- cur,
- metric,
- ref,
- metric,
- ratios[metric],
- threshold,
- )
-
- if not all(test_results.values()):
- failed_tests = [it[0] for it in test_results.items() if not it[1]]
- self.log.warning(
- "The following %s tests out of %s failed:",
- len(failed_tests),
- len(self.metrics_sorted),
- )
- for k in failed_tests:
- self.log.warning(
- "'%s' check failed. %s is above %s",
- k,
- ratios[k],
- self.metrics_thresholds[k],
- )
- raise AirflowException(f"The following tests have failed:\n {', '.join(sorted(failed_tests))}")
-
- self.log.info("All tests have passed")
-
-
-class SQLThresholdCheckOperator(BaseSQLOperator):
- """
- Performs a value check using sql code against a minimum threshold
- and a maximum threshold. Thresholds can be in the form of a numeric
- value OR a sql statement that results a numeric.
-
- :param sql: the sql to be executed. (templated)
- :param conn_id: the connection ID used to connect to the database.
- :param database: name of database which overwrite the defined one in connection
- :param min_threshold: numerical value or min threshold sql to be executed (templated)
- :param max_threshold: numerical value or max threshold sql to be executed (templated)
- """
-
- template_fields: Sequence[str] = ("sql", "min_threshold", "max_threshold")
- template_ext: Sequence[str] = (
- ".hql",
- ".sql",
- )
- template_fields_renderers = {"sql": "sql"}
-
- def __init__(
- self,
- *,
- sql: str,
- min_threshold: Any,
- max_threshold: Any,
- conn_id: Optional[str] = None,
- database: Optional[str] = None,
- **kwargs,
- ):
- super().__init__(conn_id=conn_id, database=database, **kwargs)
- self.sql = sql
- self.min_threshold = _convert_to_float_if_possible(min_threshold)
- self.max_threshold = _convert_to_float_if_possible(max_threshold)
-
- def execute(self, context=None):
- hook = self.get_db_hook()
- result = hook.get_first(self.sql)[0]
-
- if isinstance(self.min_threshold, float):
- lower_bound = self.min_threshold
- else:
- lower_bound = hook.get_first(self.min_threshold)[0]
-
- if isinstance(self.max_threshold, float):
- upper_bound = self.max_threshold
- else:
- upper_bound = hook.get_first(self.max_threshold)[0]
-
- meta_data = {
- "result": result,
- "task_id": self.task_id,
- "min_threshold": lower_bound,
- "max_threshold": upper_bound,
- "within_threshold": lower_bound <= result <= upper_bound,
- }
-
- self.push(meta_data)
- if not meta_data["within_threshold"]:
- error_msg = (
- f'Threshold Check: "{meta_data.get("task_id")}" failed.\n'
- f'DAG: {self.dag_id}\nTask_id: {meta_data.get("task_id")}\n'
- f'Check description: {meta_data.get("description")}\n'
- f"SQL: {self.sql}\n"
- f'Result: {round(meta_data.get("result"), 2)} is not within thresholds '
- f'{meta_data.get("min_threshold")} and {meta_data.get("max_threshold")}'
- )
- raise AirflowException(error_msg)
-
- self.log.info("Test %s Successful.", self.task_id)
-
- def push(self, meta_data):
- """
- Optional: Send data check info and metadata to an external database.
- Default functionality will log metadata.
- """
- info = "\n".join(f"""{key}: {item}""" for key, item in meta_data.items())
- self.log.info("Log from %s:\n%s", self.dag_id, info)
-
-
-class BranchSQLOperator(BaseSQLOperator, SkipMixin):
- """
- Allows a DAG to "branch" or follow a specified path based on the results of a SQL query.
-
- :param sql: The SQL code to be executed, should return true or false (templated)
- Template reference are recognized by str ending in '.sql'.
- Expected SQL query to return Boolean (True/False), integer (0 = False, Otherwise = 1)
- or string (true/y/yes/1/on/false/n/no/0/off).
- :param follow_task_ids_if_true: task id or task ids to follow if query returns true
- :param follow_task_ids_if_false: task id or task ids to follow if query returns false
- :param conn_id: the connection ID used to connect to the database.
- :param database: name of database which overwrite the defined one in connection
- :param parameters: (optional) the parameters to render the SQL query with.
- """
-
- template_fields: Sequence[str] = ("sql",)
- template_ext: Sequence[str] = (".sql",)
- template_fields_renderers = {"sql": "sql"}
- ui_color = "#a22034"
- ui_fgcolor = "#F7F7F7"
-
- def __init__(
- self,
- *,
- sql: str,
- follow_task_ids_if_true: List[str],
- follow_task_ids_if_false: List[str],
- conn_id: str = "default_conn_id",
- database: Optional[str] = None,
- parameters: Optional[Union[Mapping, Iterable]] = None,
- **kwargs,
- ) -> None:
- super().__init__(conn_id=conn_id, database=database, **kwargs)
- self.sql = sql
- self.parameters = parameters
- self.follow_task_ids_if_true = follow_task_ids_if_true
- self.follow_task_ids_if_false = follow_task_ids_if_false
-
- def execute(self, context: Context):
- self.log.info(
- "Executing: %s (with parameters %s) with connection: %s",
- self.sql,
- self.parameters,
- self.conn_id,
- )
- record = self.get_db_hook().get_first(self.sql, self.parameters)
- if not record:
- raise AirflowException(
- "No rows returned from sql query. Operator expected True or False return value."
- )
-
- if isinstance(record, list):
- if isinstance(record[0], list):
- query_result = record[0][0]
- else:
- query_result = record[0]
- elif isinstance(record, tuple):
- query_result = record[0]
- else:
- query_result = record
-
- self.log.info("Query returns %s, type '%s'", query_result, type(query_result))
-
- follow_branch = None
- try:
- if isinstance(query_result, bool):
- if query_result:
- follow_branch = self.follow_task_ids_if_true
- elif isinstance(query_result, str):
- # return result is not Boolean, try to convert from String to Boolean
- if parse_boolean(query_result):
- follow_branch = self.follow_task_ids_if_true
- elif isinstance(query_result, int):
- if bool(query_result):
- follow_branch = self.follow_task_ids_if_true
- else:
- raise AirflowException(
- f"Unexpected query return result '{query_result}' type '{type(query_result)}'"
- )
-
- if follow_branch is None:
- follow_branch = self.follow_task_ids_if_false
- except ValueError:
- raise AirflowException(
- f"Unexpected query return result '{query_result}' type '{type(query_result)}'"
- )
-
- self.skip_all_except(context["ti"], follow_branch)
diff --git a/airflow/operators/sql_branch_operator.py b/airflow/operators/sql_branch_operator.py
deleted file mode 100644
index 5987bce61003f..0000000000000
--- a/airflow/operators/sql_branch_operator.py
+++ /dev/null
@@ -1,40 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.operators.sql`."""
-import warnings
-
-from airflow.operators.sql import BranchSQLOperator
-
-warnings.warn(
- "This module is deprecated. Please use :mod:`airflow.operators.sql`.", DeprecationWarning, stacklevel=2
-)
-
-
-class BranchSqlOperator(BranchSQLOperator):
- """
- This class is deprecated.
- Please use `airflow.operators.sql.BranchSQLOperator`.
- """
-
- def __init__(self, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.operators.sql.BranchSQLOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(**kwargs)
diff --git a/airflow/operators/sqlite_operator.py b/airflow/operators/sqlite_operator.py
deleted file mode 100644
index 68791d69846c0..0000000000000
--- a/airflow/operators/sqlite_operator.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.sqlite.operators.sqlite`."""
-
-import warnings
-
-from airflow.providers.sqlite.operators.sqlite import SqliteOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.sqlite.operators.sqlite`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/operators/subdag.py b/airflow/operators/subdag.py
index bd81314dda8fd..1b3f0049788cc 100644
--- a/airflow/operators/subdag.py
+++ b/airflow/operators/subdag.py
@@ -19,16 +19,16 @@
This module is deprecated. Please use :mod:`airflow.utils.task_group`.
The module which provides a way to nest your DAGs and so your levels of complexity.
"""
+from __future__ import annotations
import warnings
from datetime import datetime
from enum import Enum
-from typing import Dict, Optional, Tuple
from sqlalchemy.orm.session import Session
from airflow.api.common.experimental.get_task_instance import get_task_instance
-from airflow.exceptions import AirflowException, TaskInstanceNotFound
+from airflow.exceptions import AirflowException, RemovedInAirflow3Warning, TaskInstanceNotFound
from airflow.models import DagRun
from airflow.models.dag import DAG, DagContext
from airflow.models.pool import Pool
@@ -43,8 +43,8 @@
class SkippedStatePropagationOptions(Enum):
"""Available options for skipped state propagation of subdag's tasks to parent dag tasks."""
- ALL_LEAVES = 'all_leaves'
- ANY_LEAF = 'any_leaf'
+ ALL_LEAVES = "all_leaves"
+ ANY_LEAF = "any_leaf"
class SubDagOperator(BaseSensorOperator):
@@ -66,10 +66,10 @@ class SubDagOperator(BaseSensorOperator):
parent dag's downstream task.
"""
- ui_color = '#555'
- ui_fgcolor = '#fff'
+ ui_color = "#555"
+ ui_fgcolor = "#fff"
- subdag: "DAG"
+ subdag: DAG
@provide_session
def __init__(
@@ -77,8 +77,8 @@ def __init__(
*,
subdag: DAG,
session: Session = NEW_SESSION,
- conf: Optional[Dict] = None,
- propagate_skipped_state: Optional[SkippedStatePropagationOptions] = None,
+ conf: dict | None = None,
+ propagate_skipped_state: SkippedStatePropagationOptions | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -91,17 +91,17 @@ def __init__(
warnings.warn(
"""This class is deprecated. Please use `airflow.utils.task_group.TaskGroup`.""",
- DeprecationWarning,
+ RemovedInAirflow3Warning,
stacklevel=4,
)
def _validate_dag(self, kwargs):
- dag = kwargs.get('dag') or DagContext.get_current_dag()
+ dag = kwargs.get("dag") or DagContext.get_current_dag()
if not dag:
- raise AirflowException('Please pass in the `dag` param or call within a DAG context manager')
+ raise AirflowException("Please pass in the `dag` param or call within a DAG context manager")
- if dag.dag_id + '.' + kwargs['task_id'] != self.subdag.dag_id:
+ if dag.dag_id + "." + kwargs["task_id"] != self.subdag.dag_id:
raise AirflowException(
f"The subdag's dag_id should have the form '{{parent_dag_id}}.{{this_task_id}}'. "
f"Expected '{dag.dag_id}.{kwargs['task_id']}'; received '{self.subdag.dag_id}'."
@@ -152,14 +152,14 @@ def _reset_dag_run_and_task_instances(self, dag_run, execution_date):
def pre_execute(self, context):
super().pre_execute(context)
- execution_date = context['execution_date']
+ execution_date = context["execution_date"]
dag_run = self._get_dagrun(execution_date)
if dag_run is None:
- if context['data_interval_start'] is None or context['data_interval_end'] is None:
- data_interval: Optional[Tuple[datetime, datetime]] = None
+ if context["data_interval_start"] is None or context["data_interval_end"] is None:
+ data_interval: tuple[datetime, datetime] | None = None
else:
- data_interval = (context['data_interval_start'], context['data_interval_end'])
+ data_interval = (context["data_interval_start"], context["data_interval_end"])
dag_run = self.subdag.create_dagrun(
run_type=DagRunType.SCHEDULED,
execution_date=execution_date,
@@ -175,13 +175,13 @@ def pre_execute(self, context):
self._reset_dag_run_and_task_instances(dag_run, execution_date)
def poke(self, context: Context):
- execution_date = context['execution_date']
+ execution_date = context["execution_date"]
dag_run = self._get_dagrun(execution_date=execution_date)
return dag_run.state != State.RUNNING
def post_execute(self, context, result=None):
super().post_execute(context)
- execution_date = context['execution_date']
+ execution_date = context["execution_date"]
dag_run = self._get_dagrun(execution_date=execution_date)
self.log.info("Execution finished. State is %s", dag_run.state)
@@ -192,14 +192,14 @@ def post_execute(self, context, result=None):
self._skip_downstream_tasks(context)
def _check_skipped_states(self, context):
- leaves_tis = self._get_leaves_tis(context['execution_date'])
+ leaves_tis = self._get_leaves_tis(context["execution_date"])
if self.propagate_skipped_state == SkippedStatePropagationOptions.ANY_LEAF:
return any(ti.state == State.SKIPPED for ti in leaves_tis)
if self.propagate_skipped_state == SkippedStatePropagationOptions.ALL_LEAVES:
return all(ti.state == State.SKIPPED for ti in leaves_tis)
raise AirflowException(
- f'Unimplemented SkippedStatePropagationOptions {self.propagate_skipped_state} used.'
+ f"Unimplemented SkippedStatePropagationOptions {self.propagate_skipped_state} used."
)
def _get_leaves_tis(self, execution_date):
@@ -216,15 +216,15 @@ def _get_leaves_tis(self, execution_date):
def _skip_downstream_tasks(self, context):
self.log.info(
- 'Skipping downstream tasks because propagate_skipped_state is set to %s '
- 'and skipped task(s) were found.',
+ "Skipping downstream tasks because propagate_skipped_state is set to %s "
+ "and skipped task(s) were found.",
self.propagate_skipped_state,
)
- downstream_tasks = context['task'].downstream_list
- self.log.debug('Downstream task_ids %s', downstream_tasks)
+ downstream_tasks = context["task"].downstream_list
+ self.log.debug("Downstream task_ids %s", downstream_tasks)
if downstream_tasks:
- self.skip(context['dag_run'], context['execution_date'], downstream_tasks)
+ self.skip(context["dag_run"], context["execution_date"], downstream_tasks)
- self.log.info('Done.')
+ self.log.info("Done.")
diff --git a/airflow/operators/subdag_operator.py b/airflow/operators/subdag_operator.py
deleted file mode 100644
index bb5a088d23b6d..0000000000000
--- a/airflow/operators/subdag_operator.py
+++ /dev/null
@@ -1,26 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.operators.subdag`."""
-
-import warnings
-
-from airflow.operators.subdag import SkippedStatePropagationOptions, SubDagOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.operators.subdag`.", DeprecationWarning, stacklevel=2
-)
diff --git a/airflow/operators/trigger_dagrun.py b/airflow/operators/trigger_dagrun.py
index 0689f14c56261..b687c05dc4b97 100644
--- a/airflow/operators/trigger_dagrun.py
+++ b/airflow/operators/trigger_dagrun.py
@@ -15,15 +15,19 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
import datetime
import json
import time
-from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Union, cast
+from typing import TYPE_CHECKING, Sequence, cast
from airflow.api.common.trigger_dag import trigger_dag
from airflow.exceptions import AirflowException, DagNotFound, DagRunAlreadyExists
-from airflow.models import BaseOperator, BaseOperatorLink, DagBag, DagModel, DagRun
+from airflow.models.baseoperator import BaseOperator, BaseOperatorLink
+from airflow.models.dag import DagModel
+from airflow.models.dagbag import DagBag
+from airflow.models.dagrun import DagRun
from airflow.models.xcom import XCom
from airflow.utils import timezone
from airflow.utils.context import Context
@@ -36,7 +40,6 @@
if TYPE_CHECKING:
- from airflow.models.abstractoperator import AbstractOperator
from airflow.models.taskinstance import TaskInstanceKey
@@ -46,14 +49,9 @@ class TriggerDagRunLink(BaseOperatorLink):
DAG triggered by task using TriggerDagRunOperator.
"""
- name = 'Triggered DAG'
+ name = "Triggered DAG"
- def get_link(
- self,
- operator: "AbstractOperator",
- *,
- ti_key: "TaskInstanceKey",
- ) -> str:
+ def get_link(self, operator: BaseOperator, *, ti_key: TaskInstanceKey) -> str:
# Fetch the correct execution date for the triggerED dag which is
# stored in xcom during execution of the triggerING task.
when = XCom.get_value(ti_key=ti_key, key=XCOM_EXECUTION_DATE_ISO)
@@ -72,6 +70,8 @@ class TriggerDagRunOperator(BaseOperator):
:param execution_date: Execution date for the dag (templated).
:param reset_dag_run: Whether or not clear existing dag run if already exists.
This is useful when backfill or rerun an existing dag run.
+ This only resets (not recreates) the dag run.
+ Dag run conf is immutable and will not be reset on rerun of an existing dag run.
When reset_dag_run=False and dag run exists, DagRunAlreadyExists will be raised.
When reset_dag_run=True and dag run exists, existing dag run will be cleared to rerun.
:param wait_for_completion: Whether or not wait for dag run completion. (default: False)
@@ -79,29 +79,26 @@ class TriggerDagRunOperator(BaseOperator):
(default: 60)
:param allowed_states: List of allowed states, default is ``['success']``.
:param failed_states: List of failed or dis-allowed states, default is ``None``.
+ :param notes: Set a custom note for the newly created DagRun.
"""
template_fields: Sequence[str] = ("trigger_dag_id", "trigger_run_id", "execution_date", "conf")
template_fields_renderers = {"conf": "py"}
ui_color = "#ffefeb"
-
- @property
- def operator_extra_links(self):
- """Return operator extra links"""
- return [TriggerDagRunLink()]
+ operator_extra_links = [TriggerDagRunLink()]
def __init__(
self,
*,
trigger_dag_id: str,
- trigger_run_id: Optional[str] = None,
- conf: Optional[Dict] = None,
- execution_date: Optional[Union[str, datetime.datetime]] = None,
+ trigger_run_id: str | None = None,
+ conf: dict | None = None,
+ execution_date: str | datetime.datetime | None = None,
reset_dag_run: bool = False,
wait_for_completion: bool = False,
poke_interval: int = 60,
- allowed_states: Optional[List] = None,
- failed_states: Optional[List] = None,
+ allowed_states: list | None = None,
+ failed_states: list | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -121,11 +118,6 @@ def __init__(
self.execution_date = execution_date
- try:
- json.dumps(self.conf)
- except TypeError:
- raise AirflowException("conf parameter should be JSON Serializable")
-
def execute(self, context: Context):
if isinstance(self.execution_date, datetime.datetime):
parsed_execution_date = self.execution_date
@@ -134,6 +126,11 @@ def execute(self, context: Context):
else:
parsed_execution_date = timezone.utcnow()
+ try:
+ json.dumps(self.conf)
+ except TypeError:
+ raise AirflowException("conf parameter should be JSON Serializable")
+
if self.trigger_run_id:
run_id = self.trigger_run_id
else:
@@ -160,14 +157,14 @@ def execute(self, context: Context):
dag_bag = DagBag(dag_folder=dag_model.fileloc, read_dags_from_db=True)
dag = dag_bag.get_dag(self.trigger_dag_id)
dag.clear(start_date=parsed_execution_date, end_date=parsed_execution_date)
- dag_run = DagRun.find(dag_id=dag.dag_id, run_id=run_id)[0]
+ dag_run = e.dag_run
else:
raise e
if dag_run is None:
raise RuntimeError("The dag_run should be set here!")
# Store the execution date from the dag run (either created or found above) to
# be used when creating the extra link on the webserver.
- ti = context['task_instance']
+ ti = context["task_instance"]
ti.xcom_push(key=XCOM_EXECUTION_DATE_ISO, value=dag_run.execution_date.isoformat())
ti.xcom_push(key=XCOM_RUN_ID, value=dag_run.run_id)
@@ -175,7 +172,7 @@ def execute(self, context: Context):
# wait for dag to complete
while True:
self.log.info(
- 'Waiting for %s on %s to become allowed state %s ...',
+ "Waiting for %s on %s to become allowed state %s ...",
self.trigger_dag_id,
dag_run.execution_date,
self.allowed_states,
diff --git a/airflow/operators/weekday.py b/airflow/operators/weekday.py
index b23d57e9fb1d4..2b4e0f4698df6 100644
--- a/airflow/operators/weekday.py
+++ b/airflow/operators/weekday.py
@@ -15,9 +15,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
import warnings
-from typing import Iterable, Union
+from typing import Iterable
+from airflow.exceptions import RemovedInAirflow3Warning
from airflow.operators.branch import BaseBranchOperator
from airflow.utils import timezone
from airflow.utils.context import Context
@@ -30,6 +33,40 @@ class BranchDayOfWeekOperator(BaseBranchOperator):
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:BranchDayOfWeekOperator`
+ **Example** (with single day): ::
+
+ from airflow.operators.empty import EmptyOperator
+
+ monday = EmptyOperator(task_id='monday')
+ other_day = EmptyOperator(task_id='other_day')
+
+ monday_check = DayOfWeekSensor(
+ task_id='monday_check',
+ week_day='Monday',
+ use_task_logical_date=True,
+ follow_task_ids_if_true='monday',
+ follow_task_ids_if_false='other_day',
+ dag=dag)
+ monday_check >> [monday, other_day]
+
+ **Example** (with :class:`~airflow.utils.weekday.WeekDay` enum): ::
+
+ # import WeekDay Enum
+ from airflow.utils.weekday import WeekDay
+ from airflow.operators.empty import EmptyOperator
+
+ workday = EmptyOperator(task_id='workday')
+ weekend = EmptyOperator(task_id='weekend')
+ weekend_check = BranchDayOfWeekOperator(
+ task_id='weekend_check',
+ week_day={WeekDay.SATURDAY, WeekDay.SUNDAY},
+ use_task_logical_date=True,
+ follow_task_ids_if_true='weekend',
+ follow_task_ids_if_false='workday',
+ dag=dag)
+ # add downstream dependencies as you would do with any branch operator
+ weekend_check >> [workday, weekend]
+
:param follow_task_ids_if_true: task id or task ids to follow if criteria met
:param follow_task_ids_if_false: task id or task ids to follow if criteria does not met
:param week_day: Day of the week to check (full name). Optionally, a set
@@ -41,17 +78,20 @@ class BranchDayOfWeekOperator(BaseBranchOperator):
* ``{WeekDay.TUESDAY}``
* ``{WeekDay.SATURDAY, WeekDay.SUNDAY}``
+ To use `WeekDay` enum, import it from `airflow.utils.weekday`
+
:param use_task_logical_date: If ``True``, uses task's logical date to compare
with is_today. Execution Date is Useful for backfilling.
If ``False``, uses system's day of the week.
+ :param use_task_execution_day: deprecated parameter, same effect as `use_task_logical_date`
"""
def __init__(
self,
*,
- follow_task_ids_if_true: Union[str, Iterable[str]],
- follow_task_ids_if_false: Union[str, Iterable[str]],
- week_day: Union[str, Iterable[str]],
+ follow_task_ids_if_true: str | Iterable[str],
+ follow_task_ids_if_false: str | Iterable[str],
+ week_day: str | Iterable[str] | WeekDay | Iterable[WeekDay],
use_task_logical_date: bool = False,
use_task_execution_day: bool = False,
**kwargs,
@@ -65,12 +105,12 @@ def __init__(
self.use_task_logical_date = use_task_execution_day
warnings.warn(
"Parameter ``use_task_execution_day`` is deprecated. Use ``use_task_logical_date``.",
- DeprecationWarning,
+ RemovedInAirflow3Warning,
stacklevel=2,
)
self._week_day_num = WeekDay.validate_week_day(week_day)
- def choose_branch(self, context: Context) -> Union[str, Iterable[str]]:
+ def choose_branch(self, context: Context) -> str | Iterable[str]:
if self.use_task_logical_date:
now = context["logical_date"]
else:
diff --git a/airflow/plugins_manager.py b/airflow/plugins_manager.py
index 82e295fa1906c..51c74a37c3b1c 100644
--- a/airflow/plugins_manager.py
+++ b/airflow/plugins_manager.py
@@ -16,6 +16,8 @@
# specific language governing permissions and limitations
# under the License.
"""Manages all plugins."""
+from __future__ import annotations
+
import importlib
import importlib.machinery
import importlib.util
@@ -24,7 +26,7 @@
import os
import sys
import types
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Type
+from typing import TYPE_CHECKING, Any, Iterable
try:
import importlib_metadata
@@ -45,26 +47,26 @@
log = logging.getLogger(__name__)
-import_errors: Dict[str, str] = {}
+import_errors: dict[str, str] = {}
-plugins = None # type: Optional[List[AirflowPlugin]]
+plugins: list[AirflowPlugin] | None = None
# Plugin components to integrate as modules
-registered_hooks: Optional[List['BaseHook']] = None
-macros_modules: Optional[List[Any]] = None
-executors_modules: Optional[List[Any]] = None
+registered_hooks: list[BaseHook] | None = None
+macros_modules: list[Any] | None = None
+executors_modules: list[Any] | None = None
# Plugin components to integrate directly
-admin_views: Optional[List[Any]] = None
-flask_blueprints: Optional[List[Any]] = None
-menu_links: Optional[List[Any]] = None
-flask_appbuilder_views: Optional[List[Any]] = None
-flask_appbuilder_menu_links: Optional[List[Any]] = None
-global_operator_extra_links: Optional[List[Any]] = None
-operator_extra_links: Optional[List[Any]] = None
-registered_operator_link_classes: Optional[Dict[str, Type]] = None
-registered_ti_dep_classes: Optional[Dict[str, Type]] = None
-timetable_classes: Optional[Dict[str, Type["Timetable"]]] = None
+admin_views: list[Any] | None = None
+flask_blueprints: list[Any] | None = None
+menu_links: list[Any] | None = None
+flask_appbuilder_views: list[Any] | None = None
+flask_appbuilder_menu_links: list[Any] | None = None
+global_operator_extra_links: list[Any] | None = None
+operator_extra_links: list[Any] | None = None
+registered_operator_link_classes: dict[str, type] | None = None
+registered_ti_dep_classes: dict[str, type] | None = None
+timetable_classes: dict[str, type[Timetable]] | None = None
"""Mapping of class names to class of OperatorLinks registered by plugins.
Used by the DAG serialization code to only allow specific classes to be created
@@ -113,7 +115,7 @@ class EntryPointSource(AirflowPluginSource):
"""Class used to define Plugins loaded from entrypoint."""
def __init__(self, entrypoint: importlib_metadata.EntryPoint, dist: importlib_metadata.Distribution):
- self.dist = dist.metadata['name']
+ self.dist = dist.metadata["Name"]
self.version = dist.version
self.entrypoint = str(entrypoint)
@@ -131,16 +133,16 @@ class AirflowPluginException(Exception):
class AirflowPlugin:
"""Class used to define AirflowPlugin."""
- name: Optional[str] = None
- source: Optional[AirflowPluginSource] = None
- hooks: List[Any] = []
- executors: List[Any] = []
- macros: List[Any] = []
- admin_views: List[Any] = []
- flask_blueprints: List[Any] = []
- menu_links: List[Any] = []
- appbuilder_views: List[Any] = []
- appbuilder_menu_items: List[Any] = []
+ name: str | None = None
+ source: AirflowPluginSource | None = None
+ hooks: list[Any] = []
+ executors: list[Any] = []
+ macros: list[Any] = []
+ admin_views: list[Any] = []
+ flask_blueprints: list[Any] = []
+ menu_links: list[Any] = []
+ appbuilder_views: list[Any] = []
+ appbuilder_menu_items: list[Any] = []
# A list of global operator extra links that can redirect users to
# external systems. These extra links will be available on the
@@ -148,20 +150,20 @@ class AirflowPlugin:
#
# Note: the global operator extra link can be overridden at each
# operator level.
- global_operator_extra_links: List[Any] = []
+ global_operator_extra_links: list[Any] = []
# A list of operator extra links to override or add operator links
# to existing Airflow Operators.
# These extra links will be available on the task page in form of
# buttons.
- operator_extra_links: List[Any] = []
+ operator_extra_links: list[Any] = []
- ti_deps: List[Any] = []
+ ti_deps: list[Any] = []
# A list of timetable classes that can be used for DAG scheduling.
- timetables: List[Type["Timetable"]] = []
+ timetables: list[type[Timetable]] = []
- listeners: List[ModuleType] = []
+ listeners: list[ModuleType] = []
@classmethod
def validate(cls):
@@ -221,8 +223,8 @@ def load_entrypoint_plugins():
log.debug("Loading plugins from entrypoints")
- for entry_point, dist in entry_points_with_dist('airflow.plugins'):
- log.debug('Importing entry_point plugin %s', entry_point.name)
+ for entry_point, dist in entry_points_with_dist("airflow.plugins"):
+ log.debug("Importing entry_point plugin %s", entry_point.name)
try:
plugin_class = entry_point.load()
if not is_valid_plugin(plugin_class):
@@ -245,7 +247,7 @@ def load_plugins_from_plugin_directory():
if not os.path.isfile(file_path):
continue
mod_name, file_ext = os.path.splitext(os.path.split(file_path)[-1])
- if file_ext != '.py':
+ if file_ext != ".py":
continue
try:
@@ -254,25 +256,25 @@ def load_plugins_from_plugin_directory():
mod = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = mod
loader.exec_module(mod)
- log.debug('Importing plugin module %s', file_path)
+ log.debug("Importing plugin module %s", file_path)
for mod_attr_value in (m for m in mod.__dict__.values() if is_valid_plugin(m)):
plugin_instance = mod_attr_value()
plugin_instance.source = PluginsDirectorySource(file_path)
register_plugin(plugin_instance)
except Exception as e:
- log.exception('Failed to import plugin %s', file_path)
+ log.exception("Failed to import plugin %s", file_path)
import_errors[file_path] = str(e)
-def make_module(name: str, objects: List[Any]):
+def make_module(name: str, objects: list[Any]):
"""Creates new module."""
if not objects:
return None
- log.debug('Creating module %s', name)
+ log.debug("Creating module %s", name)
name = name.lower()
module = types.ModuleType(name)
- module._name = name.split('.')[-1] # type: ignore
+ module._name = name.split(".")[-1] # type: ignore
module._objects = objects # type: ignore
module.__dict__.update((o.__name__, o) for o in objects)
return module
@@ -342,13 +344,13 @@ def initialize_web_ui_plugins():
for plugin in plugins:
flask_appbuilder_views.extend(plugin.appbuilder_views)
flask_appbuilder_menu_links.extend(plugin.appbuilder_menu_items)
- flask_blueprints.extend([{'name': plugin.name, 'blueprint': bp} for bp in plugin.flask_blueprints])
+ flask_blueprints.extend([{"name": plugin.name, "blueprint": bp} for bp in plugin.flask_blueprints])
if (plugin.admin_views and not plugin.appbuilder_views) or (
plugin.menu_links and not plugin.appbuilder_menu_items
):
log.warning(
- "Plugin \'%s\' may not be compatible with the current Airflow version. "
+ "Plugin '%s' may not be compatible with the current Airflow version. "
"Please contact the author of the plugin.",
plugin.name,
)
@@ -450,7 +452,7 @@ def integrate_executor_plugins() -> None:
raise AirflowPluginException("Invalid plugin name")
plugin_name: str = plugin.name
- executors_module = make_module('airflow.executors.' + plugin_name, plugin.executors)
+ executors_module = make_module("airflow.executors." + plugin_name, plugin.executors)
if executors_module:
executors_modules.append(executors_module)
sys.modules[executors_module.__name__] = executors_module
@@ -479,7 +481,7 @@ def integrate_macros_plugins() -> None:
if plugin.name is None:
raise AirflowPluginException("Invalid plugin name")
- macros_module = make_module(f'airflow.macros.{plugin.name}', plugin.macros)
+ macros_module = make_module(f"airflow.macros.{plugin.name}", plugin.macros)
if macros_module:
macros_modules.append(macros_module)
@@ -489,7 +491,7 @@ def integrate_macros_plugins() -> None:
setattr(macros, plugin.name, macros_module)
-def integrate_listener_plugins(listener_manager: "ListenerManager") -> None:
+def integrate_listener_plugins(listener_manager: ListenerManager) -> None:
global plugins
ensure_plugins_loaded()
@@ -503,7 +505,7 @@ def integrate_listener_plugins(listener_manager: "ListenerManager") -> None:
listener_manager.add_listener(listener)
-def get_plugin_info(attrs_to_dump: Optional[Iterable[str]] = None) -> List[Dict[str, Any]]:
+def get_plugin_info(attrs_to_dump: Iterable[str] | None = None) -> list[dict[str, Any]]:
"""
Dump plugins attributes
@@ -519,23 +521,23 @@ def get_plugin_info(attrs_to_dump: Optional[Iterable[str]] = None) -> List[Dict[
plugins_info = []
if plugins:
for plugin in plugins:
- info: Dict[str, Any] = {"name": plugin.name}
+ info: dict[str, Any] = {"name": plugin.name}
for attr in attrs_to_dump:
- if attr in ('global_operator_extra_links', 'operator_extra_links'):
+ if attr in ("global_operator_extra_links", "operator_extra_links"):
info[attr] = [
- f'<{as_importable_string(d.__class__)} object>' for d in getattr(plugin, attr)
+ f"<{as_importable_string(d.__class__)} object>" for d in getattr(plugin, attr)
]
- elif attr in ('macros', 'timetables', 'hooks', 'executors'):
+ elif attr in ("macros", "timetables", "hooks", "executors"):
info[attr] = [as_importable_string(d) for d in getattr(plugin, attr)]
- elif attr == 'listeners':
+ elif attr == "listeners":
# listeners are always modules
info[attr] = [d.__name__ for d in getattr(plugin, attr)]
- elif attr == 'appbuilder_views':
+ elif attr == "appbuilder_views":
info[attr] = [
- {**d, 'view': as_importable_string(d['view'].__class__) if 'view' in d else None}
+ {**d, "view": as_importable_string(d["view"].__class__) if "view" in d else None}
for d in getattr(plugin, attr)
]
- elif attr == 'flask_blueprints':
+ elif attr == "flask_blueprints":
info[attr] = [
(
f"<{as_importable_string(d.__class__)}: "
diff --git a/airflow/provider.yaml.schema.json b/airflow/provider.yaml.schema.json
index d34fcce95e7c2..ff0537db32cf4 100644
--- a/airflow/provider.yaml.schema.json
+++ b/airflow/provider.yaml.schema.json
@@ -21,8 +21,8 @@
"type": "string"
}
},
- "additional-dependencies": {
- "description": "Additional dependencies that should be added to the provider",
+ "dependencies": {
+ "description": "Dependencies that should be added to the provider",
"type": "array",
"items": {
"type": "string"
@@ -194,17 +194,6 @@
]
}
},
- "hook-class-names": {
- "type": "array",
- "description": "Hook class names that provide connection types to core (deprecated by connection-types)",
- "items": {
- "type": "string"
- },
- "deprecated": {
- "description": "The hook-class-names property has been deprecated in favour of connection-types which is more performant version allowing to only import individual Hooks rather than all hooks at once",
- "deprecatedVersion": "2.2"
- }
- },
"connection-types": {
"type": "array",
"description": "Array of connection types mapped to hook class names",
@@ -219,9 +208,12 @@
"description": "Hook class name that implements the connection type",
"type": "string"
}
- }
- },
- "required": ["connection-type", "hook-class-name"]
+ },
+ "required": [
+ "connection-type",
+ "hook-class-name"
+ ]
+ }
},
"extra-links": {
"type": "array",
@@ -231,8 +223,26 @@
}
},
"additional-extras": {
- "type": "object",
- "description": "Additional extras that the provider should have"
+ "type": "array",
+ "items": {
+ "type": "object",
+ "properties": {
+ "name": {
+ "description": "Name of the extra",
+ "type": "string"
+ },
+ "dependencies": {
+ "description": "Dependencies that should be added for the extra",
+ "type": "array",
+ "items": {
+ "type": "string"
+ }
+ }
+ },
+ "required": [ "name", "dependencies"]
+ },
+
+ "description": "Additional extras that the provider should have. Replaces auto-generated cross-provider extras, if matching the same prefix, so that you can specify boundaries for existing dependencies."
},
"task-decorators": {
"type": "array",
@@ -273,6 +283,7 @@
"name",
"package-name",
"description",
+ "dependencies",
"versions"
]
}
diff --git a/airflow/providers/airbyte/.latest-doc-only-change.txt b/airflow/providers/airbyte/.latest-doc-only-change.txt
index 28124098645cf..ff7136e07d744 100644
--- a/airflow/providers/airbyte/.latest-doc-only-change.txt
+++ b/airflow/providers/airbyte/.latest-doc-only-change.txt
@@ -1 +1 @@
-6c3a67d4fccafe4ab6cd9ec8c7bacf2677f17038
+06acf40a4337759797f666d5bb27a5a393b74fed
diff --git a/airflow/providers/airbyte/CHANGELOG.rst b/airflow/providers/airbyte/CHANGELOG.rst
index deb0e79e968e0..103cf3179fc46 100644
--- a/airflow/providers/airbyte/CHANGELOG.rst
+++ b/airflow/providers/airbyte/CHANGELOG.rst
@@ -16,9 +16,61 @@
under the License.
+.. NOTE TO CONTRIBUTORS:
+ Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes
+ and you want to add an explanation to the users on how they are supposed to deal with them.
+ The changelog is updated and maintained semi-automatically by release manager.
+
Changelog
---------
+3.2.0
+.....
+
+This release of provider is only available for Airflow 2.3+ as explained in the
+`Apache Airflow providers support policy `_.
+
+Misc
+~~~~
+
+* ``Move min airflow version to 2.3.0 for all providers (#27196)``
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Enable string normalization in python formatting - providers (#27205)``
+ * ``Update docs for September Provider's release (#26731)``
+ * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)``
+ * ``Prepare docs for new providers release (August 2022) (#25618)``
+ * ``AIP-47 - Migrate Airbyte DAGs to new design (#25135)``
+
+3.1.0
+.....
+
+Features
+~~~~~~~~
+
+* ``'AirbyteHook' add cancel job option (#24593)``
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Move provider dependencies to inside provider folders (#24672)``
+ * ``Remove 'hook-class-names' from provider.yaml (#24702)``
+
+3.0.0
+.....
+
+Breaking changes
+~~~~~~~~~~~~~~~~
+
+* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow
+ providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Add explanatory note for contributors about updating Changelog (#24229)``
+ * ``Prepare docs for May 2022 provider's release (#24231)``
+ * ``Update package description to remove double min-airflow specification (#24292)``
+
2.1.4
.....
diff --git a/airflow/providers/airbyte/example_dags/example_airbyte_trigger_job.py b/airflow/providers/airbyte/example_dags/example_airbyte_trigger_job.py
deleted file mode 100644
index 55563ff5e03bd..0000000000000
--- a/airflow/providers/airbyte/example_dags/example_airbyte_trigger_job.py
+++ /dev/null
@@ -1,57 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""Example DAG demonstrating the usage of the AirbyteTriggerSyncOperator."""
-
-from datetime import datetime, timedelta
-
-from airflow import DAG
-from airflow.providers.airbyte.operators.airbyte import AirbyteTriggerSyncOperator
-from airflow.providers.airbyte.sensors.airbyte import AirbyteJobSensor
-
-with DAG(
- dag_id='example_airbyte_operator',
- schedule_interval=None,
- start_date=datetime(2021, 1, 1),
- dagrun_timeout=timedelta(minutes=60),
- tags=['example'],
- catchup=False,
-) as dag:
-
- # [START howto_operator_airbyte_synchronous]
- sync_source_destination = AirbyteTriggerSyncOperator(
- task_id='airbyte_sync_source_dest_example',
- connection_id='15bc3800-82e4-48c3-a32d-620661273f28',
- )
- # [END howto_operator_airbyte_synchronous]
-
- # [START howto_operator_airbyte_asynchronous]
- async_source_destination = AirbyteTriggerSyncOperator(
- task_id='airbyte_async_source_dest_example',
- connection_id='15bc3800-82e4-48c3-a32d-620661273f28',
- asynchronous=True,
- )
-
- airbyte_sensor = AirbyteJobSensor(
- task_id='airbyte_sensor_source_dest_example',
- airbyte_job_id=async_source_destination.output,
- )
- # [END howto_operator_airbyte_asynchronous]
-
- # Task dependency created via `XComArgs`:
- # async_source_destination >> airbyte_sensor
diff --git a/airflow/providers/airbyte/hooks/airbyte.py b/airflow/providers/airbyte/hooks/airbyte.py
index b1f6317530514..0ececccdfa907 100644
--- a/airflow/providers/airbyte/hooks/airbyte.py
+++ b/airflow/providers/airbyte/hooks/airbyte.py
@@ -15,8 +15,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
import time
-from typing import Any, Optional, Union
+from typing import Any
from airflow.exceptions import AirflowException
from airflow.providers.http.hooks.http import HttpHook
@@ -31,10 +33,10 @@ class AirbyteHook(HttpHook):
:param api_version: Optional. Airbyte API version.
"""
- conn_name_attr = 'airbyte_conn_id'
- default_conn_name = 'airbyte_default'
- conn_type = 'airbyte'
- hook_name = 'Airbyte'
+ conn_name_attr = "airbyte_conn_id"
+ default_conn_name = "airbyte_default"
+ conn_type = "airbyte"
+ hook_name = "Airbyte"
RUNNING = "running"
SUCCEEDED = "succeeded"
@@ -48,9 +50,7 @@ def __init__(self, airbyte_conn_id: str = "airbyte_default", api_version: str =
super().__init__(http_conn_id=airbyte_conn_id)
self.api_version: str = api_version
- def wait_for_job(
- self, job_id: Union[str, int], wait_seconds: float = 3, timeout: Optional[float] = 3600
- ) -> None:
+ def wait_for_job(self, job_id: str | int, wait_seconds: float = 3, timeout: float | None = 3600) -> None:
"""
Helper method which polls a job to check if it finishes.
@@ -107,21 +107,33 @@ def get_job(self, job_id: int) -> Any:
headers={"accept": "application/json"},
)
+ def cancel_job(self, job_id: int) -> Any:
+ """
+ Cancel the job when task is cancelled
+
+ :param job_id: Required. Id of the Airbyte job
+ """
+ return self.run(
+ endpoint=f"api/{self.api_version}/jobs/cancel",
+ json={"id": job_id},
+ headers={"accept": "application/json"},
+ )
+
def test_connection(self):
"""Tests the Airbyte connection by hitting the health API"""
- self.method = 'GET'
+ self.method = "GET"
try:
res = self.run(
endpoint=f"api/{self.api_version}/health",
headers={"accept": "application/json"},
- extra_options={'check_response': False},
+ extra_options={"check_response": False},
)
if res.status_code == 200:
- return True, 'Connection successfully tested'
+ return True, "Connection successfully tested"
else:
return False, res.text
except Exception as e:
return False, str(e)
finally:
- self.method = 'POST'
+ self.method = "POST"
diff --git a/airflow/providers/airbyte/operators/airbyte.py b/airflow/providers/airbyte/operators/airbyte.py
index ef2e2c1559902..e5ebc443e6d2d 100644
--- a/airflow/providers/airbyte/operators/airbyte.py
+++ b/airflow/providers/airbyte/operators/airbyte.py
@@ -15,7 +15,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import TYPE_CHECKING, Optional, Sequence
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Sequence
from airflow.models import BaseOperator
from airflow.providers.airbyte.hooks.airbyte import AirbyteHook
@@ -45,16 +47,16 @@ class AirbyteTriggerSyncOperator(BaseOperator):
Only used when ``asynchronous`` is False.
"""
- template_fields: Sequence[str] = ('connection_id',)
+ template_fields: Sequence[str] = ("connection_id",)
def __init__(
self,
connection_id: str,
airbyte_conn_id: str = "airbyte_default",
- asynchronous: Optional[bool] = False,
+ asynchronous: bool | None = False,
api_version: str = "v1",
wait_seconds: float = 3,
- timeout: Optional[float] = 3600,
+ timeout: float | None = 3600,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -65,16 +67,22 @@ def __init__(
self.wait_seconds = wait_seconds
self.asynchronous = asynchronous
- def execute(self, context: 'Context') -> None:
+ def execute(self, context: Context) -> None:
"""Create Airbyte Job and wait to finish"""
- hook = AirbyteHook(airbyte_conn_id=self.airbyte_conn_id, api_version=self.api_version)
- job_object = hook.submit_sync_connection(connection_id=self.connection_id)
- job_id = job_object.json()['job']['id']
+ self.hook = AirbyteHook(airbyte_conn_id=self.airbyte_conn_id, api_version=self.api_version)
+ job_object = self.hook.submit_sync_connection(connection_id=self.connection_id)
+ self.job_id = job_object.json()["job"]["id"]
- self.log.info("Job %s was submitted to Airbyte Server", job_id)
+ self.log.info("Job %s was submitted to Airbyte Server", self.job_id)
if not self.asynchronous:
- self.log.info('Waiting for job %s to complete', job_id)
- hook.wait_for_job(job_id=job_id, wait_seconds=self.wait_seconds, timeout=self.timeout)
- self.log.info('Job %s completed successfully', job_id)
+ self.log.info("Waiting for job %s to complete", self.job_id)
+ self.hook.wait_for_job(job_id=self.job_id, wait_seconds=self.wait_seconds, timeout=self.timeout)
+ self.log.info("Job %s completed successfully", self.job_id)
+
+ return self.job_id
- return job_id
+ def on_kill(self):
+ """Cancel the job if task is cancelled"""
+ if self.job_id:
+ self.log.info("on_kill: cancel the airbyte Job %s", self.job_id)
+ self.hook.cancel_job(self.job_id)
diff --git a/airflow/providers/airbyte/provider.yaml b/airflow/providers/airbyte/provider.yaml
index a292b5378b14e..2d164460e6cf9 100644
--- a/airflow/providers/airbyte/provider.yaml
+++ b/airflow/providers/airbyte/provider.yaml
@@ -22,6 +22,9 @@ description: |
`Airbyte `__
versions:
+ - 3.2.0
+ - 3.1.0
+ - 3.0.0
- 2.1.4
- 2.1.3
- 2.1.2
@@ -30,8 +33,9 @@ versions:
- 2.0.0
- 1.0.0
-additional-dependencies:
- - apache-airflow>=2.1.0
+dependencies:
+ - apache-airflow>=2.3.0
+ - apache-airflow-providers-http
integrations:
- integration-name: Airbyte
@@ -56,9 +60,6 @@ sensors:
python-modules:
- airflow.providers.airbyte.sensors.airbyte
-hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+
- - airflow.providers.airbyte.hooks.airbyte.AirbyteHook
-
connection-types:
- hook-class-name: airflow.providers.airbyte.hooks.airbyte.AirbyteHook
connection-type: airbyte
diff --git a/airflow/providers/airbyte/sensors/airbyte.py b/airflow/providers/airbyte/sensors/airbyte.py
index 10c5954ee3a79..f1c03c90a35b1 100644
--- a/airflow/providers/airbyte/sensors/airbyte.py
+++ b/airflow/providers/airbyte/sensors/airbyte.py
@@ -16,6 +16,8 @@
# specific language governing permissions and limitations
# under the License.
"""This module contains a Airbyte Job sensor."""
+from __future__ import annotations
+
from typing import TYPE_CHECKING, Sequence
from airflow.exceptions import AirflowException
@@ -36,14 +38,14 @@ class AirbyteJobSensor(BaseSensorOperator):
:param api_version: Optional. Airbyte API version.
"""
- template_fields: Sequence[str] = ('airbyte_job_id',)
- ui_color = '#6C51FD'
+ template_fields: Sequence[str] = ("airbyte_job_id",)
+ ui_color = "#6C51FD"
def __init__(
self,
*,
airbyte_job_id: int,
- airbyte_conn_id: str = 'airbyte_default',
+ airbyte_conn_id: str = "airbyte_default",
api_version: str = "v1",
**kwargs,
) -> None:
@@ -52,10 +54,10 @@ def __init__(
self.airbyte_job_id = airbyte_job_id
self.api_version = api_version
- def poke(self, context: 'Context') -> bool:
+ def poke(self, context: Context) -> bool:
hook = AirbyteHook(airbyte_conn_id=self.airbyte_conn_id, api_version=self.api_version)
job = hook.get_job(job_id=self.airbyte_job_id)
- status = job.json()['job']['status']
+ status = job.json()["job"]["status"]
if status == hook.FAILED:
raise AirflowException(f"Job failed: \n{job}")
diff --git a/airflow/providers/alibaba/CHANGELOG.rst b/airflow/providers/alibaba/CHANGELOG.rst
index 3eceea9636fe2..265b65f634550 100644
--- a/airflow/providers/alibaba/CHANGELOG.rst
+++ b/airflow/providers/alibaba/CHANGELOG.rst
@@ -16,9 +16,83 @@
under the License.
+.. NOTE TO CONTRIBUTORS:
+ Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes
+ and you want to add an explanation to the users on how they are supposed to deal with them.
+ The changelog is updated and maintained semi-automatically by release manager.
+
Changelog
---------
+2.2.0
+.....
+
+This release of provider is only available for Airflow 2.3+ as explained in the
+`Apache Airflow providers support policy `_.
+
+Misc
+~~~~
+
+* ``Move min airflow version to 2.3.0 for all providers (#27196)``
+* ``Use log.exception where more economical than log.error (#27517)``
+* ``Replace urlparse with urlsplit (#27389)``
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Update old style typing (#26872)``
+ * ``Enable string normalization in python formatting - providers (#27205)``
+
+2.1.0
+.....
+
+Features
+~~~~~~~~
+
+* ``Auto tail file logs in Web UI (#26169)``
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)``
+
+2.0.1
+.....
+
+Bug Fixes
+~~~~~~~~~
+
+* ``Update providers to use functools compat for ''cached_property'' (#24582)``
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Move provider dependencies to inside provider folders (#24672)``
+ * ``Remove 'hook-class-names' from provider.yaml (#24702)``
+
+2.0.0
+.....
+
+Breaking changes
+~~~~~~~~~~~~~~~~
+
+* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow
+ providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers
+
+Features
+~~~~~~~~
+
+ * ``SSL Bucket, Light Logic Refactor and Docstring Update for Alibaba Provider (#23891)``
+
+Misc
+~~~~
+
+ * ``Apply per-run log templates to log handlers (#24153)``
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Migrate Alibaba example DAGs to new design #22437 (#24130)``
+ * ``Add explanatory note for contributors about updating Changelog (#24229)``
+ * ``Prepare docs for May 2022 provider's release (#24231)``
+ * ``Update package description to remove double min-airflow specification (#24292)``
+
1.1.1
.....
diff --git a/airflow/providers/alibaba/cloud/example_dags/example_oss_bucket.py b/airflow/providers/alibaba/cloud/example_dags/example_oss_bucket.py
deleted file mode 100644
index 6d16ce3feb23f..0000000000000
--- a/airflow/providers/alibaba/cloud/example_dags/example_oss_bucket.py
+++ /dev/null
@@ -1,41 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-# Ignore missing args provided by default_args
-# type: ignore[call-arg]
-
-from datetime import datetime
-
-from airflow.models.dag import DAG
-from airflow.providers.alibaba.cloud.operators.oss import OSSCreateBucketOperator, OSSDeleteBucketOperator
-
-# [START howto_operator_oss_bucket]
-with DAG(
- dag_id='oss_bucket_dag',
- start_date=datetime(2021, 1, 1),
- default_args={'bucket_name': 'your bucket', 'region': 'your region'},
- max_active_runs=1,
- tags=['example'],
- catchup=False,
-) as dag:
-
- create_bucket = OSSCreateBucketOperator(task_id='task1')
-
- delete_bucket = OSSDeleteBucketOperator(task_id='task2')
-
- create_bucket >> delete_bucket
-# [END howto_operator_oss_bucket]
diff --git a/airflow/providers/alibaba/cloud/example_dags/example_oss_object.py b/airflow/providers/alibaba/cloud/example_dags/example_oss_object.py
deleted file mode 100644
index ac3a07674492f..0000000000000
--- a/airflow/providers/alibaba/cloud/example_dags/example_oss_object.py
+++ /dev/null
@@ -1,61 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-# Ignore missing args provided by default_args
-# type: ignore[call-arg]
-
-from datetime import datetime
-
-from airflow.models.dag import DAG
-from airflow.providers.alibaba.cloud.operators.oss import (
- OSSDeleteBatchObjectOperator,
- OSSDeleteObjectOperator,
- OSSDownloadObjectOperator,
- OSSUploadObjectOperator,
-)
-
-with DAG(
- dag_id='oss_object_dag',
- start_date=datetime(2021, 1, 1),
- default_args={'bucket_name': 'your bucket', 'region': 'your region'},
- max_active_runs=1,
- tags=['example'],
- catchup=False,
-) as dag:
-
- create_object = OSSUploadObjectOperator(
- file='your local file',
- key='your oss key',
- task_id='task1',
- )
-
- download_object = OSSDownloadObjectOperator(
- file='your local file',
- key='your oss key',
- task_id='task2',
- )
-
- delete_object = OSSDeleteObjectOperator(
- key='your oss key',
- task_id='task3',
- )
-
- delete_batch_object = OSSDeleteBatchObjectOperator(
- keys=['obj1', 'obj2', 'obj3'],
- task_id='task4',
- )
-
- create_object >> download_object >> delete_object >> delete_batch_object
diff --git a/airflow/providers/alibaba/cloud/hooks/oss.py b/airflow/providers/alibaba/cloud/hooks/oss.py
index 08272adb25e70..ab29958a00ba7 100644
--- a/airflow/providers/alibaba/cloud/hooks/oss.py
+++ b/airflow/providers/alibaba/cloud/hooks/oss.py
@@ -15,10 +15,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
from functools import wraps
from inspect import signature
-from typing import TYPE_CHECKING, Callable, Optional, TypeVar, cast
-from urllib.parse import urlparse
+from typing import TYPE_CHECKING, Callable, TypeVar, cast
+from urllib.parse import urlsplit
import oss2
from oss2.exceptions import ClientError
@@ -43,10 +45,10 @@ def provide_bucket_name(func: T) -> T:
def wrapper(*args, **kwargs) -> T:
bound_args = function_signature.bind(*args, **kwargs)
self = args[0]
- if bound_args.arguments.get('bucket_name') is None and self.oss_conn_id:
+ if bound_args.arguments.get("bucket_name") is None and self.oss_conn_id:
connection = self.get_connection(self.oss_conn_id)
if connection.schema:
- bound_args.arguments['bucket_name'] = connection.schema
+ bound_args.arguments["bucket_name"] = connection.schema
return func(*bound_args.args, **bound_args.kwargs)
@@ -65,13 +67,13 @@ def wrapper(*args, **kwargs) -> T:
bound_args = function_signature.bind(*args, **kwargs)
def get_key() -> str:
- if 'key' in bound_args.arguments:
- return 'key'
- raise ValueError('Missing key parameter!')
+ if "key" in bound_args.arguments:
+ return "key"
+ raise ValueError("Missing key parameter!")
key_name = get_key()
- if 'bucket_name' not in bound_args.arguments or bound_args.arguments['bucket_name'] is None:
- bound_args.arguments['bucket_name'], bound_args.arguments['key'] = OSSHook.parse_oss_url(
+ if "bucket_name" not in bound_args.arguments or bound_args.arguments["bucket_name"] is None:
+ bound_args.arguments["bucket_name"], bound_args.arguments["key"] = OSSHook.parse_oss_url(
bound_args.arguments[key_name]
)
@@ -83,18 +85,18 @@ def get_key() -> str:
class OSSHook(BaseHook):
"""Interact with Alibaba Cloud OSS, using the oss2 library."""
- conn_name_attr = 'alibabacloud_conn_id'
- default_conn_name = 'oss_default'
- conn_type = 'oss'
- hook_name = 'OSS'
+ conn_name_attr = "alibabacloud_conn_id"
+ default_conn_name = "oss_default"
+ conn_type = "oss"
+ hook_name = "OSS"
- def __init__(self, region: Optional[str] = None, oss_conn_id='oss_default', *args, **kwargs) -> None:
+ def __init__(self, region: str | None = None, oss_conn_id="oss_default", *args, **kwargs) -> None:
self.oss_conn_id = oss_conn_id
self.oss_conn = self.get_connection(oss_conn_id)
self.region = self.get_default_region() if region is None else region
super().__init__(*args, **kwargs)
- def get_conn(self) -> "Connection":
+ def get_conn(self) -> Connection:
"""Returns connection for the hook."""
return self.oss_conn
@@ -106,26 +108,25 @@ def parse_oss_url(ossurl: str) -> tuple:
:param ossurl: The OSS Url to parse.
:return: the parsed bucket name and key
"""
- parsed_url = urlparse(ossurl)
+ parsed_url = urlsplit(ossurl)
if not parsed_url.netloc:
raise AirflowException(f'Please provide a bucket_name instead of "{ossurl}"')
bucket_name = parsed_url.netloc
- key = parsed_url.path.lstrip('/')
+ key = parsed_url.path.lstrip("/")
return bucket_name, key
@provide_bucket_name
@unify_bucket_name_and_key
- def object_exists(self, key: str, bucket_name: Optional[str] = None) -> bool:
+ def object_exists(self, key: str, bucket_name: str | None = None) -> bool:
"""
Check if object exists.
:param key: the path of the object
:param bucket_name: the name of the bucket
:return: True if it exists and False if not.
- :rtype: bool
"""
try:
return self.get_bucket(bucket_name).object_exists(key)
@@ -134,21 +135,20 @@ def object_exists(self, key: str, bucket_name: Optional[str] = None) -> bool:
return False
@provide_bucket_name
- def get_bucket(self, bucket_name: Optional[str] = None) -> oss2.api.Bucket:
+ def get_bucket(self, bucket_name: str | None = None) -> oss2.api.Bucket:
"""
Returns a oss2.Bucket object
:param bucket_name: the name of the bucket
:return: the bucket object to the bucket name.
- :rtype: oss2.api.Bucket
"""
auth = self.get_credential()
assert self.region is not None
- return oss2.Bucket(auth, f'https://oss-{self.region}.aliyuncs.com', bucket_name)
+ return oss2.Bucket(auth, f"https://oss-{self.region}.aliyuncs.com", bucket_name)
@provide_bucket_name
@unify_bucket_name_and_key
- def load_string(self, key: str, content: str, bucket_name: Optional[str] = None) -> None:
+ def load_string(self, key: str, content: str, bucket_name: str | None = None) -> None:
"""
Loads a string to OSS
@@ -167,7 +167,7 @@ def upload_local_file(
self,
key: str,
file: str,
- bucket_name: Optional[str] = None,
+ bucket_name: str | None = None,
) -> None:
"""
Upload a local file to OSS
@@ -187,8 +187,8 @@ def download_file(
self,
key: str,
local_file: str,
- bucket_name: Optional[str] = None,
- ) -> Optional[str]:
+ bucket_name: str | None = None,
+ ) -> str | None:
"""
Download file from OSS
@@ -196,7 +196,6 @@ def download_file(
:param local_file: local path + file name to save.
:param bucket_name: the name of the bucket
:return: the file name.
- :rtype: str
"""
try:
self.get_bucket(bucket_name).get_object_to_file(key, local_file)
@@ -210,7 +209,7 @@ def download_file(
def delete_object(
self,
key: str,
- bucket_name: Optional[str] = None,
+ bucket_name: str | None = None,
) -> None:
"""
Delete object from OSS
@@ -229,7 +228,7 @@ def delete_object(
def delete_objects(
self,
key: list,
- bucket_name: Optional[str] = None,
+ bucket_name: str | None = None,
) -> None:
"""
Delete objects from OSS
@@ -246,7 +245,7 @@ def delete_objects(
@provide_bucket_name
def delete_bucket(
self,
- bucket_name: Optional[str] = None,
+ bucket_name: str | None = None,
) -> None:
"""
Delete bucket from OSS
@@ -262,7 +261,7 @@ def delete_bucket(
@provide_bucket_name
def create_bucket(
self,
- bucket_name: Optional[str] = None,
+ bucket_name: str | None = None,
) -> None:
"""
Create bucket
@@ -277,7 +276,7 @@ def create_bucket(
@provide_bucket_name
@unify_bucket_name_and_key
- def append_string(self, bucket_name: Optional[str], content: str, key: str, pos: int) -> None:
+ def append_string(self, bucket_name: str | None, content: str, key: str, pos: int) -> None:
"""
Append string to a remote existing file
@@ -295,7 +294,7 @@ def append_string(self, bucket_name: Optional[str], content: str, key: str, pos:
@provide_bucket_name
@unify_bucket_name_and_key
- def read_key(self, bucket_name: Optional[str], key: str) -> str:
+ def read_key(self, bucket_name: str | None, key: str) -> str:
"""
Read oss remote object content with the specified key
@@ -311,7 +310,7 @@ def read_key(self, bucket_name: Optional[str], key: str) -> str:
@provide_bucket_name
@unify_bucket_name_and_key
- def head_key(self, bucket_name: Optional[str], key: str) -> oss2.models.HeadObjectResult:
+ def head_key(self, bucket_name: str | None, key: str) -> oss2.models.HeadObjectResult:
"""
Get meta info of the specified remote object
@@ -327,7 +326,7 @@ def head_key(self, bucket_name: Optional[str], key: str) -> oss2.models.HeadObje
@provide_bucket_name
@unify_bucket_name_and_key
- def key_exist(self, bucket_name: Optional[str], key: str) -> bool:
+ def key_exist(self, bucket_name: str | None, key: str) -> bool:
"""
Find out whether the specified key exists in the oss remote storage
@@ -335,7 +334,7 @@ def key_exist(self, bucket_name: Optional[str], key: str) -> bool:
:param key: oss bucket key
"""
# full_path = None
- self.log.info('Looking up oss bucket %s for bucket key %s ...', bucket_name, key)
+ self.log.info("Looking up oss bucket %s for bucket key %s ...", bucket_name, key)
try:
return self.get_bucket(bucket_name).object_exists(key)
except Exception as e:
@@ -344,14 +343,14 @@ def key_exist(self, bucket_name: Optional[str], key: str) -> bool:
def get_credential(self) -> oss2.auth.Auth:
extra_config = self.oss_conn.extra_dejson
- auth_type = extra_config.get('auth_type', None)
+ auth_type = extra_config.get("auth_type", None)
if not auth_type:
raise Exception("No auth_type specified in extra_config. ")
- if auth_type != 'AK':
+ if auth_type != "AK":
raise Exception(f"Unsupported auth_type: {auth_type}")
- oss_access_key_id = extra_config.get('access_key_id', None)
- oss_access_key_secret = extra_config.get('access_key_secret', None)
+ oss_access_key_id = extra_config.get("access_key_id", None)
+ oss_access_key_secret = extra_config.get("access_key_secret", None)
if not oss_access_key_id:
raise Exception(f"No access_key_id is specified for connection: {self.oss_conn_id}")
@@ -360,16 +359,16 @@ def get_credential(self) -> oss2.auth.Auth:
return oss2.Auth(oss_access_key_id, oss_access_key_secret)
- def get_default_region(self) -> Optional[str]:
+ def get_default_region(self) -> str | None:
extra_config = self.oss_conn.extra_dejson
- auth_type = extra_config.get('auth_type', None)
+ auth_type = extra_config.get("auth_type", None)
if not auth_type:
raise Exception("No auth_type specified in extra_config. ")
- if auth_type != 'AK':
+ if auth_type != "AK":
raise Exception(f"Unsupported auth_type: {auth_type}")
- default_region = extra_config.get('region', None)
+ default_region = extra_config.get("region", None)
if not default_region:
raise Exception(f"No region is specified for connection: {self.oss_conn_id}")
return default_region
diff --git a/airflow/providers/alibaba/cloud/log/oss_task_handler.py b/airflow/providers/alibaba/cloud/log/oss_task_handler.py
index bb1d801ac9821..c443b4e014392 100644
--- a/airflow/providers/alibaba/cloud/log/oss_task_handler.py
+++ b/airflow/providers/alibaba/cloud/log/oss_task_handler.py
@@ -15,16 +15,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
import contextlib
import os
import pathlib
-import sys
-
-if sys.version_info >= (3, 8):
- from functools import cached_property
-else:
- from cached_property import cached_property
+from airflow.compat.functools import cached_property
from airflow.configuration import conf
from airflow.providers.alibaba.cloud.hooks.oss import OSSHook
from airflow.utils.log.file_task_handler import FileTaskHandler
@@ -38,27 +35,27 @@ class OSSTaskHandler(FileTaskHandler, LoggingMixin):
uploads to and reads from OSS remote storage.
"""
- def __init__(self, base_log_folder, oss_log_folder, filename_template):
+ def __init__(self, base_log_folder, oss_log_folder, filename_template=None):
self.log.info("Using oss_task_handler for remote logging...")
super().__init__(base_log_folder, filename_template)
(self.bucket_name, self.base_folder) = OSSHook.parse_oss_url(oss_log_folder)
- self.log_relative_path = ''
+ self.log_relative_path = ""
self._hook = None
self.closed = False
self.upload_on_close = True
@cached_property
def hook(self):
- remote_conn_id = conf.get('logging', 'REMOTE_LOG_CONN_ID')
+ remote_conn_id = conf.get("logging", "REMOTE_LOG_CONN_ID")
self.log.info("remote_conn_id: %s", remote_conn_id)
try:
return OSSHook(oss_conn_id=remote_conn_id)
except Exception as e:
- self.log.error(e, exc_info=True)
+ self.log.exception(e)
self.log.error(
'Could not create an OSSHook with connection id "%s". '
- 'Please make sure that airflow[oss] is installed and '
- 'the OSS connection exists.',
+ "Please make sure that airflow[oss] is installed and "
+ "the OSS connection exists.",
remote_conn_id,
)
@@ -73,7 +70,7 @@ def set_context(self, ti):
# Clear the file first so that duplicate data is not uploaded
# when re-using the same path (e.g. with rescheduled sensors)
if self.upload_on_close:
- with open(self.handler.baseFilename, 'w'):
+ with open(self.handler.baseFilename, "w"):
pass
def close(self):
@@ -117,13 +114,13 @@ def _read(self, ti, try_number, metadata=None):
remote_loc = log_relative_path
if not self.oss_log_exists(remote_loc):
- return super()._read(ti, try_number)
+ return super()._read(ti, try_number, metadata)
# If OSS remote file exists, we do not fetch logs from task instance
# local machine even if there are errors reading remote logs, as
# returned remote_log will contain error messages.
remote_log = self.oss_read(remote_loc, return_error=True)
- log = f'*** Reading remote log from {remote_loc}.\n{remote_log}\n'
- return log, {'end_of_log': True}
+ log = f"*** Reading remote log from {remote_loc}.\n{remote_log}\n"
+ return log, {"end_of_log": True}
def oss_log_exists(self, remote_log_location):
"""
@@ -132,7 +129,7 @@ def oss_log_exists(self, remote_log_location):
:param remote_log_location: log's location in remote storage
:return: True if location exists else False
"""
- oss_remote_log_location = f'{self.base_folder}/{remote_log_location}'
+ oss_remote_log_location = f"{self.base_folder}/{remote_log_location}"
with contextlib.suppress(Exception):
return self.hook.key_exist(self.bucket_name, oss_remote_log_location)
return False
@@ -147,11 +144,11 @@ def oss_read(self, remote_log_location, return_error=False):
error occurs. Otherwise returns '' when an error occurs.
"""
try:
- oss_remote_log_location = f'{self.base_folder}/{remote_log_location}'
+ oss_remote_log_location = f"{self.base_folder}/{remote_log_location}"
self.log.info("read remote log: %s", oss_remote_log_location)
return self.hook.read_key(self.bucket_name, oss_remote_log_location)
except Exception:
- msg = f'Could not read logs from {oss_remote_log_location}'
+ msg = f"Could not read logs from {oss_remote_log_location}"
self.log.exception(msg)
# return error if needed
if return_error:
@@ -167,7 +164,7 @@ def oss_write(self, log, remote_log_location, append=True):
:param append: if False, any existing log file is overwritten. If True,
the new log is appended to any existing logs.
"""
- oss_remote_log_location = f'{self.base_folder}/{remote_log_location}'
+ oss_remote_log_location = f"{self.base_folder}/{remote_log_location}"
pos = 0
if append and self.oss_log_exists(oss_remote_log_location):
head = self.hook.head_key(self.bucket_name, oss_remote_log_location)
@@ -178,7 +175,7 @@ def oss_write(self, log, remote_log_location, append=True):
self.hook.append_string(self.bucket_name, log, oss_remote_log_location, pos)
except Exception:
self.log.exception(
- 'Could not write logs to %s, log write pos is: %s, Append is %s',
+ "Could not write logs to %s, log write pos is: %s, Append is %s",
oss_remote_log_location,
str(pos),
str(append),
diff --git a/airflow/providers/alibaba/cloud/operators/oss.py b/airflow/providers/alibaba/cloud/operators/oss.py
index 8ec9b4b13975e..2e43529566b03 100644
--- a/airflow/providers/alibaba/cloud/operators/oss.py
+++ b/airflow/providers/alibaba/cloud/operators/oss.py
@@ -15,9 +15,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""This module contains Alibaba Cloud OSS operators."""
-from typing import TYPE_CHECKING, Optional
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
from airflow.models import BaseOperator
from airflow.providers.alibaba.cloud.hooks.oss import OSSHook
@@ -38,8 +39,8 @@ class OSSCreateBucketOperator(BaseOperator):
def __init__(
self,
region: str,
- bucket_name: Optional[str] = None,
- oss_conn_id: str = 'oss_default',
+ bucket_name: str | None = None,
+ oss_conn_id: str = "oss_default",
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -47,7 +48,7 @@ def __init__(
self.region = region
self.bucket_name = bucket_name
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
oss_hook = OSSHook(oss_conn_id=self.oss_conn_id, region=self.region)
oss_hook.create_bucket(bucket_name=self.bucket_name)
@@ -64,8 +65,8 @@ class OSSDeleteBucketOperator(BaseOperator):
def __init__(
self,
region: str,
- bucket_name: Optional[str] = None,
- oss_conn_id: str = 'oss_default',
+ bucket_name: str | None = None,
+ oss_conn_id: str = "oss_default",
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -73,7 +74,7 @@ def __init__(
self.region = region
self.bucket_name = bucket_name
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
oss_hook = OSSHook(oss_conn_id=self.oss_conn_id, region=self.region)
oss_hook.delete_bucket(bucket_name=self.bucket_name)
@@ -94,8 +95,8 @@ def __init__(
key: str,
file: str,
region: str,
- bucket_name: Optional[str] = None,
- oss_conn_id: str = 'oss_default',
+ bucket_name: str | None = None,
+ oss_conn_id: str = "oss_default",
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -105,7 +106,7 @@ def __init__(
self.region = region
self.bucket_name = bucket_name
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
oss_hook = OSSHook(oss_conn_id=self.oss_conn_id, region=self.region)
oss_hook.upload_local_file(bucket_name=self.bucket_name, key=self.key, file=self.file)
@@ -126,8 +127,8 @@ def __init__(
key: str,
file: str,
region: str,
- bucket_name: Optional[str] = None,
- oss_conn_id: str = 'oss_default',
+ bucket_name: str | None = None,
+ oss_conn_id: str = "oss_default",
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -137,7 +138,7 @@ def __init__(
self.region = region
self.bucket_name = bucket_name
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
oss_hook = OSSHook(oss_conn_id=self.oss_conn_id, region=self.region)
oss_hook.download_file(bucket_name=self.bucket_name, key=self.key, local_file=self.file)
@@ -156,8 +157,8 @@ def __init__(
self,
keys: list,
region: str,
- bucket_name: Optional[str] = None,
- oss_conn_id: str = 'oss_default',
+ bucket_name: str | None = None,
+ oss_conn_id: str = "oss_default",
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -166,7 +167,7 @@ def __init__(
self.region = region
self.bucket_name = bucket_name
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
oss_hook = OSSHook(oss_conn_id=self.oss_conn_id, region=self.region)
oss_hook.delete_objects(bucket_name=self.bucket_name, key=self.keys)
@@ -185,8 +186,8 @@ def __init__(
self,
key: str,
region: str,
- bucket_name: Optional[str] = None,
- oss_conn_id: str = 'oss_default',
+ bucket_name: str | None = None,
+ oss_conn_id: str = "oss_default",
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -195,6 +196,6 @@ def __init__(
self.region = region
self.bucket_name = bucket_name
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
oss_hook = OSSHook(oss_conn_id=self.oss_conn_id, region=self.region)
oss_hook.delete_object(bucket_name=self.bucket_name, key=self.key)
diff --git a/airflow/providers/alibaba/cloud/sensors/oss_key.py b/airflow/providers/alibaba/cloud/sensors/oss_key.py
index 0160783178b08..98b2c25beaee1 100644
--- a/airflow/providers/alibaba/cloud/sensors/oss_key.py
+++ b/airflow/providers/alibaba/cloud/sensors/oss_key.py
@@ -15,16 +15,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-import sys
+from __future__ import annotations
-if sys.version_info >= (3, 8):
- from functools import cached_property
-else:
- from cached_property import cached_property
-
-from typing import TYPE_CHECKING, Optional, Sequence
-from urllib.parse import urlparse
+from typing import TYPE_CHECKING, Sequence
+from urllib.parse import urlsplit
+from airflow.compat.functools import cached_property
from airflow.exceptions import AirflowException
from airflow.providers.alibaba.cloud.hooks.oss import OSSHook
from airflow.sensors.base import BaseSensorOperator
@@ -47,14 +43,14 @@ class OSSKeySensor(BaseSensorOperator):
:param oss_conn_id: The Airflow connection used for OSS credentials.
"""
- template_fields: Sequence[str] = ('bucket_key', 'bucket_name')
+ template_fields: Sequence[str] = ("bucket_key", "bucket_name")
def __init__(
self,
bucket_key: str,
region: str,
- bucket_name: Optional[str] = None,
- oss_conn_id: Optional[str] = 'oss_default',
+ bucket_name: str | None = None,
+ oss_conn_id: str | None = "oss_default",
**kwargs,
):
super().__init__(**kwargs)
@@ -63,9 +59,9 @@ def __init__(
self.bucket_key = bucket_key
self.region = region
self.oss_conn_id = oss_conn_id
- self.hook: Optional[OSSHook] = None
+ self.hook: OSSHook | None = None
- def poke(self, context: 'Context'):
+ def poke(self, context: Context):
"""
Check if the object exists in the bucket to pull key.
@param self - the object itself
@@ -73,21 +69,21 @@ def poke(self, context: 'Context'):
@returns True if the object exists, False otherwise
"""
if self.bucket_name is None:
- parsed_url = urlparse(self.bucket_key)
- if parsed_url.netloc == '':
- raise AirflowException('If key is a relative path from root, please provide a bucket_name')
+ parsed_url = urlsplit(self.bucket_key)
+ if parsed_url.netloc == "":
+ raise AirflowException("If key is a relative path from root, please provide a bucket_name")
self.bucket_name = parsed_url.netloc
- self.bucket_key = parsed_url.path.lstrip('/')
+ self.bucket_key = parsed_url.path.lstrip("/")
else:
- parsed_url = urlparse(self.bucket_key)
- if parsed_url.scheme != '' or parsed_url.netloc != '':
+ parsed_url = urlsplit(self.bucket_key)
+ if parsed_url.scheme != "" or parsed_url.netloc != "":
raise AirflowException(
- 'If bucket_name is provided, bucket_key'
- ' should be relative path from root'
- ' level, rather than a full oss:// url'
+ "If bucket_name is provided, bucket_key"
+ " should be relative path from root"
+ " level, rather than a full oss:// url"
)
- self.log.info('Poking for key : oss://%s/%s', self.bucket_name, self.bucket_key)
+ self.log.info("Poking for key : oss://%s/%s", self.bucket_name, self.bucket_key)
return self.get_hook.object_exists(key=self.bucket_key, bucket_name=self.bucket_name)
@cached_property
diff --git a/airflow/providers/alibaba/provider.yaml b/airflow/providers/alibaba/provider.yaml
index ca5e03f1b27d2..cd5cbced35249 100644
--- a/airflow/providers/alibaba/provider.yaml
+++ b/airflow/providers/alibaba/provider.yaml
@@ -22,13 +22,18 @@ description: |
Alibaba Cloud integration (including `Alibaba Cloud `__).
versions:
+ - 2.2.0
+ - 2.1.0
+ - 2.0.1
+ - 2.0.0
- 1.1.1
- 1.1.0
- 1.0.1
- 1.0.0
-additional-dependencies:
- - apache-airflow>=2.1.0
+dependencies:
+ - apache-airflow>=2.3.0
+ - oss2>=2.14.0
integrations:
- integration-name: Alibaba Cloud OSS
@@ -53,8 +58,6 @@ hooks:
python-modules:
- airflow.providers.alibaba.cloud.hooks.oss
-hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+
- - airflow.providers.alibaba.cloud.hooks.oss.OSSHook
connection-types:
- hook-class-name: airflow.providers.alibaba.cloud.hooks.oss.OSSHook
diff --git a/airflow/providers/amazon/CHANGELOG.rst b/airflow/providers/amazon/CHANGELOG.rst
index e8a0f442c457c..d55861f1c6d2d 100644
--- a/airflow/providers/amazon/CHANGELOG.rst
+++ b/airflow/providers/amazon/CHANGELOG.rst
@@ -16,9 +16,300 @@
under the License.
+.. NOTE TO CONTRIBUTORS:
+ Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes
+ and you want to add an explanation to the users on how they are supposed to deal with them.
+ The changelog is updated and maintained semi-automatically by release manager.
+
Changelog
---------
+6.1.0
+.....
+
+This release of provider is only available for Airflow 2.3+ as explained in the
+`Apache Airflow providers support policy `_.
+
+Misc
+~~~~
+
+* ``Move min airflow version to 2.3.0 for all providers (#27196)``
+* ``Replace urlparse with urlsplit (#27389)``
+
+Features
+~~~~~~~~
+
+* ``Add info about JSON Connection format for AWS SSM Parameter Store Secrets Backend (#27134)``
+* ``Add default name to EMR Serverless jobs (#27458)``
+* ``Adding 'preserve_file_name' param to 'S3Hook.download_file' method (#26886)``
+* ``Add GlacierUploadArchiveOperator (#26652)``
+* ``Add RdsStopDbOperator and RdsStartDbOperator (#27076)``
+* ``'GoogleApiToS3Operator' : add 'gcp_conn_id' to template fields (#27017)``
+* ``Add SQLExecuteQueryOperator (#25717)``
+* ``Add information about Amazon Elastic MapReduce Connection (#26687)``
+* ``Add BatchOperator template fields (#26805)``
+* ``Improve testing AWS Connection response (#26953)``
+
+Bug Fixes
+~~~~~~~~~
+
+* ``SagemakerProcessingOperator stopped honoring 'existing_jobs_found' (#27456)``
+* ``CloudWatch task handler doesn't fall back to local logs when Amazon CloudWatch logs aren't found (#27564)``
+* ``Fix backwards compatibility for RedshiftSQLOperator (#27602)``
+* ``Fix typo in redshift sql hook get_ui_field_behaviour (#27533)``
+* ``Fix example_emr_serverless system test (#27149)``
+* ``Fix param in docstring RedshiftSQLHook get_table_primary_key method (#27330)``
+* ``Adds s3_key_prefix to template fields (#27207)``
+* ``Fix assume role if user explicit set credentials (#26946)``
+* ``Fix failure state in waiter call for EmrServerlessStartJobOperator. (#26853)``
+* ``Fix a bunch of deprecation warnings AWS tests (#26857)``
+* ``Fix null strings bug in SqlToS3Operator in non parquet formats (#26676)``
+* ``Sagemaker hook: remove extra call at the end when waiting for completion (#27551)``
+* ``ECS Buglette (#26921)``
+* ``Avoid circular imports in AWS Secrets Backends if obtain secrets from config (#26784)``
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``sagemaker operators: mutualize init of aws_conn_id (#27579)``
+ * ``Upgrade dependencies in order to avoid backtracking (#27531)``
+ * ``Code quality improvements on sagemaker operators/hook (#27453)``
+ * ``Update old style typing (#26872)``
+ * ``System test for SQL to S3 Transfer (AIP-47) (#27097)``
+ * ``Enable string normalization in python formatting - providers (#27205)``
+ * ``Convert emr_eks example dag to system test (#26723)``
+ * ``System test for Dynamo DB (#26729)``
+ * ``ECS System Test (#26808)``
+ * ``RDS Instance System Tests (#26733)``
+
+6.0.0
+.....
+
+Breaking changes
+~~~~~~~~~~~~~~~~
+
+.. warning::
+ In this version of provider Amazon S3 Connection (``conn_type="s3"``) removed due to the fact that it was always
+ an alias to AWS connection ``conn_type="aws"``
+ In practice the only impact is you won't be able to ``test`` the connection in the web UI / API.
+ In order to restore ability to test connection you need to change connection type from **Amazon S3** (``conn_type="s3"``)
+ to **Amazon Web Services** (``conn_type="aws"``) manually.
+
+* ``Remove Amazon S3 Connection Type (#25980)``
+
+Features
+~~~~~~~~
+
+* ``Add RdsDbSensor to amazon provider package (#26003)``
+* ``Set template_fields on RDS operators (#26005)``
+* ``Auto tail file logs in Web UI (#26169)``
+
+Bug Fixes
+~~~~~~~~~
+
+* ``Fix SageMakerEndpointConfigOperator's return value (#26541)``
+* ``EMR Serverless Fix for Jobs marked as success even on failure (#26218)``
+* ``Fix AWS Connection warn condition for invalid 'profile_name' argument (#26464)``
+* ``Athena and EMR operator max_retries mix-up fix (#25971)``
+* ``Fixes SageMaker operator return values (#23628)``
+* ``Remove redundant catch exception in Amazon Log Task Handlers (#26442)``
+
+Misc
+~~~~
+
+* ``Remove duplicated connection-type within the provider (#26628)``
+
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Redshift to S3 and S3 to Redshift System test (AIP-47) (#26613)``
+ * ``Convert example_eks_with_fargate_in_one_step.py and example_eks_with_fargate_profile to AIP-47 (#26537)``
+ * ``Redshift System Test (AIP-47) (#26187)``
+ * ``GoogleAPIToS3Operator System Test (AIP-47) (#26370)``
+ * ``Convert EKS with Nodegroups sample DAG to a system test (AIP-47) (#26539)``
+ * ``Convert EC2 sample DAG to system test (#26540)``
+ * ``Convert S3 example DAG to System test (AIP-47) (#26535)``
+ * ``Convert 'example_eks_with_nodegroup_in_one_step' sample DAG to system test (AIP-47) (#26410)``
+ * ``Migrate DMS sample dag to system test (#26270)``
+ * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)``
+ * ``D400 first line should end with period batch02 (#25268)``
+ * ``Change links to 'boto3' documentation (#26708)``
+
+5.1.0
+.....
+
+
+Features
+~~~~~~~~
+
+* ``Additional mask aws credentials (#26014)``
+* ``Add RedshiftDeleteClusterSnapshotOperator (#25975)``
+* ``Add redshift create cluster snapshot operator (#25857)``
+* ``Add common-sql lower bound for common-sql (#25789)``
+* ``Allow AWS Secrets Backends use AWS Connection capabilities (#25628)``
+* ``Implement 'EmrEksCreateClusterOperator' (#25816)``
+* ``Improve error handling/messaging around bucket exist check (#25805)``
+
+Bug Fixes
+~~~~~~~~~
+
+* ``Fix display aws connection info (#26025)``
+* ``Fix 'EcsBaseOperator' and 'EcsBaseSensor' arguments (#25989)``
+* ``Fix RDS system test (#25839)``
+* ``Avoid circular import problems when instantiating AWS SM backend (#25810)``
+* ``fix bug construction of Connection object in version 5.0.0rc3 (#25716)``
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Fix EMR serverless system test (#25969)``
+ * ``Add 'output' property to MappedOperator (#25604)``
+ * ``Add Airflow specific warning classes (#25799)``
+ * ``Replace SQL with Common SQL in pre commit (#26058)``
+ * ``Hook into Mypy to get rid of those cast() (#26023)``
+ * ``Raise an error on create bucket if use regional endpoint for us-east-1 and region not set (#25945)``
+ * ``Update AWS system tests to use SystemTestContextBuilder (#25748)``
+ * ``Convert Quicksight Sample DAG to System Test (#25696)``
+ * ``Consolidate to one 'schedule' param (#25410)``
+
+5.0.0
+.....
+
+Breaking changes
+~~~~~~~~~~~~~~~~
+
+* ``Avoid requirement that AWS Secret Manager JSON values be urlencoded. (#25432)``
+* ``Remove deprecated modules (#25543)``
+* ``Resolve Amazon Hook's 'region_name' and 'config' in wrapper (#25336)``
+* ``Resolve and validate AWS Connection parameters in wrapper (#25256)``
+* ``Standardize AwsLambda (#25100)``
+* ``Refactor monolithic ECS Operator into Operators, Sensors, and a Hook (#25413)``
+* ``Remove deprecated modules from Amazon provider package (#25609)``
+
+Features
+~~~~~~~~
+
+* ``Add EMR Serverless Operators and Hooks (#25324)``
+* ``Hide unused fields for Amazon Web Services connection (#25416)``
+* ``Enable Auto-incrementing Transform job name in SageMakerTransformOperator (#25263)``
+* ``Unify DbApiHook.run() method with the methods which override it (#23971)``
+* ``SQSPublishOperator should allow sending messages to a FIFO Queue (#25171)``
+* ``Glue Job Driver logging (#25142)``
+* ``Bump typing-extensions and mypy for ParamSpec (#25088)``
+* ``Enable multiple query execution in RedshiftDataOperator (#25619)``
+
+Bug Fixes
+~~~~~~~~~
+
+* ``Fix S3Hook transfer config arguments validation (#25544)``
+* ``Fix BatchOperator links on wait_for_completion = True (#25228)``
+* ``Makes changes to SqlToS3Operator method _fix_int_dtypes (#25083)``
+* ``refactor: Deprecate parameter 'host' as an extra attribute for the connection. Depreciation is happening in favor of 'endpoint_url' in extra. (#25494)``
+* ``Get boto3.session.Session by appropriate method (#25569)``
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``System test for EMR Serverless (#25559)``
+ * ``Convert Local to S3 example DAG to System Test (AIP-47) (#25345)``
+ * ``Convert ECS Fargate Sample DAG to System Test (#25316)``
+ * ``Sagemaker System Tests - Part 3 of 3 - example_sagemaker_endpoint.py (AIP-47) (#25134)``
+ * ``Convert RDS Export Sample DAG to System Test (AIP-47) (#25205)``
+ * ``AIP-47 - Migrate redshift DAGs to new design #22438 (#24239)``
+ * ``Convert Glue Sample DAG to System Test (#25136)``
+ * ``Convert the batch sample dag to system tests (AIP-47) (#24448)``
+ * ``Migrate datasync sample dag to system tests (AIP-47) (#24354)``
+ * ``Sagemaker System Tests - Part 2 of 3 - example_sagemaker.py (#25079)``
+ * ``Migrate lambda sample dag to system test (AIP-47) (#24355)``
+ * ``SageMaker system tests - Part 1 of 3 - Prep Work (AIP-47) (#25078)``
+ * ``Prepare docs for new providers release (August 2022) (#25618)``
+
+4.1.0
+.....
+
+Features
+~~~~~~~~
+
+* ``Add test_connection method to AWS hook (#24662)``
+* ``Add AWS operators to create and delete RDS Database (#24099)``
+* ``Add batch option to 'SqsSensor' (#24554)``
+* ``Add AWS Batch & AWS CloudWatch Extra Links (#24406)``
+* ``Refactoring EmrClusterLink and add for other AWS EMR Operators (#24294)``
+* ``Move all SQL classes to common-sql provider (#24836)``
+* ``Amazon appflow (#24057)``
+* ``Make extra_args in S3Hook immutable between calls (#24527)``
+
+Bug Fixes
+~~~~~~~~~
+
+* ``Refactor and fix AWS secret manager invalid exception (#24898)``
+* ``fix: RedshiftDataHook and RdsHook not use cached connection (#24387)``
+* ``Fix links to sources for examples (#24386)``
+* ``Fix S3KeySensor. See #24321 (#24378)``
+* ``Fix: 'emr_conn_id' should be optional in 'EmrCreateJobFlowOperator' (#24306)``
+* ``Update providers to use functools compat for ''cached_property'' (#24582)``
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Convert RDS Event and Snapshot Sample DAGs to System Tests (#24932)``
+ * ``Convert Step Functions Example DAG to System Test (AIP-47) (#24643)``
+ * ``Update AWS Connection docs and deprecate some extras (#24670)``
+ * ``Remove 'xcom_push' flag from providers (#24823)``
+ * ``Align Black and blacken-docs configs (#24785)``
+ * ``Restore Optional value of script_location (#24754)``
+ * ``Move provider dependencies to inside provider folders (#24672)``
+ * ``Use our yaml util in all providers (#24720)``
+ * ``Remove 'hook-class-names' from provider.yaml (#24702)``
+ * ``Convert SQS Sample DAG to System Test (#24513)``
+ * ``Convert Cloudformation Sample DAG to System Test (#24447)``
+ * ``Convert SNS Sample DAG to System Test (#24384)``
+
+4.0.0
+.....
+
+Breaking changes
+~~~~~~~~~~~~~~~~
+
+* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow
+ providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers
+
+Features
+~~~~~~~~
+
+* ``Add partition related methods to GlueCatalogHook: (#23857)``
+* ``Add support for associating custom tags to job runs submitted via EmrContainerOperator (#23769)``
+* ``Add number of node params only for single-node cluster in RedshiftCreateClusterOperator (#23839)``
+
+Bug Fixes
+~~~~~~~~~
+
+* ``fix: StepFunctionHook ignores explicit set 'region_name' (#23976)``
+* ``Fix Amazon EKS example DAG raises warning during Imports (#23849)``
+* ``Move string arg evals to 'execute()' in 'EksCreateClusterOperator' (#23877)``
+* ``fix: patches #24215. Won't raise KeyError when 'create_job_kwargs' contains the 'Command' key. (#24308)``
+
+Misc
+~~~~
+
+* ``Light Refactor and Clean-up AWS Provider (#23907)``
+* ``Update sample dag and doc for RDS (#23651)``
+* ``Reformat the whole AWS documentation (#23810)``
+* ``Replace "absolute()" with "resolve()" in pathlib objects (#23675)``
+* ``Apply per-run log templates to log handlers (#24153)``
+* ``Refactor GlueJobHook get_or_create_glue_job method. (#24215)``
+* ``Update the DMS Sample DAG and Docs (#23681)``
+* ``Update doc and sample dag for Quicksight (#23653)``
+* ``Update doc and sample dag for EMR Containers (#24087)``
+* ``Add AWS project structure tests (re: AIP-47) (#23630)``
+* ``Add doc and sample dag for GCSToS3Operator (#23730)``
+* ``Remove old Athena Sample DAG (#24170)``
+* ``Clean up f-strings in logging calls (#23597)``
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Add explanatory note for contributors about updating Changelog (#24229)``
+ * ``Introduce 'flake8-implicit-str-concat' plugin to static checks (#23873)``
+ * ``pydocstyle D202 added (#24221)``
+ * ``Prepare docs for May 2022 provider's release (#24231)``
+ * ``Update package description to remove double min-airflow specification (#24292)``
+
3.4.0
.....
@@ -66,7 +357,6 @@ Misc
* ``Bump pre-commit hook versions (#22887)``
* ``Use new Breese for building, pulling and verifying the images. (#23104)``
-.. Review and move the new changes to one of the sections above:
3.3.0
.....
diff --git a/airflow/providers/amazon/aws/example_dags/example_appflow.py b/airflow/providers/amazon/aws/example_dags/example_appflow.py
new file mode 100644
index 0000000000000..c986961499dc8
--- /dev/null
+++ b/airflow/providers/amazon/aws/example_dags/example_appflow.py
@@ -0,0 +1,110 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from datetime import datetime
+
+from airflow import DAG
+from airflow.models.baseoperator import chain
+from airflow.operators.bash import BashOperator
+from airflow.providers.amazon.aws.operators.appflow import (
+ AppflowRecordsShortCircuitOperator,
+ AppflowRunAfterOperator,
+ AppflowRunBeforeOperator,
+ AppflowRunDailyOperator,
+ AppflowRunFullOperator,
+ AppflowRunOperator,
+)
+
+SOURCE_NAME = "salesforce"
+FLOW_NAME = "salesforce-campaign"
+
+with DAG(
+ "example_appflow",
+ start_date=datetime(2022, 1, 1),
+ catchup=False,
+ tags=["example"],
+) as dag:
+
+ # [START howto_operator_appflow_run]
+ campaign_dump = AppflowRunOperator(
+ task_id="campaign_dump",
+ source=SOURCE_NAME,
+ flow_name=FLOW_NAME,
+ )
+ # [END howto_operator_appflow_run]
+
+ # [START howto_operator_appflow_run_full]
+ campaign_dump_full = AppflowRunFullOperator(
+ task_id="campaign_dump_full",
+ source=SOURCE_NAME,
+ flow_name=FLOW_NAME,
+ )
+ # [END howto_operator_appflow_run_full]
+
+ # [START howto_operator_appflow_run_daily]
+ campaign_dump_daily = AppflowRunDailyOperator(
+ task_id="campaign_dump_daily",
+ source=SOURCE_NAME,
+ flow_name=FLOW_NAME,
+ source_field="LastModifiedDate",
+ filter_date="{{ ds }}",
+ )
+ # [END howto_operator_appflow_run_daily]
+
+ # [START howto_operator_appflow_run_before]
+ campaign_dump_before = AppflowRunBeforeOperator(
+ task_id="campaign_dump_before",
+ source=SOURCE_NAME,
+ flow_name=FLOW_NAME,
+ source_field="LastModifiedDate",
+ filter_date="{{ ds }}",
+ )
+ # [END howto_operator_appflow_run_before]
+
+ # [START howto_operator_appflow_run_after]
+ campaign_dump_after = AppflowRunAfterOperator(
+ task_id="campaign_dump_after",
+ source=SOURCE_NAME,
+ flow_name=FLOW_NAME,
+ source_field="LastModifiedDate",
+ filter_date="3000-01-01", # Future date, so no records to dump
+ )
+ # [END howto_operator_appflow_run_after]
+
+ # [START howto_operator_appflow_shortcircuit]
+ campaign_dump_short_circuit = AppflowRecordsShortCircuitOperator(
+ task_id="campaign_dump_short_circuit",
+ flow_name=FLOW_NAME,
+ appflow_run_task_id="campaign_dump_after", # Should shortcircuit, no records expected
+ )
+ # [END howto_operator_appflow_shortcircuit]
+
+ should_be_skipped = BashOperator(
+ task_id="should_be_skipped",
+ bash_command="echo 1",
+ )
+
+ chain(
+ campaign_dump,
+ campaign_dump_full,
+ campaign_dump_daily,
+ campaign_dump_before,
+ campaign_dump_after,
+ campaign_dump_short_circuit,
+ should_be_skipped,
+ )
diff --git a/airflow/providers/amazon/aws/example_dags/example_athena.py b/airflow/providers/amazon/aws/example_dags/example_athena.py
deleted file mode 100644
index 26b47bdfa19b5..0000000000000
--- a/airflow/providers/amazon/aws/example_dags/example_athena.py
+++ /dev/null
@@ -1,124 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-from datetime import datetime
-from os import getenv
-
-from airflow import DAG
-from airflow.decorators import task
-from airflow.providers.amazon.aws.hooks.s3 import S3Hook
-from airflow.providers.amazon.aws.operators.athena import AthenaOperator
-from airflow.providers.amazon.aws.operators.s3 import S3CreateObjectOperator, S3DeleteObjectsOperator
-from airflow.providers.amazon.aws.sensors.athena import AthenaSensor
-
-S3_BUCKET = getenv("S3_BUCKET", "test-bucket")
-S3_KEY = getenv('S3_KEY', 'athena-demo')
-ATHENA_TABLE = getenv('ATHENA_TABLE', 'test_table')
-ATHENA_DATABASE = getenv('ATHENA_DATABASE', 'default')
-
-SAMPLE_DATA = """"Alice",20
-"Bob",25
-"Charlie",30
-"""
-SAMPLE_FILENAME = 'airflow_sample.csv'
-
-
-@task
-def read_results_from_s3(query_execution_id):
- s3_hook = S3Hook()
- file_obj = s3_hook.get_conn().get_object(Bucket=S3_BUCKET, Key=f'{S3_KEY}/{query_execution_id}.csv')
- file_content = file_obj['Body'].read().decode('utf-8')
- print(file_content)
-
-
-QUERY_CREATE_TABLE = f"""
-CREATE EXTERNAL TABLE IF NOT EXISTS {ATHENA_DATABASE}.{ATHENA_TABLE} ( `name` string, `age` int )
-ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe'
-WITH SERDEPROPERTIES ( 'serialization.format' = ',', 'field.delim' = ','
-) LOCATION 's3://{S3_BUCKET}/{S3_KEY}/{ATHENA_TABLE}'
-TBLPROPERTIES ('has_encrypted_data'='false')
-"""
-
-QUERY_READ_TABLE = f"""
-SELECT * from {ATHENA_DATABASE}.{ATHENA_TABLE}
-"""
-
-QUERY_DROP_TABLE = f"""
-DROP TABLE IF EXISTS {ATHENA_DATABASE}.{ATHENA_TABLE}
-"""
-
-with DAG(
- dag_id='example_athena',
- schedule_interval=None,
- start_date=datetime(2021, 1, 1),
- tags=['example'],
- catchup=False,
-) as dag:
-
- upload_sample_data = S3CreateObjectOperator(
- task_id='upload_sample_data',
- s3_bucket=S3_BUCKET,
- s3_key=f'{S3_KEY}/{ATHENA_TABLE}/{SAMPLE_FILENAME}',
- data=SAMPLE_DATA,
- replace=True,
- )
-
- create_table = AthenaOperator(
- task_id='create_table',
- query=QUERY_CREATE_TABLE,
- database=ATHENA_DATABASE,
- output_location=f's3://{S3_BUCKET}/{S3_KEY}',
- )
-
- # [START howto_operator_athena]
- read_table = AthenaOperator(
- task_id='read_table',
- query=QUERY_READ_TABLE,
- database=ATHENA_DATABASE,
- output_location=f's3://{S3_BUCKET}/{S3_KEY}',
- )
- # [END howto_operator_athena]
-
- # [START howto_sensor_athena]
- await_query = AthenaSensor(
- task_id='await_query',
- query_execution_id=read_table.output,
- )
- # [END howto_sensor_athena]
-
- drop_table = AthenaOperator(
- task_id='drop_table',
- query=QUERY_DROP_TABLE,
- database=ATHENA_DATABASE,
- output_location=f's3://{S3_BUCKET}/{S3_KEY}',
- )
-
- remove_s3_files = S3DeleteObjectsOperator(
- task_id='remove_s3_files',
- bucket=S3_BUCKET,
- prefix=S3_KEY,
- )
-
- (
- upload_sample_data
- >> create_table
- >> read_table
- >> await_query
- >> read_results_from_s3(read_table.output)
- >> drop_table
- >> remove_s3_files
- )
diff --git a/airflow/providers/amazon/aws/example_dags/example_batch.py b/airflow/providers/amazon/aws/example_dags/example_batch.py
deleted file mode 100644
index 959e4de52d300..0000000000000
--- a/airflow/providers/amazon/aws/example_dags/example_batch.py
+++ /dev/null
@@ -1,66 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from datetime import datetime
-from json import loads
-from os import environ
-
-from airflow import DAG
-from airflow.providers.amazon.aws.operators.batch import BatchOperator
-from airflow.providers.amazon.aws.sensors.batch import BatchSensor
-
-# The inputs below are required for the submit batch example DAG.
-JOB_NAME = environ.get('BATCH_JOB_NAME', 'example_job_name')
-JOB_DEFINITION = environ.get('BATCH_JOB_DEFINITION', 'example_job_definition')
-JOB_QUEUE = environ.get('BATCH_JOB_QUEUE', 'example_job_queue')
-JOB_OVERRIDES = loads(environ.get('BATCH_JOB_OVERRIDES', '{}'))
-
-# An existing (externally triggered) job id is required for the sensor example DAG.
-JOB_ID = environ.get('BATCH_JOB_ID', '00000000-0000-0000-0000-000000000000')
-
-
-with DAG(
- dag_id='example_batch_submit_job',
- schedule_interval=None,
- start_date=datetime(2021, 1, 1),
- tags=['example'],
- catchup=False,
-) as submit_dag:
-
- # [START howto_operator_batch]
- submit_batch_job = BatchOperator(
- task_id='submit_batch_job',
- job_name=JOB_NAME,
- job_queue=JOB_QUEUE,
- job_definition=JOB_DEFINITION,
- overrides=JOB_OVERRIDES,
- )
- # [END howto_operator_batch]
-
-with DAG(
- dag_id='example_batch_wait_for_job_sensor',
- schedule_interval=None,
- start_date=datetime(2021, 1, 1),
- tags=['example'],
- catchup=False,
-) as sensor_dag:
-
- # [START howto_sensor_batch]
- wait_for_batch_job = BatchSensor(
- task_id='wait_for_batch_job',
- job_id=JOB_ID,
- )
- # [END howto_sensor_batch]
diff --git a/airflow/providers/amazon/aws/example_dags/example_cloudformation.py b/airflow/providers/amazon/aws/example_dags/example_cloudformation.py
deleted file mode 100644
index c31d3215ff5d8..0000000000000
--- a/airflow/providers/amazon/aws/example_dags/example_cloudformation.py
+++ /dev/null
@@ -1,83 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from datetime import datetime
-from json import dumps
-
-from airflow import DAG
-from airflow.models.baseoperator import chain
-from airflow.providers.amazon.aws.operators.cloud_formation import (
- CloudFormationCreateStackOperator,
- CloudFormationDeleteStackOperator,
-)
-from airflow.providers.amazon.aws.sensors.cloud_formation import (
- CloudFormationCreateStackSensor,
- CloudFormationDeleteStackSensor,
-)
-
-CLOUDFORMATION_STACK_NAME = 'example-stack-name'
-# The CloudFormation template must have at least one resource to be usable, this example uses SQS
-# as a free and serverless option.
-CLOUDFORMATION_TEMPLATE = {
- 'Description': 'Stack from Airflow CloudFormation example DAG',
- 'Resources': {
- 'ExampleQueue': {
- 'Type': 'AWS::SQS::Queue',
- }
- },
-}
-CLOUDFORMATION_CREATE_PARAMETERS = {
- 'StackName': CLOUDFORMATION_STACK_NAME,
- 'TemplateBody': dumps(CLOUDFORMATION_TEMPLATE),
- 'TimeoutInMinutes': 2,
- 'OnFailure': 'DELETE', # Don't leave stacks behind if creation fails.
-}
-
-with DAG(
- dag_id='example_cloudformation',
- schedule_interval=None,
- start_date=datetime(2021, 1, 1),
- tags=['example'],
- catchup=False,
-) as dag:
-
- # [START howto_operator_cloudformation_create_stack]
- create_stack = CloudFormationCreateStackOperator(
- task_id='create_stack',
- stack_name=CLOUDFORMATION_STACK_NAME,
- cloudformation_parameters=CLOUDFORMATION_CREATE_PARAMETERS,
- )
- # [END howto_operator_cloudformation_create_stack]
-
- # [START howto_sensor_cloudformation_create_stack]
- wait_for_stack_create = CloudFormationCreateStackSensor(
- task_id='wait_for_stack_creation', stack_name=CLOUDFORMATION_STACK_NAME
- )
- # [END howto_sensor_cloudformation_create_stack]
-
- # [START howto_operator_cloudformation_delete_stack]
- delete_stack = CloudFormationDeleteStackOperator(
- task_id='delete_stack', stack_name=CLOUDFORMATION_STACK_NAME
- )
- # [END howto_operator_cloudformation_delete_stack]
-
- # [START howto_sensor_cloudformation_delete_stack]
- wait_for_stack_delete = CloudFormationDeleteStackSensor(
- task_id='wait_for_stack_deletion', trigger_rule='all_done', stack_name=CLOUDFORMATION_STACK_NAME
- )
- # [END howto_sensor_cloudformation_delete_stack]
-
- chain(create_stack, wait_for_stack_create, delete_stack, wait_for_stack_delete)
diff --git a/airflow/providers/amazon/aws/example_dags/example_datasync.py b/airflow/providers/amazon/aws/example_dags/example_datasync.py
deleted file mode 100644
index 09c474079e497..0000000000000
--- a/airflow/providers/amazon/aws/example_dags/example_datasync.py
+++ /dev/null
@@ -1,81 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-import json
-import re
-from datetime import datetime
-from os import getenv
-
-from airflow import models
-from airflow.models.baseoperator import chain
-from airflow.providers.amazon.aws.operators.datasync import DataSyncOperator
-
-TASK_ARN = getenv("TASK_ARN", "my_aws_datasync_task_arn")
-SOURCE_LOCATION_URI = getenv("SOURCE_LOCATION_URI", "smb://hostname/directory/")
-DESTINATION_LOCATION_URI = getenv("DESTINATION_LOCATION_URI", "s3://mybucket/prefix")
-CREATE_TASK_KWARGS = json.loads(getenv("CREATE_TASK_KWARGS", '{"Name": "Created by Airflow"}'))
-CREATE_SOURCE_LOCATION_KWARGS = json.loads(getenv("CREATE_SOURCE_LOCATION_KWARGS", '{}'))
-default_destination_location_kwargs = """\
-{"S3BucketArn": "arn:aws:s3:::mybucket",
- "S3Config": {"BucketAccessRoleArn":
- "arn:aws:iam::11112223344:role/r-11112223344-my-bucket-access-role"}
-}"""
-CREATE_DESTINATION_LOCATION_KWARGS = json.loads(
- getenv("CREATE_DESTINATION_LOCATION_KWARGS", re.sub(r"[\s+]", '', default_destination_location_kwargs))
-)
-UPDATE_TASK_KWARGS = json.loads(getenv("UPDATE_TASK_KWARGS", '{"Name": "Updated by Airflow"}'))
-
-with models.DAG(
- "example_datasync",
- schedule_interval=None, # Override to match your needs
- start_date=datetime(2021, 1, 1),
- catchup=False,
- tags=['example'],
-) as dag:
- # [START howto_operator_datasync_specific_task]
- # Execute a specific task
- datasync_specific_task = DataSyncOperator(task_id="datasync_specific_task", task_arn=TASK_ARN)
- # [END howto_operator_datasync_specific_task]
-
- # [START howto_operator_datasync_search_task]
- # Search and execute a task
- datasync_search_task = DataSyncOperator(
- task_id="datasync_search_task",
- source_location_uri=SOURCE_LOCATION_URI,
- destination_location_uri=DESTINATION_LOCATION_URI,
- )
- # [END howto_operator_datasync_search_task]
-
- # [START howto_operator_datasync_create_task]
- # Create a task (the task does not exist)
- datasync_create_task = DataSyncOperator(
- task_id="datasync_create_task",
- source_location_uri=SOURCE_LOCATION_URI,
- destination_location_uri=DESTINATION_LOCATION_URI,
- create_task_kwargs=CREATE_TASK_KWARGS,
- create_source_location_kwargs=CREATE_SOURCE_LOCATION_KWARGS,
- create_destination_location_kwargs=CREATE_DESTINATION_LOCATION_KWARGS,
- update_task_kwargs=UPDATE_TASK_KWARGS,
- delete_task_after_execution=True,
- )
- # [END howto_operator_datasync_create_task]
-
- chain(
- datasync_specific_task,
- datasync_search_task,
- datasync_create_task,
- )
diff --git a/airflow/providers/amazon/aws/example_dags/example_dms.py b/airflow/providers/amazon/aws/example_dags/example_dms.py
deleted file mode 100644
index caffe44353fda..0000000000000
--- a/airflow/providers/amazon/aws/example_dags/example_dms.py
+++ /dev/null
@@ -1,347 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-Note: DMS requires you to configure specific IAM roles/permissions. For more information, see
-https://docs.aws.amazon.com/dms/latest/userguide/CHAP_Security.html#CHAP_Security.APIRole
-"""
-
-import json
-import os
-from datetime import datetime
-
-import boto3
-from sqlalchemy import Column, MetaData, String, Table, create_engine
-
-from airflow import DAG
-from airflow.decorators import task
-from airflow.models.baseoperator import chain
-from airflow.operators.python import get_current_context
-from airflow.providers.amazon.aws.operators.dms import (
- DmsCreateTaskOperator,
- DmsDeleteTaskOperator,
- DmsDescribeTasksOperator,
- DmsStartTaskOperator,
- DmsStopTaskOperator,
-)
-from airflow.providers.amazon.aws.sensors.dms import DmsTaskBaseSensor, DmsTaskCompletedSensor
-
-S3_BUCKET = os.getenv('S3_BUCKET', 's3_bucket_name')
-ROLE_ARN = os.getenv('ROLE_ARN', 'arn:aws:iam::1234567890:role/s3_target_endpoint_role')
-
-# The project name will be used as a prefix for various entity names.
-# Use either PascalCase or camelCase. While some names require kebab-case
-# and others require snake_case, they all accept mixedCase strings.
-PROJECT_NAME = 'DmsDemo'
-
-# Config values for setting up the "Source" database.
-RDS_ENGINE = 'postgres'
-RDS_PROTOCOL = 'postgresql'
-RDS_USERNAME = 'username'
-# NEVER store your production password in plaintext in a DAG like this.
-# Use Airflow Secrets or a secret manager for this in production.
-RDS_PASSWORD = 'rds_password'
-
-# Config values for RDS.
-RDS_INSTANCE_NAME = f'{PROJECT_NAME}-instance'
-RDS_DB_NAME = f'{PROJECT_NAME}_source_database'
-
-# Config values for DMS.
-DMS_REPLICATION_INSTANCE_NAME = f'{PROJECT_NAME}-replication-instance'
-DMS_REPLICATION_TASK_ID = f'{PROJECT_NAME}-replication-task'
-SOURCE_ENDPOINT_IDENTIFIER = f'{PROJECT_NAME}-source-endpoint'
-TARGET_ENDPOINT_IDENTIFIER = f'{PROJECT_NAME}-target-endpoint'
-
-# Sample data.
-TABLE_NAME = f'{PROJECT_NAME}-table'
-TABLE_HEADERS = ['apache_project', 'release_year']
-SAMPLE_DATA = [
- ('Airflow', '2015'),
- ('OpenOffice', '2012'),
- ('Subversion', '2000'),
- ('NiFi', '2006'),
-]
-TABLE_DEFINITION = {
- 'TableCount': '1',
- 'Tables': [
- {
- 'TableName': TABLE_NAME,
- 'TableColumns': [
- {
- 'ColumnName': TABLE_HEADERS[0],
- 'ColumnType': 'STRING',
- 'ColumnNullable': 'false',
- 'ColumnIsPk': 'true',
- },
- {"ColumnName": TABLE_HEADERS[1], "ColumnType": 'STRING', "ColumnLength": "4"},
- ],
- 'TableColumnsTotal': '2',
- }
- ],
-}
-TABLE_MAPPINGS = {
- 'rules': [
- {
- 'rule-type': 'selection',
- 'rule-id': '1',
- 'rule-name': '1',
- 'object-locator': {
- 'schema-name': 'public',
- 'table-name': TABLE_NAME,
- },
- 'rule-action': 'include',
- }
- ]
-}
-
-
-def _create_rds_instance():
- print('Creating RDS Instance.')
-
- rds_client = boto3.client('rds')
- rds_client.create_db_instance(
- DBName=RDS_DB_NAME,
- DBInstanceIdentifier=RDS_INSTANCE_NAME,
- AllocatedStorage=20,
- DBInstanceClass='db.t3.micro',
- Engine=RDS_ENGINE,
- MasterUsername=RDS_USERNAME,
- MasterUserPassword=RDS_PASSWORD,
- )
-
- rds_client.get_waiter('db_instance_available').wait(DBInstanceIdentifier=RDS_INSTANCE_NAME)
-
- response = rds_client.describe_db_instances(DBInstanceIdentifier=RDS_INSTANCE_NAME)
- return response['DBInstances'][0]['Endpoint']
-
-
-def _create_rds_table(rds_endpoint):
- print('Creating table.')
-
- hostname = rds_endpoint['Address']
- port = rds_endpoint['Port']
- rds_url = f'{RDS_PROTOCOL}://{RDS_USERNAME}:{RDS_PASSWORD}@{hostname}:{port}/{RDS_DB_NAME}'
- engine = create_engine(rds_url)
-
- table = Table(
- TABLE_NAME,
- MetaData(engine),
- Column(TABLE_HEADERS[0], String, primary_key=True),
- Column(TABLE_HEADERS[1], String),
- )
-
- with engine.connect() as connection:
- # Create the Table.
- table.create()
- load_data = table.insert().values(SAMPLE_DATA)
- connection.execute(load_data)
-
- # Read the data back to verify everything is working.
- connection.execute(table.select())
-
-
-def _create_dms_replication_instance(ti, dms_client):
- print('Creating replication instance.')
- instance_arn = dms_client.create_replication_instance(
- ReplicationInstanceIdentifier=DMS_REPLICATION_INSTANCE_NAME,
- ReplicationInstanceClass='dms.t3.micro',
- )['ReplicationInstance']['ReplicationInstanceArn']
-
- ti.xcom_push(key='replication_instance_arn', value=instance_arn)
- return instance_arn
-
-
-def _create_dms_endpoints(ti, dms_client, rds_instance_endpoint):
- print('Creating DMS source endpoint.')
- source_endpoint_arn = dms_client.create_endpoint(
- EndpointIdentifier=SOURCE_ENDPOINT_IDENTIFIER,
- EndpointType='source',
- EngineName=RDS_ENGINE,
- Username=RDS_USERNAME,
- Password=RDS_PASSWORD,
- ServerName=rds_instance_endpoint['Address'],
- Port=rds_instance_endpoint['Port'],
- DatabaseName=RDS_DB_NAME,
- )['Endpoint']['EndpointArn']
-
- print('Creating DMS target endpoint.')
- target_endpoint_arn = dms_client.create_endpoint(
- EndpointIdentifier=TARGET_ENDPOINT_IDENTIFIER,
- EndpointType='target',
- EngineName='s3',
- S3Settings={
- 'BucketName': S3_BUCKET,
- 'BucketFolder': PROJECT_NAME,
- 'ServiceAccessRoleArn': ROLE_ARN,
- 'ExternalTableDefinition': json.dumps(TABLE_DEFINITION),
- },
- )['Endpoint']['EndpointArn']
-
- ti.xcom_push(key='source_endpoint_arn', value=source_endpoint_arn)
- ti.xcom_push(key='target_endpoint_arn', value=target_endpoint_arn)
-
-
-def _await_setup_assets(dms_client, instance_arn):
- print("Awaiting asset provisioning.")
- dms_client.get_waiter('replication_instance_available').wait(
- Filters=[{'Name': 'replication-instance-arn', 'Values': [instance_arn]}]
- )
-
-
-def _delete_rds_instance():
- print('Deleting RDS Instance.')
-
- rds_client = boto3.client('rds')
- rds_client.delete_db_instance(
- DBInstanceIdentifier=RDS_INSTANCE_NAME,
- SkipFinalSnapshot=True,
- )
-
- rds_client.get_waiter('db_instance_deleted').wait(DBInstanceIdentifier=RDS_INSTANCE_NAME)
-
-
-def _delete_dms_assets(dms_client):
- ti = get_current_context()['ti']
- replication_instance_arn = ti.xcom_pull(key='replication_instance_arn')
- source_arn = ti.xcom_pull(key='source_endpoint_arn')
- target_arn = ti.xcom_pull(key='target_endpoint_arn')
-
- print('Deleting DMS assets.')
- dms_client.delete_replication_instance(ReplicationInstanceArn=replication_instance_arn)
- dms_client.delete_endpoint(EndpointArn=source_arn)
- dms_client.delete_endpoint(EndpointArn=target_arn)
-
-
-def _await_all_teardowns(dms_client):
- print('Awaiting tear-down.')
- dms_client.get_waiter('replication_instance_deleted').wait(
- Filters=[{'Name': 'replication-instance-id', 'Values': [DMS_REPLICATION_INSTANCE_NAME]}]
- )
-
- dms_client.get_waiter('endpoint_deleted').wait(
- Filters=[
- {
- 'Name': 'endpoint-id',
- 'Values': [SOURCE_ENDPOINT_IDENTIFIER, TARGET_ENDPOINT_IDENTIFIER],
- }
- ]
- )
-
-
-@task
-def set_up():
- ti = get_current_context()['ti']
- dms_client = boto3.client('dms')
-
- rds_instance_endpoint = _create_rds_instance()
- _create_rds_table(rds_instance_endpoint)
- instance_arn = _create_dms_replication_instance(ti, dms_client)
- _create_dms_endpoints(ti, dms_client, rds_instance_endpoint)
- _await_setup_assets(dms_client, instance_arn)
-
-
-@task(trigger_rule='all_done')
-def clean_up():
- dms_client = boto3.client('dms')
-
- _delete_rds_instance()
- _delete_dms_assets(dms_client)
- _await_all_teardowns(dms_client)
-
-
-with DAG(
- dag_id='example_dms',
- schedule_interval=None,
- start_date=datetime(2021, 1, 1),
- tags=['example'],
- catchup=False,
-) as dag:
-
- # [START howto_operator_dms_create_task]
- create_task = DmsCreateTaskOperator(
- task_id='create_task',
- replication_task_id=DMS_REPLICATION_TASK_ID,
- source_endpoint_arn='{{ ti.xcom_pull(key="source_endpoint_arn") }}',
- target_endpoint_arn='{{ ti.xcom_pull(key="target_endpoint_arn") }}',
- replication_instance_arn='{{ ti.xcom_pull(key="replication_instance_arn") }}',
- table_mappings=TABLE_MAPPINGS,
- )
- # [END howto_operator_dms_create_task]
-
- # [START howto_operator_dms_start_task]
- start_task = DmsStartTaskOperator(
- task_id='start_task',
- replication_task_arn=create_task.output,
- )
- # [END howto_operator_dms_start_task]
-
- # [START howto_operator_dms_describe_tasks]
- describe_tasks = DmsDescribeTasksOperator(
- task_id='describe_tasks',
- describe_tasks_kwargs={
- 'Filters': [
- {
- 'Name': 'replication-instance-arn',
- 'Values': ['{{ ti.xcom_pull(key="replication_instance_arn") }}'],
- }
- ]
- },
- do_xcom_push=False,
- )
- # [END howto_operator_dms_describe_tasks]
-
- await_task_start = DmsTaskBaseSensor(
- task_id='await_task_start',
- replication_task_arn=create_task.output,
- target_statuses=['running'],
- termination_statuses=['stopped', 'deleting', 'failed'],
- )
-
- # [START howto_operator_dms_stop_task]
- stop_task = DmsStopTaskOperator(
- task_id='stop_task',
- replication_task_arn=create_task.output,
- )
- # [END howto_operator_dms_stop_task]
-
- # TaskCompletedSensor actually waits until task reaches the "Stopped" state, so it will work here.
- # [START howto_sensor_dms_task_completed]
- await_task_stop = DmsTaskCompletedSensor(
- task_id='await_task_stop',
- replication_task_arn=create_task.output,
- )
- # [END howto_sensor_dms_task_completed]
-
- # [START howto_operator_dms_delete_task]
- delete_task = DmsDeleteTaskOperator(
- task_id='delete_task',
- replication_task_arn=create_task.output,
- trigger_rule='all_done',
- )
- # [END howto_operator_dms_delete_task]
-
- chain(
- set_up()
- >> create_task
- >> start_task
- >> describe_tasks
- >> await_task_start
- >> stop_task
- >> await_task_stop
- >> delete_task
- >> clean_up()
- )
diff --git a/airflow/providers/amazon/aws/example_dags/example_dynamodb_to_s3.py b/airflow/providers/amazon/aws/example_dags/example_dynamodb_to_s3.py
deleted file mode 100644
index 66334fc996522..0000000000000
--- a/airflow/providers/amazon/aws/example_dags/example_dynamodb_to_s3.py
+++ /dev/null
@@ -1,72 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from datetime import datetime
-from os import environ
-
-from airflow import DAG
-from airflow.models.baseoperator import chain
-from airflow.providers.amazon.aws.transfers.dynamodb_to_s3 import DynamoDBToS3Operator
-
-TABLE_NAME = environ.get('DYNAMO_TABLE_NAME', 'ExistingDynamoDbTableName')
-BUCKET_NAME = environ.get('S3_BUCKET_NAME', 'ExistingS3BucketName')
-
-
-with DAG(
- dag_id='example_dynamodb_to_s3',
- schedule_interval=None,
- start_date=datetime(2021, 1, 1),
- tags=['example'],
- catchup=False,
-) as dag:
- # [START howto_transfer_dynamodb_to_s3]
- backup_db = DynamoDBToS3Operator(
- task_id='backup_db',
- dynamodb_table_name=TABLE_NAME,
- s3_bucket_name=BUCKET_NAME,
- # Max output file size in bytes. If the Table is too large, multiple files will be created.
- file_size=1000,
- )
- # [END howto_transfer_dynamodb_to_s3]
-
- # [START howto_transfer_dynamodb_to_s3_segmented]
- # Segmenting allows the transfer to be parallelized into {segment} number of parallel tasks.
- backup_db_segment_1 = DynamoDBToS3Operator(
- task_id='backup-1',
- dynamodb_table_name=TABLE_NAME,
- s3_bucket_name=BUCKET_NAME,
- # Max output file size in bytes. If the Table is too large, multiple files will be created.
- file_size=1000,
- dynamodb_scan_kwargs={
- "TotalSegments": 2,
- "Segment": 0,
- },
- )
-
- backup_db_segment_2 = DynamoDBToS3Operator(
- task_id="backup-2",
- dynamodb_table_name=TABLE_NAME,
- s3_bucket_name=BUCKET_NAME,
- # Max output file size in bytes. If the Table is too large, multiple files will be created.
- file_size=1000,
- dynamodb_scan_kwargs={
- "TotalSegments": 2,
- "Segment": 1,
- },
- )
- # [END howto_transfer_dynamodb_to_s3_segmented]
-
- chain(backup_db, [backup_db_segment_1, backup_db_segment_2])
diff --git a/airflow/providers/amazon/aws/example_dags/example_ec2.py b/airflow/providers/amazon/aws/example_dags/example_ec2.py
deleted file mode 100644
index 7d12aef36746e..0000000000000
--- a/airflow/providers/amazon/aws/example_dags/example_ec2.py
+++ /dev/null
@@ -1,56 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-import os
-from datetime import datetime
-
-from airflow import DAG
-from airflow.models.baseoperator import chain
-from airflow.providers.amazon.aws.operators.ec2 import EC2StartInstanceOperator, EC2StopInstanceOperator
-from airflow.providers.amazon.aws.sensors.ec2 import EC2InstanceStateSensor
-
-INSTANCE_ID = os.getenv("INSTANCE_ID", "instance-id")
-
-with DAG(
- dag_id='example_ec2',
- schedule_interval=None,
- start_date=datetime(2021, 1, 1),
- tags=['example'],
- catchup=False,
-) as dag:
- # [START howto_operator_ec2_start_instance]
- start_instance = EC2StartInstanceOperator(
- task_id="ec2_start_instance",
- instance_id=INSTANCE_ID,
- )
- # [END howto_operator_ec2_start_instance]
-
- # [START howto_sensor_ec2_instance_state]
- instance_state = EC2InstanceStateSensor(
- task_id="ec2_instance_state",
- instance_id=INSTANCE_ID,
- target_state="running",
- )
- # [END howto_sensor_ec2_instance_state]
-
- # [START howto_operator_ec2_stop_instance]
- stop_instance = EC2StopInstanceOperator(
- task_id="ec2_stop_instance",
- instance_id=INSTANCE_ID,
- )
- # [END howto_operator_ec2_stop_instance]
-
- chain(start_instance, instance_state, stop_instance)
diff --git a/airflow/providers/amazon/aws/example_dags/example_ecs.py b/airflow/providers/amazon/aws/example_dags/example_ecs.py
deleted file mode 100644
index b8f1f67b8eb08..0000000000000
--- a/airflow/providers/amazon/aws/example_dags/example_ecs.py
+++ /dev/null
@@ -1,59 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-import os
-from datetime import datetime
-
-from airflow import DAG
-from airflow.providers.amazon.aws.operators.ecs import EcsOperator
-
-with DAG(
- dag_id='example_ecs',
- schedule_interval=None,
- start_date=datetime(2021, 1, 1),
- tags=['example'],
- catchup=False,
-) as dag:
-
- # [START howto_operator_ecs]
- hello_world = EcsOperator(
- task_id="hello_world",
- cluster=os.environ.get("CLUSTER_NAME", "existing_cluster_name"),
- task_definition=os.environ.get("TASK_DEFINITION", "existing_task_definition_name"),
- launch_type="EXTERNAL|EC2",
- aws_conn_id="aws_ecs",
- overrides={
- "containerOverrides": [
- {
- "name": "hello-world-container",
- "command": ["echo", "hello", "world"],
- },
- ],
- },
- tags={
- "Customer": "X",
- "Project": "Y",
- "Application": "Z",
- "Version": "0.0.1",
- "Environment": "Development",
- },
- # [START howto_awslogs_ecs]
- awslogs_group="/ecs/hello-world",
- awslogs_region="aws-region",
- awslogs_stream_prefix="ecs/hello-world-container"
- # [END howto_awslogs_ecs]
- )
- # [END howto_operator_ecs]
diff --git a/airflow/providers/amazon/aws/example_dags/example_ecs_fargate.py b/airflow/providers/amazon/aws/example_dags/example_ecs_fargate.py
deleted file mode 100644
index 1e48367429faa..0000000000000
--- a/airflow/providers/amazon/aws/example_dags/example_ecs_fargate.py
+++ /dev/null
@@ -1,62 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-import os
-from datetime import datetime
-
-from airflow import DAG
-from airflow.providers.amazon.aws.operators.ecs import EcsOperator
-
-with DAG(
- dag_id='example_ecs_fargate',
- schedule_interval=None,
- start_date=datetime(2021, 1, 1),
- tags=['example'],
- catchup=False,
-) as dag:
-
- # [START howto_operator_ecs]
- hello_world = EcsOperator(
- task_id="hello_world",
- cluster=os.environ.get("CLUSTER_NAME", "existing_cluster_name"),
- task_definition=os.environ.get("TASK_DEFINITION", "existing_task_definition_name"),
- launch_type="FARGATE",
- aws_conn_id="aws_ecs",
- overrides={
- "containerOverrides": [
- {
- "name": "hello-world-container",
- "command": ["echo", "hello", "world"],
- },
- ],
- },
- network_configuration={
- "awsvpcConfiguration": {
- "securityGroups": [os.environ.get("SECURITY_GROUP_ID", "sg-123abc")],
- "subnets": [os.environ.get("SUBNET_ID", "subnet-123456ab")],
- },
- },
- tags={
- "Customer": "X",
- "Project": "Y",
- "Application": "Z",
- "Version": "0.0.1",
- "Environment": "Development",
- },
- awslogs_group="/ecs/hello-world",
- awslogs_stream_prefix="prefix_b/hello-world-container",
- )
- # [END howto_operator_ecs]
diff --git a/airflow/providers/amazon/aws/example_dags/example_eks_templated.py b/airflow/providers/amazon/aws/example_dags/example_eks_templated.py
index 87d6d9ef44e68..da984adcf5547 100644
--- a/airflow/providers/amazon/aws/example_dags/example_eks_templated.py
+++ b/airflow/providers/amazon/aws/example_dags/example_eks_templated.py
@@ -14,9 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
-# mypy ignore arg types (for templated fields)
-# type: ignore[arg-type]
+from __future__ import annotations
from datetime import datetime
@@ -31,6 +29,10 @@
)
from airflow.providers.amazon.aws.sensors.eks import EksClusterStateSensor, EksNodegroupStateSensor
+# mypy ignore arg types (for templated fields)
+# type: ignore[arg-type]
+
+
# Example Jinja Template format, substitute your values:
"""
{
@@ -48,10 +50,9 @@
"""
with DAG(
- dag_id='example_eks_templated',
- schedule_interval=None,
+ dag_id="example_eks_templated",
start_date=datetime(2021, 1, 1),
- tags=['example', 'templated'],
+ tags=["example", "templated"],
catchup=False,
# render_template_as_native_obj=True is what converts the Jinja to Python objects, instead of a string.
render_template_as_native_obj=True,
@@ -62,21 +63,22 @@
# Create an Amazon EKS Cluster control plane without attaching a compute service.
create_cluster = EksCreateClusterOperator(
- task_id='create_eks_cluster',
+ task_id="create_eks_cluster",
cluster_name=CLUSTER_NAME,
compute=None,
cluster_role_arn="{{ dag_run.conf['cluster_role_arn'] }}",
- resources_vpc_config="{{ dag_run.conf['resources_vpc_config'] }}",
+ # This only works with render_template_as_native_obj flag (this dag has it set)
+ resources_vpc_config="{{ dag_run.conf['resources_vpc_config'] }}", # type: ignore[arg-type]
)
await_create_cluster = EksClusterStateSensor(
- task_id='wait_for_create_cluster',
+ task_id="wait_for_create_cluster",
cluster_name=CLUSTER_NAME,
target_state=ClusterStates.ACTIVE,
)
create_nodegroup = EksCreateNodegroupOperator(
- task_id='create_eks_nodegroup',
+ task_id="create_eks_nodegroup",
cluster_name=CLUSTER_NAME,
nodegroup_name=NODEGROUP_NAME,
nodegroup_subnets="{{ dag_run.conf['nodegroup_subnets'] }}",
@@ -84,7 +86,7 @@
)
await_create_nodegroup = EksNodegroupStateSensor(
- task_id='wait_for_create_nodegroup',
+ task_id="wait_for_create_nodegroup",
cluster_name=CLUSTER_NAME,
nodegroup_name=NODEGROUP_NAME,
target_state=NodegroupStates.ACTIVE,
@@ -103,25 +105,25 @@
)
delete_nodegroup = EksDeleteNodegroupOperator(
- task_id='delete_eks_nodegroup',
+ task_id="delete_eks_nodegroup",
cluster_name=CLUSTER_NAME,
nodegroup_name=NODEGROUP_NAME,
)
await_delete_nodegroup = EksNodegroupStateSensor(
- task_id='wait_for_delete_nodegroup',
+ task_id="wait_for_delete_nodegroup",
cluster_name=CLUSTER_NAME,
nodegroup_name=NODEGROUP_NAME,
target_state=NodegroupStates.NONEXISTENT,
)
delete_cluster = EksDeleteClusterOperator(
- task_id='delete_eks_cluster',
+ task_id="delete_eks_cluster",
cluster_name=CLUSTER_NAME,
)
await_delete_cluster = EksClusterStateSensor(
- task_id='wait_for_delete_cluster',
+ task_id="wait_for_delete_cluster",
cluster_name=CLUSTER_NAME,
target_state=ClusterStates.NONEXISTENT,
)
diff --git a/airflow/providers/amazon/aws/example_dags/example_eks_with_fargate_in_one_step.py b/airflow/providers/amazon/aws/example_dags/example_eks_with_fargate_in_one_step.py
deleted file mode 100644
index e08e6525e6fc8..0000000000000
--- a/airflow/providers/amazon/aws/example_dags/example_eks_with_fargate_in_one_step.py
+++ /dev/null
@@ -1,107 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-# Ignore missing args provided by default_args
-# type: ignore[call-arg]
-
-from datetime import datetime
-from os import environ
-
-from airflow.models.dag import DAG
-from airflow.providers.amazon.aws.hooks.eks import ClusterStates, FargateProfileStates
-from airflow.providers.amazon.aws.operators.eks import (
- EksCreateClusterOperator,
- EksDeleteClusterOperator,
- EksPodOperator,
-)
-from airflow.providers.amazon.aws.sensors.eks import EksClusterStateSensor, EksFargateProfileStateSensor
-
-CLUSTER_NAME = 'fargate-all-in-one'
-FARGATE_PROFILE_NAME = f'{CLUSTER_NAME}-profile'
-
-ROLE_ARN = environ.get('EKS_DEMO_ROLE_ARN', 'arn:aws:iam::123456789012:role/role_name')
-SUBNETS = environ.get('EKS_DEMO_SUBNETS', 'subnet-12345ab subnet-67890cd').split(' ')
-VPC_CONFIG = {
- 'subnetIds': SUBNETS,
- 'endpointPublicAccess': True,
- 'endpointPrivateAccess': False,
-}
-
-
-with DAG(
- dag_id='example_eks_with_fargate_in_one_step',
- schedule_interval=None,
- start_date=datetime(2021, 1, 1),
- tags=['example'],
- catchup=False,
-) as dag:
-
- # [START howto_operator_eks_create_cluster_with_fargate_profile]
- # Create an Amazon EKS cluster control plane and an AWS Fargate compute platform in one step.
- create_cluster_and_fargate_profile = EksCreateClusterOperator(
- task_id='create_eks_cluster_and_fargate_profile',
- cluster_name=CLUSTER_NAME,
- cluster_role_arn=ROLE_ARN,
- resources_vpc_config=VPC_CONFIG,
- compute='fargate',
- fargate_profile_name=FARGATE_PROFILE_NAME,
- # Opting to use the same ARN for the cluster and the pod here,
- # but a different ARN could be configured and passed if desired.
- fargate_pod_execution_role_arn=ROLE_ARN,
- )
- # [END howto_operator_eks_create_cluster_with_fargate_profile]
-
- await_create_fargate_profile = EksFargateProfileStateSensor(
- task_id='wait_for_create_fargate_profile',
- cluster_name=CLUSTER_NAME,
- fargate_profile_name=FARGATE_PROFILE_NAME,
- target_state=FargateProfileStates.ACTIVE,
- )
-
- start_pod = EksPodOperator(
- task_id="run_pod",
- pod_name="run_pod",
- cluster_name=CLUSTER_NAME,
- image="amazon/aws-cli:latest",
- cmds=["sh", "-c", "echo Test Airflow; date"],
- labels={"demo": "hello_world"},
- get_logs=True,
- # Delete the pod when it reaches its final state, or the execution is interrupted.
- is_delete_operator_pod=True,
- )
-
- # An Amazon EKS cluster can not be deleted with attached resources such as nodegroups or Fargate profiles.
- # Setting the `force` to `True` will delete any attached resources before deleting the cluster.
- delete_all = EksDeleteClusterOperator(
- task_id='delete_fargate_profile_and_cluster',
- cluster_name=CLUSTER_NAME,
- force_delete_compute=True,
- )
-
- await_delete_cluster = EksClusterStateSensor(
- task_id='wait_for_delete_cluster',
- cluster_name=CLUSTER_NAME,
- target_state=ClusterStates.NONEXISTENT,
- )
-
- (
- create_cluster_and_fargate_profile
- >> await_create_fargate_profile
- >> start_pod
- >> delete_all
- >> await_delete_cluster
- )
diff --git a/airflow/providers/amazon/aws/example_dags/example_eks_with_fargate_profile.py b/airflow/providers/amazon/aws/example_dags/example_eks_with_fargate_profile.py
deleted file mode 100644
index 3ca3b2eb87728..0000000000000
--- a/airflow/providers/amazon/aws/example_dags/example_eks_with_fargate_profile.py
+++ /dev/null
@@ -1,138 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-# Ignore missing args provided by default_args
-# type: ignore[call-arg]
-
-from datetime import datetime
-from os import environ
-
-from airflow.models.dag import DAG
-from airflow.providers.amazon.aws.hooks.eks import ClusterStates, FargateProfileStates
-from airflow.providers.amazon.aws.operators.eks import (
- EksCreateClusterOperator,
- EksCreateFargateProfileOperator,
- EksDeleteClusterOperator,
- EksDeleteFargateProfileOperator,
- EksPodOperator,
-)
-from airflow.providers.amazon.aws.sensors.eks import EksClusterStateSensor, EksFargateProfileStateSensor
-
-CLUSTER_NAME = 'fargate-demo'
-FARGATE_PROFILE_NAME = f'{CLUSTER_NAME}-profile'
-SELECTORS = [{'namespace': 'default'}]
-
-ROLE_ARN = environ.get('EKS_DEMO_ROLE_ARN', 'arn:aws:iam::123456789012:role/role_name')
-SUBNETS = environ.get('EKS_DEMO_SUBNETS', 'subnet-12345ab subnet-67890cd').split(' ')
-VPC_CONFIG = {
- 'subnetIds': SUBNETS,
- 'endpointPublicAccess': True,
- 'endpointPrivateAccess': False,
-}
-
-
-with DAG(
- dag_id='example_eks_with_fargate_profile',
- schedule_interval=None,
- start_date=datetime(2021, 1, 1),
- tags=['example'],
- catchup=False,
-) as dag:
-
- # Create an Amazon EKS Cluster control plane without attaching a compute service.
- create_cluster = EksCreateClusterOperator(
- task_id='create_eks_cluster',
- cluster_name=CLUSTER_NAME,
- cluster_role_arn=ROLE_ARN,
- resources_vpc_config=VPC_CONFIG,
- compute=None,
- )
-
- await_create_cluster = EksClusterStateSensor(
- task_id='wait_for_create_cluster',
- cluster_name=CLUSTER_NAME,
- target_state=ClusterStates.ACTIVE,
- )
-
- # [START howto_operator_eks_create_fargate_profile]
- create_fargate_profile = EksCreateFargateProfileOperator(
- task_id='create_eks_fargate_profile',
- cluster_name=CLUSTER_NAME,
- pod_execution_role_arn=ROLE_ARN,
- fargate_profile_name=FARGATE_PROFILE_NAME,
- selectors=SELECTORS,
- )
- # [END howto_operator_eks_create_fargate_profile]
-
- # [START howto_sensor_eks_fargate]
- await_create_fargate_profile = EksFargateProfileStateSensor(
- task_id='wait_for_create_fargate_profile',
- cluster_name=CLUSTER_NAME,
- fargate_profile_name=FARGATE_PROFILE_NAME,
- target_state=FargateProfileStates.ACTIVE,
- )
- # [END howto_sensor_eks_fargate]
-
- start_pod = EksPodOperator(
- task_id="run_pod",
- cluster_name=CLUSTER_NAME,
- pod_name="run_pod",
- image="amazon/aws-cli:latest",
- cmds=["sh", "-c", "echo Test Airflow; date"],
- labels={"demo": "hello_world"},
- get_logs=True,
- # Delete the pod when it reaches its final state, or the execution is interrupted.
- is_delete_operator_pod=True,
- )
-
- # [START howto_operator_eks_delete_fargate_profile]
- delete_fargate_profile = EksDeleteFargateProfileOperator(
- task_id='delete_eks_fargate_profile',
- cluster_name=CLUSTER_NAME,
- fargate_profile_name=FARGATE_PROFILE_NAME,
- )
- # [END howto_operator_eks_delete_fargate_profile]
-
- await_delete_fargate_profile = EksFargateProfileStateSensor(
- task_id='wait_for_delete_fargate_profile',
- cluster_name=CLUSTER_NAME,
- fargate_profile_name=FARGATE_PROFILE_NAME,
- target_state=FargateProfileStates.NONEXISTENT,
- )
-
- delete_cluster = EksDeleteClusterOperator(
- task_id='delete_eks_cluster',
- cluster_name=CLUSTER_NAME,
- )
-
- await_delete_cluster = EksClusterStateSensor(
- task_id='wait_for_delete_cluster',
- cluster_name=CLUSTER_NAME,
- target_state=ClusterStates.NONEXISTENT,
- )
-
- (
- create_cluster
- >> await_create_cluster
- >> create_fargate_profile
- >> await_create_fargate_profile
- >> start_pod
- >> delete_fargate_profile
- >> await_delete_fargate_profile
- >> delete_cluster
- >> await_delete_cluster
- )
diff --git a/airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroup_in_one_step.py b/airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroup_in_one_step.py
deleted file mode 100644
index 38d1bd1ad4c2f..0000000000000
--- a/airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroup_in_one_step.py
+++ /dev/null
@@ -1,109 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-# Ignore missing args provided by default_args
-# type: ignore[call-arg]
-
-from datetime import datetime
-from os import environ
-
-from airflow.models.dag import DAG
-from airflow.providers.amazon.aws.hooks.eks import ClusterStates, NodegroupStates
-from airflow.providers.amazon.aws.operators.eks import (
- EksCreateClusterOperator,
- EksDeleteClusterOperator,
- EksPodOperator,
-)
-from airflow.providers.amazon.aws.sensors.eks import EksClusterStateSensor, EksNodegroupStateSensor
-
-CLUSTER_NAME = environ.get('EKS_CLUSTER_NAME', 'eks-demo')
-NODEGROUP_NAME = f'{CLUSTER_NAME}-nodegroup'
-ROLE_ARN = environ.get('EKS_DEMO_ROLE_ARN', 'arn:aws:iam::123456789012:role/role_name')
-SUBNETS = environ.get('EKS_DEMO_SUBNETS', 'subnet-12345ab subnet-67890cd').split(' ')
-VPC_CONFIG = {
- 'subnetIds': SUBNETS,
- 'endpointPublicAccess': True,
- 'endpointPrivateAccess': False,
-}
-
-
-with DAG(
- dag_id='example_eks_with_nodegroup_in_one_step',
- schedule_interval=None,
- start_date=datetime(2021, 1, 1),
- tags=['example'],
- catchup=False,
-) as dag:
-
- # [START howto_operator_eks_create_cluster_with_nodegroup]
- # Create an Amazon EKS cluster control plane and an EKS nodegroup compute platform in one step.
- create_cluster_and_nodegroup = EksCreateClusterOperator(
- task_id='create_eks_cluster_and_nodegroup',
- cluster_name=CLUSTER_NAME,
- nodegroup_name=NODEGROUP_NAME,
- cluster_role_arn=ROLE_ARN,
- nodegroup_role_arn=ROLE_ARN,
- # Opting to use the same ARN for the cluster and the nodegroup here,
- # but a different ARN could be configured and passed if desired.
- resources_vpc_config=VPC_CONFIG,
- # Compute defaults to 'nodegroup' but is called out here for the purposed of the example.
- compute='nodegroup',
- )
- # [END howto_operator_eks_create_cluster_with_nodegroup]
-
- await_create_nodegroup = EksNodegroupStateSensor(
- task_id='wait_for_create_nodegroup',
- cluster_name=CLUSTER_NAME,
- nodegroup_name=NODEGROUP_NAME,
- target_state=NodegroupStates.ACTIVE,
- )
-
- start_pod = EksPodOperator(
- task_id="run_pod",
- cluster_name=CLUSTER_NAME,
- pod_name="run_pod",
- image="amazon/aws-cli:latest",
- cmds=["sh", "-c", "echo Test Airflow; date"],
- labels={"demo": "hello_world"},
- get_logs=True,
- # Delete the pod when it reaches its final state, or the execution is interrupted.
- is_delete_operator_pod=True,
- )
-
- # [START howto_operator_eks_force_delete_cluster]
- # An Amazon EKS cluster can not be deleted with attached resources such as nodegroups or Fargate profiles.
- # Setting the `force` to `True` will delete any attached resources before deleting the cluster.
- delete_all = EksDeleteClusterOperator(
- task_id='delete_nodegroup_and_cluster',
- cluster_name=CLUSTER_NAME,
- force_delete_compute=True,
- )
- # [END howto_operator_eks_force_delete_cluster]
-
- await_delete_cluster = EksClusterStateSensor(
- task_id='wait_for_delete_cluster',
- cluster_name=CLUSTER_NAME,
- target_state=ClusterStates.NONEXISTENT,
- )
-
- (
- create_cluster_and_nodegroup
- >> await_create_nodegroup
- >> start_pod
- >> delete_all
- >> await_delete_cluster
- )
diff --git a/airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroups.py b/airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroups.py
deleted file mode 100644
index efeeb14e04bfb..0000000000000
--- a/airflow/providers/amazon/aws/example_dags/example_eks_with_nodegroups.py
+++ /dev/null
@@ -1,145 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-# Ignore missing args provided by default_args
-# type: ignore[call-arg]
-
-from datetime import datetime
-from os import environ
-
-from airflow.models.dag import DAG
-from airflow.providers.amazon.aws.hooks.eks import ClusterStates, NodegroupStates
-from airflow.providers.amazon.aws.operators.eks import (
- EksCreateClusterOperator,
- EksCreateNodegroupOperator,
- EksDeleteClusterOperator,
- EksDeleteNodegroupOperator,
- EksPodOperator,
-)
-from airflow.providers.amazon.aws.sensors.eks import EksClusterStateSensor, EksNodegroupStateSensor
-
-CLUSTER_NAME = 'eks-demo'
-NODEGROUP_SUFFIX = '-nodegroup'
-NODEGROUP_NAME = CLUSTER_NAME + NODEGROUP_SUFFIX
-ROLE_ARN = environ.get('EKS_DEMO_ROLE_ARN', 'arn:aws:iam::123456789012:role/role_name')
-SUBNETS = environ.get('EKS_DEMO_SUBNETS', 'subnet-12345ab subnet-67890cd').split(' ')
-VPC_CONFIG = {
- 'subnetIds': SUBNETS,
- 'endpointPublicAccess': True,
- 'endpointPrivateAccess': False,
-}
-
-
-with DAG(
- dag_id='example_eks_with_nodegroups',
- schedule_interval=None,
- start_date=datetime(2021, 1, 1),
- tags=['example'],
- catchup=False,
-) as dag:
-
- # [START howto_operator_eks_create_cluster]
- # Create an Amazon EKS Cluster control plane without attaching compute service.
- create_cluster = EksCreateClusterOperator(
- task_id='create_eks_cluster',
- cluster_name=CLUSTER_NAME,
- cluster_role_arn=ROLE_ARN,
- resources_vpc_config=VPC_CONFIG,
- compute=None,
- )
- # [END howto_operator_eks_create_cluster]
-
- # [START howto_sensor_eks_cluster]
- await_create_cluster = EksClusterStateSensor(
- task_id='wait_for_create_cluster',
- cluster_name=CLUSTER_NAME,
- target_state=ClusterStates.ACTIVE,
- )
- # [END howto_sensor_eks_cluster]
-
- # [START howto_operator_eks_create_nodegroup]
- create_nodegroup = EksCreateNodegroupOperator(
- task_id='create_eks_nodegroup',
- cluster_name=CLUSTER_NAME,
- nodegroup_name=NODEGROUP_NAME,
- nodegroup_subnets=SUBNETS,
- nodegroup_role_arn=ROLE_ARN,
- )
- # [END howto_operator_eks_create_nodegroup]
-
- # [START howto_sensor_eks_nodegroup]
- await_create_nodegroup = EksNodegroupStateSensor(
- task_id='wait_for_create_nodegroup',
- cluster_name=CLUSTER_NAME,
- nodegroup_name=NODEGROUP_NAME,
- target_state=NodegroupStates.ACTIVE,
- )
- # [END howto_sensor_eks_nodegroup]
-
- # [START howto_operator_eks_pod_operator]
- start_pod = EksPodOperator(
- task_id="run_pod",
- cluster_name=CLUSTER_NAME,
- pod_name="run_pod",
- image="amazon/aws-cli:latest",
- cmds=["sh", "-c", "ls"],
- labels={"demo": "hello_world"},
- get_logs=True,
- # Delete the pod when it reaches its final state, or the execution is interrupted.
- is_delete_operator_pod=True,
- )
- # [END howto_operator_eks_pod_operator]
-
- # [START howto_operator_eks_delete_nodegroup]
- delete_nodegroup = EksDeleteNodegroupOperator(
- task_id='delete_eks_nodegroup',
- cluster_name=CLUSTER_NAME,
- nodegroup_name=NODEGROUP_NAME,
- )
- # [END howto_operator_eks_delete_nodegroup]
-
- await_delete_nodegroup = EksNodegroupStateSensor(
- task_id='wait_for_delete_nodegroup',
- cluster_name=CLUSTER_NAME,
- nodegroup_name=NODEGROUP_NAME,
- target_state=NodegroupStates.NONEXISTENT,
- )
-
- # [START howto_operator_eks_delete_cluster]
- delete_cluster = EksDeleteClusterOperator(
- task_id='delete_eks_cluster',
- cluster_name=CLUSTER_NAME,
- )
- # [END howto_operator_eks_delete_cluster]
-
- await_delete_cluster = EksClusterStateSensor(
- task_id='wait_for_delete_cluster',
- cluster_name=CLUSTER_NAME,
- target_state=ClusterStates.NONEXISTENT,
- )
-
- (
- create_cluster
- >> await_create_cluster
- >> create_nodegroup
- >> await_create_nodegroup
- >> start_pod
- >> delete_nodegroup
- >> await_delete_nodegroup
- >> delete_cluster
- >> await_delete_cluster
- )
diff --git a/airflow/providers/amazon/aws/example_dags/example_emr_eks_job.py b/airflow/providers/amazon/aws/example_dags/example_emr_eks_job.py
deleted file mode 100644
index e8932630d9d45..0000000000000
--- a/airflow/providers/amazon/aws/example_dags/example_emr_eks_job.py
+++ /dev/null
@@ -1,75 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-import os
-from datetime import datetime
-
-from airflow import DAG
-from airflow.providers.amazon.aws.operators.emr import EmrContainerOperator
-
-VIRTUAL_CLUSTER_ID = os.getenv("VIRTUAL_CLUSTER_ID", "test-cluster")
-JOB_ROLE_ARN = os.getenv("JOB_ROLE_ARN", "arn:aws:iam::012345678912:role/emr_eks_default_role")
-
-# [START howto_operator_emr_eks_config]
-JOB_DRIVER_ARG = {
- "sparkSubmitJobDriver": {
- "entryPoint": "local:///usr/lib/spark/examples/src/main/python/pi.py",
- "sparkSubmitParameters": "--conf spark.executors.instances=2 --conf spark.executors.memory=2G --conf spark.executor.cores=2 --conf spark.driver.cores=1", # noqa: E501
- }
-}
-
-CONFIGURATION_OVERRIDES_ARG = {
- "applicationConfiguration": [
- {
- "classification": "spark-defaults",
- "properties": {
- "spark.hadoop.hive.metastore.client.factory.class": "com.amazonaws.glue.catalog.metastore.AWSGlueDataCatalogHiveClientFactory", # noqa: E501
- },
- }
- ],
- "monitoringConfiguration": {
- "cloudWatchMonitoringConfiguration": {
- "logGroupName": "/aws/emr-eks-spark",
- "logStreamNamePrefix": "airflow",
- }
- },
-}
-# [END howto_operator_emr_eks_config]
-
-with DAG(
- dag_id='example_emr_eks_job',
- schedule_interval=None,
- start_date=datetime(2021, 1, 1),
- tags=['example'],
- catchup=False,
-) as dag:
-
- # An example of how to get the cluster id and arn from an Airflow connection
- # VIRTUAL_CLUSTER_ID = '{{ conn.emr_eks.extra_dejson["virtual_cluster_id"] }}'
- # JOB_ROLE_ARN = '{{ conn.emr_eks.extra_dejson["job_role_arn"] }}'
-
- # [START howto_operator_emr_eks_job]
- job_starter = EmrContainerOperator(
- task_id="start_job",
- virtual_cluster_id=VIRTUAL_CLUSTER_ID,
- execution_role_arn=JOB_ROLE_ARN,
- release_label="emr-6.3.0-latest",
- job_driver=JOB_DRIVER_ARG,
- configuration_overrides=CONFIGURATION_OVERRIDES_ARG,
- name="pi.py",
- )
- # [END howto_operator_emr_eks_job]
diff --git a/airflow/providers/amazon/aws/example_dags/example_emr_job_flow_automatic_steps.py b/airflow/providers/amazon/aws/example_dags/example_emr_job_flow_automatic_steps.py
deleted file mode 100644
index ab92a3cb21a38..0000000000000
--- a/airflow/providers/amazon/aws/example_dags/example_emr_job_flow_automatic_steps.py
+++ /dev/null
@@ -1,84 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-import os
-from datetime import datetime
-
-from airflow import DAG
-from airflow.providers.amazon.aws.operators.emr import EmrCreateJobFlowOperator
-from airflow.providers.amazon.aws.sensors.emr import EmrJobFlowSensor
-
-JOB_FLOW_ROLE = os.getenv('EMR_JOB_FLOW_ROLE', 'EMR_EC2_DefaultRole')
-SERVICE_ROLE = os.getenv('EMR_SERVICE_ROLE', 'EMR_DefaultRole')
-
-# [START howto_operator_emr_automatic_steps_config]
-SPARK_STEPS = [
- {
- 'Name': 'calculate_pi',
- 'ActionOnFailure': 'CONTINUE',
- 'HadoopJarStep': {
- 'Jar': 'command-runner.jar',
- 'Args': ['/usr/lib/spark/bin/run-example', 'SparkPi', '10'],
- },
- }
-]
-
-JOB_FLOW_OVERRIDES = {
- 'Name': 'PiCalc',
- 'ReleaseLabel': 'emr-5.29.0',
- 'Applications': [{'Name': 'Spark'}],
- 'Instances': {
- 'InstanceGroups': [
- {
- 'Name': 'Primary node',
- 'Market': 'ON_DEMAND',
- 'InstanceRole': 'MASTER',
- 'InstanceType': 'm5.xlarge',
- 'InstanceCount': 1,
- },
- ],
- 'KeepJobFlowAliveWhenNoSteps': False,
- 'TerminationProtected': False,
- },
- 'Steps': SPARK_STEPS,
- 'JobFlowRole': JOB_FLOW_ROLE,
- 'ServiceRole': SERVICE_ROLE,
-}
-# [END howto_operator_emr_automatic_steps_config]
-
-
-with DAG(
- dag_id='example_emr_job_flow_automatic_steps',
- schedule_interval=None,
- start_date=datetime(2021, 1, 1),
- tags=['example'],
- catchup=False,
-) as dag:
-
- # [START howto_operator_emr_create_job_flow]
- job_flow_creator = EmrCreateJobFlowOperator(
- task_id='create_job_flow',
- job_flow_overrides=JOB_FLOW_OVERRIDES,
- )
- # [END howto_operator_emr_create_job_flow]
-
- # [START howto_sensor_emr_job_flow_sensor]
- job_sensor = EmrJobFlowSensor(
- task_id='check_job_flow',
- job_flow_id=job_flow_creator.output,
- )
- # [END howto_sensor_emr_job_flow_sensor]
diff --git a/airflow/providers/amazon/aws/example_dags/example_emr_job_flow_manual_steps.py b/airflow/providers/amazon/aws/example_dags/example_emr_job_flow_manual_steps.py
deleted file mode 100644
index d18237ecb0723..0000000000000
--- a/airflow/providers/amazon/aws/example_dags/example_emr_job_flow_manual_steps.py
+++ /dev/null
@@ -1,106 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-import os
-from datetime import datetime
-
-from airflow import DAG
-from airflow.models.baseoperator import chain
-from airflow.providers.amazon.aws.operators.emr import (
- EmrAddStepsOperator,
- EmrCreateJobFlowOperator,
- EmrTerminateJobFlowOperator,
-)
-from airflow.providers.amazon.aws.sensors.emr import EmrStepSensor
-
-JOB_FLOW_ROLE = os.getenv('EMR_JOB_FLOW_ROLE', 'EMR_EC2_DefaultRole')
-SERVICE_ROLE = os.getenv('EMR_SERVICE_ROLE', 'EMR_DefaultRole')
-
-SPARK_STEPS = [
- {
- 'Name': 'calculate_pi',
- 'ActionOnFailure': 'CONTINUE',
- 'HadoopJarStep': {
- 'Jar': 'command-runner.jar',
- 'Args': ['/usr/lib/spark/bin/run-example', 'SparkPi', '10'],
- },
- }
-]
-
-JOB_FLOW_OVERRIDES = {
- 'Name': 'PiCalc',
- 'ReleaseLabel': 'emr-5.29.0',
- 'Applications': [{'Name': 'Spark'}],
- 'Instances': {
- 'InstanceGroups': [
- {
- 'Name': 'Primary node',
- 'Market': 'ON_DEMAND',
- 'InstanceRole': 'MASTER',
- 'InstanceType': 'm5.xlarge',
- 'InstanceCount': 1,
- },
- ],
- 'KeepJobFlowAliveWhenNoSteps': False,
- 'TerminationProtected': False,
- },
- 'JobFlowRole': JOB_FLOW_ROLE,
- 'ServiceRole': SERVICE_ROLE,
-}
-
-
-with DAG(
- dag_id='example_emr_job_flow_manual_steps',
- schedule_interval=None,
- start_date=datetime(2021, 1, 1),
- tags=['example'],
- catchup=False,
-) as dag:
-
- cluster_creator = EmrCreateJobFlowOperator(
- task_id='create_job_flow',
- job_flow_overrides=JOB_FLOW_OVERRIDES,
- )
-
- # [START howto_operator_emr_add_steps]
- step_adder = EmrAddStepsOperator(
- task_id='add_steps',
- job_flow_id=cluster_creator.output,
- steps=SPARK_STEPS,
- )
- # [END howto_operator_emr_add_steps]
-
- # [START howto_sensor_emr_step_sensor]
- step_checker = EmrStepSensor(
- task_id='watch_step',
- job_flow_id=cluster_creator.output,
- step_id="{{ task_instance.xcom_pull(task_ids='add_steps', key='return_value')[0] }}",
- )
- # [END howto_sensor_emr_step_sensor]
-
- # [START howto_operator_emr_terminate_job_flow]
- cluster_remover = EmrTerminateJobFlowOperator(
- task_id='remove_cluster',
- job_flow_id=cluster_creator.output,
- )
- # [END howto_operator_emr_terminate_job_flow]
-
- chain(
- step_adder,
- step_checker,
- cluster_remover,
- )
diff --git a/airflow/providers/amazon/aws/example_dags/example_ftp_to_s3.py b/airflow/providers/amazon/aws/example_dags/example_ftp_to_s3.py
index d01ca38810729..4c187dc1e0c38 100644
--- a/airflow/providers/amazon/aws/example_dags/example_ftp_to_s3.py
+++ b/airflow/providers/amazon/aws/example_dags/example_ftp_to_s3.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
+from __future__ import annotations
import os
from datetime import datetime
@@ -27,7 +27,6 @@
with models.DAG(
"example_ftp_to_s3",
- schedule_interval=None,
start_date=datetime(2021, 1, 1),
catchup=False,
) as dag:
diff --git a/airflow/providers/amazon/aws/example_dags/example_gcs_to_s3.py b/airflow/providers/amazon/aws/example_dags/example_gcs_to_s3.py
index d9d04c73ffa31..ad5c8072df0a8 100644
--- a/airflow/providers/amazon/aws/example_dags/example_gcs_to_s3.py
+++ b/airflow/providers/amazon/aws/example_dags/example_gcs_to_s3.py
@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
import os
from datetime import datetime
@@ -26,7 +27,6 @@
with DAG(
dag_id="example_gcs_to_s3",
- schedule_interval=None,
start_date=datetime(2021, 1, 1),
tags=["example"],
catchup=False,
diff --git a/airflow/providers/amazon/aws/example_dags/example_glacier_to_gcs.py b/airflow/providers/amazon/aws/example_dags/example_glacier_to_gcs.py
index 2df40a7d0c37b..9c9f8a1976e89 100644
--- a/airflow/providers/amazon/aws/example_dags/example_glacier_to_gcs.py
+++ b/airflow/providers/amazon/aws/example_dags/example_glacier_to_gcs.py
@@ -14,11 +14,16 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
import os
from datetime import datetime
from airflow import DAG
-from airflow.providers.amazon.aws.operators.glacier import GlacierCreateJobOperator
+from airflow.providers.amazon.aws.operators.glacier import (
+ GlacierCreateJobOperator,
+ GlacierUploadArchiveOperator,
+)
from airflow.providers.amazon.aws.sensors.glacier import GlacierJobOperationSensor
from airflow.providers.amazon.aws.transfers.glacier_to_gcs import GlacierToGCSOperator
@@ -28,7 +33,6 @@
with DAG(
"example_glacier_to_gcs",
- schedule_interval=None,
start_date=datetime(2021, 1, 1), # Override to match your needs
catchup=False,
) as dag:
@@ -45,6 +49,12 @@
)
# [END howto_sensor_glacier_job_operation]
+ # [START howto_operator_glacier_upload_archive]
+ upload_archive_to_glacier = GlacierUploadArchiveOperator(
+ vault_name=VAULT_NAME, body=b"Test Data", task_id="upload_data_to_glacier"
+ )
+ # [END howto_operator_glacier_upload_archive]
+
# [START howto_transfer_glacier_to_gcs]
transfer_archive_to_gcs = GlacierToGCSOperator(
task_id="transfer_archive_to_gcs",
@@ -59,4 +69,4 @@
)
# [END howto_transfer_glacier_to_gcs]
- create_glacier_job >> wait_for_operation_complete >> transfer_archive_to_gcs
+ create_glacier_job >> wait_for_operation_complete >> upload_archive_to_glacier >> transfer_archive_to_gcs
diff --git a/airflow/providers/amazon/aws/example_dags/example_glue.py b/airflow/providers/amazon/aws/example_dags/example_glue.py
deleted file mode 100644
index 65417dc24f97c..0000000000000
--- a/airflow/providers/amazon/aws/example_dags/example_glue.py
+++ /dev/null
@@ -1,123 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from datetime import datetime
-from os import getenv
-
-from airflow import DAG
-from airflow.decorators import task
-from airflow.models.baseoperator import chain
-from airflow.providers.amazon.aws.hooks.s3 import S3Hook
-from airflow.providers.amazon.aws.operators.glue import GlueJobOperator
-from airflow.providers.amazon.aws.operators.glue_crawler import GlueCrawlerOperator
-from airflow.providers.amazon.aws.sensors.glue import GlueJobSensor
-from airflow.providers.amazon.aws.sensors.glue_crawler import GlueCrawlerSensor
-
-GLUE_DATABASE_NAME = getenv('GLUE_DATABASE_NAME', 'glue_database_name')
-GLUE_EXAMPLE_S3_BUCKET = getenv('GLUE_EXAMPLE_S3_BUCKET', 'glue_example_s3_bucket')
-
-# Role needs putobject/getobject access to the above bucket as well as the glue
-# service role, see docs here: https://docs.aws.amazon.com/glue/latest/dg/create-an-iam-role.html
-GLUE_CRAWLER_ROLE = getenv('GLUE_CRAWLER_ROLE', 'glue_crawler_role')
-GLUE_CRAWLER_NAME = 'example_crawler'
-GLUE_CRAWLER_CONFIG = {
- 'Name': GLUE_CRAWLER_NAME,
- 'Role': GLUE_CRAWLER_ROLE,
- 'DatabaseName': GLUE_DATABASE_NAME,
- 'Targets': {
- 'S3Targets': [
- {
- 'Path': f'{GLUE_EXAMPLE_S3_BUCKET}/input',
- }
- ]
- },
-}
-
-# Example csv data used as input to the example AWS Glue Job.
-EXAMPLE_CSV = '''
-apple,0.5
-milk,2.5
-bread,4.0
-'''
-
-# Example Spark script to operate on the above sample csv data.
-EXAMPLE_SCRIPT = f'''
-from pyspark.context import SparkContext
-from awsglue.context import GlueContext
-
-glueContext = GlueContext(SparkContext.getOrCreate())
-datasource = glueContext.create_dynamic_frame.from_catalog(
- database='{GLUE_DATABASE_NAME}', table_name='input')
-print('There are %s items in the table' % datasource.count())
-
-datasource.toDF().write.format('csv').mode("append").save('s3://{GLUE_EXAMPLE_S3_BUCKET}/output')
-'''
-
-
-@task(task_id='setup__upload_artifacts_to_s3')
-def upload_artifacts_to_s3():
- '''Upload example CSV input data and an example Spark script to be used by the Glue Job'''
- s3_hook = S3Hook()
- s3_load_kwargs = {"replace": True, "bucket_name": GLUE_EXAMPLE_S3_BUCKET}
- s3_hook.load_string(string_data=EXAMPLE_CSV, key='input/input.csv', **s3_load_kwargs)
- s3_hook.load_string(string_data=EXAMPLE_SCRIPT, key='etl_script.py', **s3_load_kwargs)
-
-
-with DAG(
- dag_id='example_glue',
- schedule_interval=None,
- start_date=datetime(2021, 1, 1),
- tags=['example'],
- catchup=False,
-) as glue_dag:
-
- setup_upload_artifacts_to_s3 = upload_artifacts_to_s3()
-
- # [START howto_operator_glue_crawler]
- crawl_s3 = GlueCrawlerOperator(
- task_id='crawl_s3',
- config=GLUE_CRAWLER_CONFIG,
- wait_for_completion=False,
- )
- # [END howto_operator_glue_crawler]
-
- # [START howto_sensor_glue_crawler]
- wait_for_crawl = GlueCrawlerSensor(task_id='wait_for_crawl', crawler_name=GLUE_CRAWLER_NAME)
- # [END howto_sensor_glue_crawler]
-
- # [START howto_operator_glue]
- job_name = 'example_glue_job'
- submit_glue_job = GlueJobOperator(
- task_id='submit_glue_job',
- job_name=job_name,
- wait_for_completion=False,
- script_location=f's3://{GLUE_EXAMPLE_S3_BUCKET}/etl_script.py',
- s3_bucket=GLUE_EXAMPLE_S3_BUCKET,
- iam_role_name=GLUE_CRAWLER_ROLE.split('/')[-1],
- create_job_kwargs={'GlueVersion': '3.0', 'NumberOfWorkers': 2, 'WorkerType': 'G.1X'},
- )
- # [END howto_operator_glue]
-
- # [START howto_sensor_glue]
- wait_for_job = GlueJobSensor(
- task_id='wait_for_job',
- job_name=job_name,
- # Job ID extracted from previous Glue Job Operator task
- run_id=submit_glue_job.output,
- )
- # [END howto_sensor_glue]
-
- chain(setup_upload_artifacts_to_s3, crawl_s3, wait_for_crawl, submit_glue_job, wait_for_job)
diff --git a/airflow/providers/amazon/aws/example_dags/example_google_api_sheets_to_s3.py b/airflow/providers/amazon/aws/example_dags/example_google_api_sheets_to_s3.py
index 7b6e4b291a66c..c405e802ac259 100644
--- a/airflow/providers/amazon/aws/example_dags/example_google_api_sheets_to_s3.py
+++ b/airflow/providers/amazon/aws/example_dags/example_google_api_sheets_to_s3.py
@@ -18,6 +18,7 @@
This is a basic example dag for using `GoogleApiToS3Operator` to retrieve Google Sheets data
You need to set all env variables to request the data.
"""
+from __future__ import annotations
from datetime import datetime
from os import getenv
@@ -31,18 +32,17 @@
with DAG(
dag_id="example_google_api_sheets_to_s3",
- schedule_interval=None,
start_date=datetime(2021, 1, 1),
catchup=False,
- tags=['example'],
+ tags=["example"],
) as dag:
# [START howto_transfer_google_api_sheets_to_s3]
task_google_sheets_values_to_s3 = GoogleApiToS3Operator(
- task_id='google_sheet_data_to_s3',
- google_api_service_name='sheets',
- google_api_service_version='v4',
- google_api_endpoint_path='sheets.spreadsheets.values.get',
- google_api_endpoint_params={'spreadsheetId': GOOGLE_SHEET_ID, 'range': GOOGLE_SHEET_RANGE},
+ task_id="google_sheet_data_to_s3",
+ google_api_service_name="sheets",
+ google_api_service_version="v4",
+ google_api_endpoint_path="sheets.spreadsheets.values.get",
+ google_api_endpoint_params={"spreadsheetId": GOOGLE_SHEET_ID, "range": GOOGLE_SHEET_RANGE},
s3_destination_key=S3_DESTINATION_KEY,
)
# [END howto_transfer_google_api_sheets_to_s3]
diff --git a/airflow/providers/amazon/aws/example_dags/example_google_api_youtube_to_s3.py b/airflow/providers/amazon/aws/example_dags/example_google_api_youtube_to_s3.py
deleted file mode 100644
index 682cc83919129..0000000000000
--- a/airflow/providers/amazon/aws/example_dags/example_google_api_youtube_to_s3.py
+++ /dev/null
@@ -1,113 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-This is a more advanced example dag for using `GoogleApiToS3Operator` which uses xcom to pass data between
-tasks to retrieve specific information about YouTube videos:
-
-First it searches for up to 50 videos (due to pagination) in a given time range
-(YOUTUBE_VIDEO_PUBLISHED_AFTER, YOUTUBE_VIDEO_PUBLISHED_BEFORE) on a YouTube channel (YOUTUBE_CHANNEL_ID)
-saves the response in S3 + passes over the YouTube IDs to the next request which then gets the information
-(YOUTUBE_VIDEO_FIELDS) for the requested videos and saves them in S3 (S3_DESTINATION_KEY).
-
-Further information:
-
-YOUTUBE_VIDEO_PUBLISHED_AFTER and YOUTUBE_VIDEO_PUBLISHED_BEFORE needs to be formatted
-"YYYY-MM-DDThh:mm:ss.sZ". See https://developers.google.com/youtube/v3/docs/search/list for more information.
-YOUTUBE_VIDEO_PARTS depends on the fields you pass via YOUTUBE_VIDEO_FIELDS. See
-https://developers.google.com/youtube/v3/docs/videos/list#parameters for more information.
-YOUTUBE_CONN_ID is optional for public videos. It does only need to authenticate when there are private videos
-on a YouTube channel you want to retrieve.
-"""
-
-from datetime import datetime
-from os import getenv
-
-from airflow import DAG
-from airflow.decorators import task
-from airflow.models.baseoperator import chain
-from airflow.providers.amazon.aws.transfers.google_api_to_s3 import GoogleApiToS3Operator
-
-YOUTUBE_CHANNEL_ID = getenv(
- "YOUTUBE_CHANNEL_ID", "UCSXwxpWZQ7XZ1WL3wqevChA"
-) # Youtube channel "Apache Airflow"
-YOUTUBE_VIDEO_PUBLISHED_AFTER = getenv("YOUTUBE_VIDEO_PUBLISHED_AFTER", "2019-09-25T00:00:00Z")
-YOUTUBE_VIDEO_PUBLISHED_BEFORE = getenv("YOUTUBE_VIDEO_PUBLISHED_BEFORE", "2019-10-18T00:00:00Z")
-S3_BUCKET_NAME = getenv("S3_DESTINATION_KEY", "s3://bucket-test")
-YOUTUBE_VIDEO_PARTS = getenv("YOUTUBE_VIDEO_PARTS", "snippet")
-YOUTUBE_VIDEO_FIELDS = getenv("YOUTUBE_VIDEO_FIELDS", "items(id,snippet(description,publishedAt,tags,title))")
-
-
-@task(task_id='transform_video_ids')
-def transform_video_ids(**kwargs):
- task_instance = kwargs['task_instance']
- output = task_instance.xcom_pull(task_ids="video_ids_to_s3", key="video_ids_response")
- video_ids = [item['id']['videoId'] for item in output['items']]
-
- if not video_ids:
- video_ids = []
-
- kwargs['task_instance'].xcom_push(key='video_ids', value={'id': ','.join(video_ids)})
-
-
-with DAG(
- dag_id="example_google_api_youtube_to_s3",
- schedule_interval=None,
- start_date=datetime(2021, 1, 1),
- catchup=False,
- tags=['example'],
-) as dag:
- # [START howto_transfer_google_api_youtube_search_to_s3]
- task_video_ids_to_s3 = GoogleApiToS3Operator(
- task_id='video_ids_to_s3',
- google_api_service_name='youtube',
- google_api_service_version='v3',
- google_api_endpoint_path='youtube.search.list',
- google_api_endpoint_params={
- 'part': 'snippet',
- 'channelId': YOUTUBE_CHANNEL_ID,
- 'maxResults': 50,
- 'publishedAfter': YOUTUBE_VIDEO_PUBLISHED_AFTER,
- 'publishedBefore': YOUTUBE_VIDEO_PUBLISHED_BEFORE,
- 'type': 'video',
- 'fields': 'items/id/videoId',
- },
- google_api_response_via_xcom='video_ids_response',
- s3_destination_key=f'{S3_BUCKET_NAME}/youtube_search.json',
- s3_overwrite=True,
- )
- # [END howto_transfer_google_api_youtube_search_to_s3]
-
- task_transform_video_ids = transform_video_ids()
-
- # [START howto_transfer_google_api_youtube_list_to_s3]
- task_video_data_to_s3 = GoogleApiToS3Operator(
- task_id='video_data_to_s3',
- google_api_service_name='youtube',
- google_api_service_version='v3',
- google_api_endpoint_path='youtube.videos.list',
- google_api_endpoint_params={
- 'part': YOUTUBE_VIDEO_PARTS,
- 'maxResults': 50,
- 'fields': YOUTUBE_VIDEO_FIELDS,
- },
- google_api_endpoint_params_via_xcom='video_ids',
- s3_destination_key=f'{S3_BUCKET_NAME}/youtube_videos.json',
- s3_overwrite=True,
- )
- # [END howto_transfer_google_api_youtube_list_to_s3]
-
- chain(task_video_ids_to_s3, task_transform_video_ids, task_video_data_to_s3)
diff --git a/airflow/providers/amazon/aws/example_dags/example_hive_to_dynamodb.py b/airflow/providers/amazon/aws/example_dags/example_hive_to_dynamodb.py
index 6fccd64c6d596..1e4375febd607 100644
--- a/airflow/providers/amazon/aws/example_dags/example_hive_to_dynamodb.py
+++ b/airflow/providers/amazon/aws/example_dags/example_hive_to_dynamodb.py
@@ -14,12 +14,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""
This DAG will not work unless you create an Amazon EMR cluster running
Apache Hive and copy data into it following steps 1-4 (inclusive) here:
https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/EMRforDynamoDB.Tutorial.html
"""
+from __future__ import annotations
import os
from datetime import datetime
@@ -31,33 +31,33 @@
from airflow.providers.amazon.aws.transfers.hive_to_dynamodb import HiveToDynamoDBOperator
from airflow.utils import db
-DYNAMODB_TABLE_NAME = 'example_hive_to_dynamodb_table'
-HIVE_CONNECTION_ID = os.getenv('HIVE_CONNECTION_ID', 'hive_on_emr')
-HIVE_HOSTNAME = os.getenv('HIVE_HOSTNAME', 'ec2-123-45-67-890.compute-1.amazonaws.com')
+DYNAMODB_TABLE_NAME = "example_hive_to_dynamodb_table"
+HIVE_CONNECTION_ID = os.getenv("HIVE_CONNECTION_ID", "hive_on_emr")
+HIVE_HOSTNAME = os.getenv("HIVE_HOSTNAME", "ec2-123-45-67-890.compute-1.amazonaws.com")
# These values assume you set up the Hive data source following the link above.
-DYNAMODB_TABLE_HASH_KEY = 'feature_id'
-HIVE_SQL = 'SELECT feature_id, feature_name, feature_class, state_alpha FROM hive_features'
+DYNAMODB_TABLE_HASH_KEY = "feature_id"
+HIVE_SQL = "SELECT feature_id, feature_name, feature_class, state_alpha FROM hive_features"
@task
def create_dynamodb_table():
- client = DynamoDBHook(client_type='dynamodb').conn
+ client = DynamoDBHook(client_type="dynamodb").conn
client.create_table(
TableName=DYNAMODB_TABLE_NAME,
KeySchema=[
- {'AttributeName': DYNAMODB_TABLE_HASH_KEY, 'KeyType': 'HASH'},
+ {"AttributeName": DYNAMODB_TABLE_HASH_KEY, "KeyType": "HASH"},
],
AttributeDefinitions=[
- {'AttributeName': DYNAMODB_TABLE_HASH_KEY, 'AttributeType': 'N'},
+ {"AttributeName": DYNAMODB_TABLE_HASH_KEY, "AttributeType": "N"},
],
- ProvisionedThroughput={'ReadCapacityUnits': 20, 'WriteCapacityUnits': 20},
+ ProvisionedThroughput={"ReadCapacityUnits": 20, "WriteCapacityUnits": 20},
)
# DynamoDB table creation is nearly, but not quite, instantaneous.
# Wait for the table to be active to avoid race conditions writing to it.
- waiter = client.get_waiter('table_exists')
- waiter.wait(TableName=DYNAMODB_TABLE_NAME, WaiterConfig={'Delay': 1})
+ waiter = client.get_waiter("table_exists")
+ waiter.wait(TableName=DYNAMODB_TABLE_NAME, WaiterConfig={"Delay": 1})
@task
@@ -66,24 +66,24 @@ def get_dynamodb_item_count():
A DynamoDB table has an ItemCount value, but it is only updated every six hours.
To verify this DAG worked, we will scan the table and count the items manually.
"""
- table = DynamoDBHook(resource_type='dynamodb').conn.Table(DYNAMODB_TABLE_NAME)
+ table = DynamoDBHook(resource_type="dynamodb").conn.Table(DYNAMODB_TABLE_NAME)
- response = table.scan(Select='COUNT')
- item_count = response['Count']
+ response = table.scan(Select="COUNT")
+ item_count = response["Count"]
- while 'LastEvaluatedKey' in response:
- response = table.scan(Select='COUNT', ExclusiveStartKey=response['LastEvaluatedKey'])
- item_count += response['Count']
+ while "LastEvaluatedKey" in response:
+ response = table.scan(Select="COUNT", ExclusiveStartKey=response["LastEvaluatedKey"])
+ item_count += response["Count"]
- print(f'DynamoDB table contains {item_count} items.')
+ print(f"DynamoDB table contains {item_count} items.")
# Included for sample purposes only; in production you wouldn't delete
# the table you just backed your data up to. Using 'all_done' so even
# if an intermediate step fails, the DAG will clean up after itself.
-@task(trigger_rule='all_done')
+@task(trigger_rule="all_done")
def delete_dynamodb_table():
- DynamoDBHook(client_type='dynamodb').conn.delete_table(TableName=DYNAMODB_TABLE_NAME)
+ DynamoDBHook(client_type="dynamodb").conn.delete_table(TableName=DYNAMODB_TABLE_NAME)
# Included for sample purposes only; in production this should
@@ -96,7 +96,7 @@ def configure_hive_connection():
db.merge_conn(
Connection(
conn_id=HIVE_CONNECTION_ID,
- conn_type='hiveserver2',
+ conn_type="hiveserver2",
host=HIVE_HOSTNAME,
port=10000,
)
@@ -104,10 +104,9 @@ def configure_hive_connection():
with DAG(
- dag_id='example_hive_to_dynamodb',
- schedule_interval=None,
+ dag_id="example_hive_to_dynamodb",
start_date=datetime(2021, 1, 1),
- tags=['example'],
+ tags=["example"],
catchup=False,
) as dag:
# Add the prerequisites docstring to the DAG in the UI.
@@ -115,7 +114,7 @@ def configure_hive_connection():
# [START howto_transfer_hive_to_dynamodb]
backup_to_dynamodb = HiveToDynamoDBOperator(
- task_id='backup_to_dynamodb',
+ task_id="backup_to_dynamodb",
hiveserver2_conn_id=HIVE_CONNECTION_ID,
sql=HIVE_SQL,
table_name=DYNAMODB_TABLE_NAME,
diff --git a/airflow/providers/amazon/aws/example_dags/example_imap_attachment_to_s3.py b/airflow/providers/amazon/aws/example_dags/example_imap_attachment_to_s3.py
index 357d92a6f5694..3d63f115f0900 100644
--- a/airflow/providers/amazon/aws/example_dags/example_imap_attachment_to_s3.py
+++ b/airflow/providers/amazon/aws/example_dags/example_imap_attachment_to_s3.py
@@ -14,11 +14,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""
This is an example dag for using `ImapAttachmentToS3Operator` to transfer an email attachment via IMAP
protocol from a mail server to S3 Bucket.
"""
+from __future__ import annotations
from datetime import datetime
from os import getenv
@@ -35,13 +35,12 @@
with DAG(
dag_id="example_imap_attachment_to_s3",
start_date=datetime(2021, 1, 1),
- schedule_interval=None,
catchup=False,
- tags=['example'],
+ tags=["example"],
) as dag:
# [START howto_transfer_imap_attachment_to_s3]
task_transfer_imap_attachment_to_s3 = ImapAttachmentToS3Operator(
- task_id='transfer_imap_attachment_to_s3',
+ task_id="transfer_imap_attachment_to_s3",
imap_attachment_name=IMAP_ATTACHMENT_NAME,
s3_bucket=S3_BUCKET,
s3_key=S3_KEY,
diff --git a/airflow/providers/amazon/aws/example_dags/example_lambda.py b/airflow/providers/amazon/aws/example_dags/example_lambda.py
deleted file mode 100644
index 3b87c3aa3c3fb..0000000000000
--- a/airflow/providers/amazon/aws/example_dags/example_lambda.py
+++ /dev/null
@@ -1,45 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-import json
-from datetime import datetime, timedelta
-from os import getenv
-
-from airflow import DAG
-from airflow.providers.amazon.aws.operators.aws_lambda import AwsLambdaInvokeFunctionOperator
-
-# [START howto_operator_lambda_env_variables]
-LAMBDA_FUNCTION_NAME = getenv("LAMBDA_FUNCTION_NAME", "test-function")
-# [END howto_operator_lambda_env_variables]
-
-SAMPLE_EVENT = json.dumps({"SampleEvent": {"SampleData": {"Name": "XYZ", "DoB": "1993-01-01"}}})
-
-with DAG(
- dag_id='example_lambda',
- schedule_interval=None,
- start_date=datetime(2021, 1, 1),
- dagrun_timeout=timedelta(minutes=60),
- tags=['example'],
- catchup=False,
-) as dag:
- # [START howto_operator_lambda]
- invoke_lambda_function = AwsLambdaInvokeFunctionOperator(
- task_id='setup__invoke_lambda_function',
- function_name=LAMBDA_FUNCTION_NAME,
- payload=SAMPLE_EVENT,
- )
- # [END howto_operator_lambda]
diff --git a/airflow/providers/amazon/aws/example_dags/example_local_to_s3.py b/airflow/providers/amazon/aws/example_dags/example_local_to_s3.py
deleted file mode 100644
index 05f9c74d7685b..0000000000000
--- a/airflow/providers/amazon/aws/example_dags/example_local_to_s3.py
+++ /dev/null
@@ -1,42 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-
-import os
-
-from airflow import models
-from airflow.providers.amazon.aws.transfers.local_to_s3 import LocalFilesystemToS3Operator
-from airflow.utils.dates import datetime
-
-S3_BUCKET = os.environ.get("S3_BUCKET", "test-bucket")
-S3_KEY = os.environ.get("S3_KEY", "key")
-
-with models.DAG(
- "example_local_to_s3",
- schedule_interval=None,
- start_date=datetime(2021, 1, 1), # Override to match your needs
- catchup=False,
-) as dag:
- # [START howto_transfer_local_to_s3]
- create_local_to_s3_job = LocalFilesystemToS3Operator(
- task_id="create_local_to_s3_job",
- filename="relative/path/to/file.csv",
- dest_key=S3_KEY,
- dest_bucket=S3_BUCKET,
- replace=True,
- )
- # [END howto_transfer_local_to_s3]
diff --git a/airflow/providers/amazon/aws/example_dags/example_mongo_to_s3.py b/airflow/providers/amazon/aws/example_dags/example_mongo_to_s3.py
index e95964b59cee1..39e4d6ac89869 100644
--- a/airflow/providers/amazon/aws/example_dags/example_mongo_to_s3.py
+++ b/airflow/providers/amazon/aws/example_dags/example_mongo_to_s3.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
+from __future__ import annotations
import os
@@ -29,7 +29,6 @@
with models.DAG(
"example_mongo_to_s3",
- schedule_interval=None,
start_date=datetime(2021, 1, 1),
catchup=False,
) as dag:
diff --git a/airflow/providers/amazon/aws/example_dags/example_quicksight.py b/airflow/providers/amazon/aws/example_dags/example_quicksight.py
deleted file mode 100644
index 5c50a54492617..0000000000000
--- a/airflow/providers/amazon/aws/example_dags/example_quicksight.py
+++ /dev/null
@@ -1,57 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-import os
-from datetime import datetime
-
-from airflow import DAG
-from airflow.models.baseoperator import chain
-from airflow.providers.amazon.aws.operators.quicksight import QuickSightCreateIngestionOperator
-from airflow.providers.amazon.aws.sensors.quicksight import QuickSightSensor
-
-DATA_SET_ID = os.getenv("DATA_SET_ID", "data-set-id")
-INGESTION_ID = os.getenv("INGESTION_ID", "ingestion-id")
-
-with DAG(
- dag_id="example_quicksight",
- schedule_interval=None,
- start_date=datetime(2021, 1, 1),
- tags=["example"],
- catchup=False,
-) as dag:
- # Create and Start the QuickSight SPICE data ingestion
- # and does not wait for its completion
- # [START howto_operator_quicksight_create_ingestion]
- quicksight_create_ingestion_no_waiting = QuickSightCreateIngestionOperator(
- task_id="quicksight_create_ingestion_no_waiting",
- data_set_id=DATA_SET_ID,
- ingestion_id=INGESTION_ID,
- wait_for_completion=False,
- )
- # [END howto_operator_quicksight_create_ingestion]
-
- # The following task checks the status of the QuickSight SPICE ingestion
- # job until it succeeds.
- # [START howto_sensor_quicksight]
- quicksight_job_status = QuickSightSensor(
- task_id="quicksight_job_status",
- data_set_id=DATA_SET_ID,
- ingestion_id=INGESTION_ID,
- )
- # [END howto_sensor_quicksight]
-
- chain(quicksight_create_ingestion_no_waiting, quicksight_job_status)
diff --git a/airflow/providers/amazon/aws/example_dags/example_rds_event.py b/airflow/providers/amazon/aws/example_dags/example_rds_event.py
deleted file mode 100644
index 4ec8b6f5be3c4..0000000000000
--- a/airflow/providers/amazon/aws/example_dags/example_rds_event.py
+++ /dev/null
@@ -1,58 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-from datetime import datetime
-from os import getenv
-
-from airflow import DAG
-from airflow.models.baseoperator import chain
-from airflow.providers.amazon.aws.operators.rds import (
- RdsCreateEventSubscriptionOperator,
- RdsDeleteEventSubscriptionOperator,
-)
-
-SUBSCRIPTION_NAME = getenv("SUBSCRIPTION_NAME", "subscription-name")
-SNS_TOPIC_ARN = getenv("SNS_TOPIC_ARN", "arn:aws:sns:::MyTopic")
-RDS_DB_IDENTIFIER = getenv("RDS_DB_IDENTIFIER", "database-identifier")
-
-with DAG(
- dag_id='example_rds_event',
- schedule_interval=None,
- start_date=datetime(2021, 1, 1),
- tags=['example'],
- catchup=False,
-) as dag:
- # [START howto_operator_rds_create_event_subscription]
- create_subscription = RdsCreateEventSubscriptionOperator(
- task_id='create_subscription',
- subscription_name=SUBSCRIPTION_NAME,
- sns_topic_arn=SNS_TOPIC_ARN,
- source_type='db-instance',
- source_ids=[RDS_DB_IDENTIFIER],
- event_categories=['availability'],
- )
- # [END howto_operator_rds_create_event_subscription]
-
- # [START howto_operator_rds_delete_event_subscription]
- delete_subscription = RdsDeleteEventSubscriptionOperator(
- task_id='delete_subscription',
- subscription_name=SUBSCRIPTION_NAME,
- )
- # [END howto_operator_rds_delete_event_subscription]
-
- chain(create_subscription, delete_subscription)
diff --git a/airflow/providers/amazon/aws/example_dags/example_rds_export.py b/airflow/providers/amazon/aws/example_dags/example_rds_export.py
deleted file mode 100644
index 1dce5804910af..0000000000000
--- a/airflow/providers/amazon/aws/example_dags/example_rds_export.py
+++ /dev/null
@@ -1,71 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-from datetime import datetime
-from os import getenv
-
-from airflow import DAG
-from airflow.models.baseoperator import chain
-from airflow.providers.amazon.aws.operators.rds import RdsCancelExportTaskOperator, RdsStartExportTaskOperator
-from airflow.providers.amazon.aws.sensors.rds import RdsExportTaskExistenceSensor
-
-RDS_EXPORT_TASK_IDENTIFIER = getenv("RDS_EXPORT_TASK_IDENTIFIER", "export-task-identifier")
-RDS_EXPORT_SOURCE_ARN = getenv(
- "RDS_EXPORT_SOURCE_ARN", "arn:aws:rds:::snapshot:snap-id"
-)
-BUCKET_NAME = getenv("BUCKET_NAME", "bucket-name")
-BUCKET_PREFIX = getenv("BUCKET_PREFIX", "bucket-prefix")
-ROLE_ARN = getenv("ROLE_ARN", "arn:aws:iam:::role/Role")
-KMS_KEY_ID = getenv("KMS_KEY_ID", "arn:aws:kms:::key/key-id")
-
-
-with DAG(
- dag_id='example_rds_export',
- schedule_interval=None,
- start_date=datetime(2021, 1, 1),
- tags=['example'],
- catchup=False,
-) as dag:
- # [START howto_operator_rds_start_export_task]
- start_export = RdsStartExportTaskOperator(
- task_id='start_export',
- export_task_identifier=RDS_EXPORT_TASK_IDENTIFIER,
- source_arn=RDS_EXPORT_SOURCE_ARN,
- s3_bucket_name=BUCKET_NAME,
- s3_prefix=BUCKET_PREFIX,
- iam_role_arn=ROLE_ARN,
- kms_key_id=KMS_KEY_ID,
- )
- # [END howto_operator_rds_start_export_task]
-
- # [START howto_operator_rds_cancel_export]
- cancel_export = RdsCancelExportTaskOperator(
- task_id='cancel_export',
- export_task_identifier=RDS_EXPORT_TASK_IDENTIFIER,
- )
- # [END howto_operator_rds_cancel_export]
-
- # [START howto_sensor_rds_export_task_existence]
- export_sensor = RdsExportTaskExistenceSensor(
- task_id='export_sensor',
- export_task_identifier=RDS_EXPORT_TASK_IDENTIFIER,
- target_statuses=['canceled'],
- )
- # [END howto_sensor_rds_export_task_existence]
-
- chain(start_export, cancel_export, export_sensor)
diff --git a/airflow/providers/amazon/aws/example_dags/example_rds_snapshot.py b/airflow/providers/amazon/aws/example_dags/example_rds_snapshot.py
deleted file mode 100644
index f7e1d02e07d49..0000000000000
--- a/airflow/providers/amazon/aws/example_dags/example_rds_snapshot.py
+++ /dev/null
@@ -1,76 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-from datetime import datetime
-from os import getenv
-
-from airflow import DAG
-from airflow.models.baseoperator import chain
-from airflow.providers.amazon.aws.operators.rds import (
- RdsCopyDbSnapshotOperator,
- RdsCreateDbSnapshotOperator,
- RdsDeleteDbSnapshotOperator,
-)
-from airflow.providers.amazon.aws.sensors.rds import RdsSnapshotExistenceSensor
-
-RDS_DB_IDENTIFIER = getenv("RDS_DB_IDENTIFIER", "database-identifier")
-RDS_DB_SNAPSHOT_IDENTIFIER = getenv("RDS_DB_SNAPSHOT_IDENTIFIER", "database-1-snap")
-
-with DAG(
- dag_id='example_rds_snapshot',
- schedule_interval=None,
- start_date=datetime(2021, 1, 1),
- tags=['example'],
- catchup=False,
-) as dag:
- # [START howto_operator_rds_create_db_snapshot]
- create_snapshot = RdsCreateDbSnapshotOperator(
- task_id='create_snapshot',
- db_type='instance',
- db_identifier=RDS_DB_IDENTIFIER,
- db_snapshot_identifier=RDS_DB_SNAPSHOT_IDENTIFIER,
- )
- # [END howto_operator_rds_create_db_snapshot]
-
- # [START howto_sensor_rds_snapshot_existence]
- snapshot_sensor = RdsSnapshotExistenceSensor(
- task_id='snapshot_sensor',
- db_type='instance',
- db_snapshot_identifier=RDS_DB_IDENTIFIER,
- target_statuses=['available'],
- )
- # [END howto_sensor_rds_snapshot_existence]
-
- # [START howto_operator_rds_copy_snapshot]
- copy_snapshot = RdsCopyDbSnapshotOperator(
- task_id='copy_snapshot',
- db_type='instance',
- source_db_snapshot_identifier=RDS_DB_IDENTIFIER,
- target_db_snapshot_identifier=f'{RDS_DB_IDENTIFIER}-copy',
- )
- # [END howto_operator_rds_copy_snapshot]
-
- # [START howto_operator_rds_delete_snapshot]
- delete_snapshot = RdsDeleteDbSnapshotOperator(
- task_id='delete_snapshot',
- db_type='instance',
- db_snapshot_identifier=RDS_DB_IDENTIFIER,
- )
- # [END howto_operator_rds_delete_snapshot]
-
- chain(create_snapshot, snapshot_sensor, copy_snapshot, delete_snapshot)
diff --git a/airflow/providers/amazon/aws/example_dags/example_redshift_cluster.py b/airflow/providers/amazon/aws/example_dags/example_redshift_cluster.py
deleted file mode 100644
index dc6deead3c618..0000000000000
--- a/airflow/providers/amazon/aws/example_dags/example_redshift_cluster.py
+++ /dev/null
@@ -1,98 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-from datetime import datetime
-from os import getenv
-
-from airflow import DAG
-from airflow.models.baseoperator import chain
-from airflow.providers.amazon.aws.operators.redshift_cluster import (
- RedshiftCreateClusterOperator,
- RedshiftDeleteClusterOperator,
- RedshiftPauseClusterOperator,
- RedshiftResumeClusterOperator,
-)
-from airflow.providers.amazon.aws.sensors.redshift_cluster import RedshiftClusterSensor
-
-REDSHIFT_CLUSTER_IDENTIFIER = getenv("REDSHIFT_CLUSTER_IDENTIFIER", "redshift-cluster-1")
-
-with DAG(
- dag_id="example_redshift_cluster",
- start_date=datetime(2021, 1, 1),
- schedule_interval=None,
- catchup=False,
- tags=['example'],
-) as dag:
- # [START howto_operator_redshift_cluster]
- task_create_cluster = RedshiftCreateClusterOperator(
- task_id="redshift_create_cluster",
- cluster_identifier=REDSHIFT_CLUSTER_IDENTIFIER,
- cluster_type="single-node",
- node_type="dc2.large",
- master_username="adminuser",
- master_user_password="dummypass",
- )
- # [END howto_operator_redshift_cluster]
-
- # [START howto_sensor_redshift_cluster]
- task_wait_cluster_available = RedshiftClusterSensor(
- task_id='sensor_redshift_cluster_available',
- cluster_identifier=REDSHIFT_CLUSTER_IDENTIFIER,
- target_status='available',
- poke_interval=5,
- timeout=60 * 15,
- )
- # [END howto_sensor_redshift_cluster]
-
- # [START howto_operator_redshift_pause_cluster]
- task_pause_cluster = RedshiftPauseClusterOperator(
- task_id='redshift_pause_cluster',
- cluster_identifier=REDSHIFT_CLUSTER_IDENTIFIER,
- )
- # [END howto_operator_redshift_pause_cluster]
-
- task_wait_cluster_paused = RedshiftClusterSensor(
- task_id='sensor_redshift_cluster_paused',
- cluster_identifier=REDSHIFT_CLUSTER_IDENTIFIER,
- target_status='paused',
- poke_interval=5,
- timeout=60 * 15,
- )
-
- # [START howto_operator_redshift_resume_cluster]
- task_resume_cluster = RedshiftResumeClusterOperator(
- task_id='redshift_resume_cluster',
- cluster_identifier=REDSHIFT_CLUSTER_IDENTIFIER,
- )
- # [END howto_operator_redshift_resume_cluster]
-
- # [START howto_operator_redshift_delete_cluster]
- task_delete_cluster = RedshiftDeleteClusterOperator(
- task_id="delete_cluster",
- cluster_identifier=REDSHIFT_CLUSTER_IDENTIFIER,
- )
- # [END howto_operator_redshift_delete_cluster]
-
- chain(
- task_create_cluster,
- task_wait_cluster_available,
- task_pause_cluster,
- task_wait_cluster_paused,
- task_resume_cluster,
- task_delete_cluster,
- )
diff --git a/airflow/providers/amazon/aws/example_dags/example_redshift_data_execute_sql.py b/airflow/providers/amazon/aws/example_dags/example_redshift_data_execute_sql.py
deleted file mode 100644
index cfa3e4cefcb2d..0000000000000
--- a/airflow/providers/amazon/aws/example_dags/example_redshift_data_execute_sql.py
+++ /dev/null
@@ -1,72 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-from datetime import datetime
-from os import getenv
-
-from airflow import DAG
-from airflow.decorators import task
-from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook
-from airflow.providers.amazon.aws.operators.redshift_data import RedshiftDataOperator
-
-REDSHIFT_CLUSTER_IDENTIFIER = getenv("REDSHIFT_CLUSTER_IDENTIFIER", "redshift_cluster_identifier")
-REDSHIFT_DATABASE = getenv("REDSHIFT_DATABASE", "redshift_database")
-REDSHIFT_DATABASE_USER = getenv("REDSHIFT_DATABASE_USER", "awsuser")
-
-REDSHIFT_QUERY = """
-SELECT table_schema,
- table_name
-FROM information_schema.tables
-WHERE table_schema NOT IN ('information_schema', 'pg_catalog')
- AND table_type = 'BASE TABLE'
-ORDER BY table_schema,
- table_name;
- """
-POLL_INTERVAL = 10
-
-
-@task(task_id="output_results")
-def output_query_results(statement_id):
- hook = RedshiftDataHook()
- resp = hook.conn.get_statement_result(
- Id=statement_id,
- )
-
- print(resp)
- return resp
-
-
-with DAG(
- dag_id="example_redshift_data_execute_sql",
- start_date=datetime(2021, 1, 1),
- schedule_interval=None,
- catchup=False,
- tags=['example'],
-) as dag:
- # [START howto_operator_redshift_data]
- task_query = RedshiftDataOperator(
- task_id='redshift_query',
- cluster_identifier=REDSHIFT_CLUSTER_IDENTIFIER,
- database=REDSHIFT_DATABASE,
- db_user=REDSHIFT_DATABASE_USER,
- sql=REDSHIFT_QUERY,
- poll_interval=POLL_INTERVAL,
- await_result=True,
- )
- # [END howto_operator_redshift_data]
-
- task_output = output_query_results(task_query.output)
diff --git a/airflow/providers/amazon/aws/example_dags/example_redshift_sql.py b/airflow/providers/amazon/aws/example_dags/example_redshift_sql.py
deleted file mode 100644
index a71ef71934edc..0000000000000
--- a/airflow/providers/amazon/aws/example_dags/example_redshift_sql.py
+++ /dev/null
@@ -1,79 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-from datetime import datetime
-
-from airflow import DAG
-from airflow.models.baseoperator import chain
-from airflow.providers.amazon.aws.operators.redshift_sql import RedshiftSQLOperator
-
-with DAG(
- dag_id="example_redshift_sql",
- start_date=datetime(2021, 1, 1),
- schedule_interval=None,
- catchup=False,
- tags=['example'],
-) as dag:
- setup__task_create_table = RedshiftSQLOperator(
- task_id='setup__create_table',
- sql="""
- CREATE TABLE IF NOT EXISTS fruit (
- fruit_id INTEGER,
- name VARCHAR NOT NULL,
- color VARCHAR NOT NULL
- );
- """,
- )
-
- setup__task_insert_data = RedshiftSQLOperator(
- task_id='setup__task_insert_data',
- sql=[
- "INSERT INTO fruit VALUES ( 1, 'Banana', 'Yellow');",
- "INSERT INTO fruit VALUES ( 2, 'Apple', 'Red');",
- "INSERT INTO fruit VALUES ( 3, 'Lemon', 'Yellow');",
- "INSERT INTO fruit VALUES ( 4, 'Grape', 'Purple');",
- "INSERT INTO fruit VALUES ( 5, 'Pear', 'Green');",
- "INSERT INTO fruit VALUES ( 6, 'Strawberry', 'Red');",
- ],
- )
-
- # [START howto_operator_redshift_sql]
- task_select_data = RedshiftSQLOperator(
- task_id='task_get_all_table_data', sql="""CREATE TABLE more_fruit AS SELECT * FROM fruit;"""
- )
- # [END howto_operator_redshift_sql]
-
- # [START howto_operator_redshift_sql_with_params]
- task_select_filtered_data = RedshiftSQLOperator(
- task_id='task_get_filtered_table_data',
- sql="""CREATE TABLE filtered_fruit AS SELECT * FROM fruit WHERE color = '{{ params.color }}';""",
- params={'color': 'Red'},
- )
- # [END howto_operator_redshift_sql_with_params]
-
- teardown__task_drop_table = RedshiftSQLOperator(
- task_id='teardown__drop_table',
- sql='DROP TABLE IF EXISTS fruit',
- )
-
- chain(
- setup__task_create_table,
- setup__task_insert_data,
- [task_select_data, task_select_filtered_data],
- teardown__task_drop_table,
- )
diff --git a/airflow/providers/amazon/aws/example_dags/example_redshift_to_s3.py b/airflow/providers/amazon/aws/example_dags/example_redshift_to_s3.py
deleted file mode 100644
index 8116e02dc165c..0000000000000
--- a/airflow/providers/amazon/aws/example_dags/example_redshift_to_s3.py
+++ /dev/null
@@ -1,43 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-from datetime import datetime
-from os import getenv
-
-from airflow import DAG
-from airflow.providers.amazon.aws.transfers.redshift_to_s3 import RedshiftToS3Operator
-
-S3_BUCKET_NAME = getenv("S3_BUCKET_NAME", "s3_bucket_name")
-S3_KEY = getenv("S3_KEY", "s3_key")
-REDSHIFT_TABLE = getenv("REDSHIFT_TABLE", "redshift_table")
-
-with DAG(
- dag_id="example_redshift_to_s3",
- start_date=datetime(2021, 1, 1),
- schedule_interval=None,
- catchup=False,
- tags=['example'],
-) as dag:
- # [START howto_transfer_redshift_to_s3]
- task_transfer_redshift_to_s3 = RedshiftToS3Operator(
- task_id='transfer_redshift_to_s3',
- s3_bucket=S3_BUCKET_NAME,
- s3_key=S3_KEY,
- schema='PUBLIC',
- table=REDSHIFT_TABLE,
- )
- # [END howto_transfer_redshift_to_s3]
diff --git a/airflow/providers/amazon/aws/example_dags/example_s3.py b/airflow/providers/amazon/aws/example_dags/example_s3.py
deleted file mode 100644
index 7e06575d4a58b..0000000000000
--- a/airflow/providers/amazon/aws/example_dags/example_s3.py
+++ /dev/null
@@ -1,224 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-import os
-from datetime import datetime
-from typing import List
-
-from airflow.models.baseoperator import chain
-from airflow.models.dag import DAG
-from airflow.providers.amazon.aws.operators.s3 import (
- S3CopyObjectOperator,
- S3CreateBucketOperator,
- S3CreateObjectOperator,
- S3DeleteBucketOperator,
- S3DeleteBucketTaggingOperator,
- S3DeleteObjectsOperator,
- S3FileTransformOperator,
- S3GetBucketTaggingOperator,
- S3ListOperator,
- S3ListPrefixesOperator,
- S3PutBucketTaggingOperator,
-)
-from airflow.providers.amazon.aws.sensors.s3 import S3KeySensor, S3KeysUnchangedSensor
-
-BUCKET_NAME = os.environ.get('BUCKET_NAME', 'test-airflow-12345')
-BUCKET_NAME_2 = os.environ.get('BUCKET_NAME_2', 'test-airflow-123456')
-KEY = os.environ.get('KEY', 'key')
-KEY_2 = os.environ.get('KEY_2', 'key2')
-# Empty string prefix refers to the bucket root
-# See what prefix is here https://docs.aws.amazon.com/AmazonS3/latest/userguide/using-prefixes.html
-PREFIX = os.environ.get('PREFIX', '')
-DELIMITER = os.environ.get('DELIMITER', '/')
-TAG_KEY = os.environ.get('TAG_KEY', 'test-s3-bucket-tagging-key')
-TAG_VALUE = os.environ.get('TAG_VALUE', 'test-s3-bucket-tagging-value')
-DATA = os.environ.get(
- 'DATA',
- '''
-apple,0.5
-milk,2.5
-bread,4.0
-''',
-)
-
-with DAG(
- dag_id='example_s3',
- schedule_interval=None,
- start_date=datetime(2021, 1, 1),
- catchup=False,
- tags=['example'],
-) as dag:
- # [START howto_sensor_s3_key_function_definition]
- def check_fn(files: List) -> bool:
- """
- Example of custom check: check if all files are bigger than 1kB
-
- :param files: List of S3 object attributes.
- Format: [{
- 'Size': int
- }]
- :return: true if the criteria is met
- :rtype: bool
- """
- return all(f.get('Size', 0) > 1024 for f in files)
-
- # [END howto_sensor_s3_key_function_definition]
-
- # [START howto_operator_s3_create_bucket]
- create_bucket = S3CreateBucketOperator(
- task_id='s3_create_bucket',
- bucket_name=BUCKET_NAME,
- )
- # [END howto_operator_s3_create_bucket]
-
- # [START howto_operator_s3_put_bucket_tagging]
- put_tagging = S3PutBucketTaggingOperator(
- task_id='s3_put_bucket_tagging',
- bucket_name=BUCKET_NAME,
- key=TAG_KEY,
- value=TAG_VALUE,
- )
- # [END howto_operator_s3_put_bucket_tagging]
-
- # [START howto_operator_s3_get_bucket_tagging]
- get_tagging = S3GetBucketTaggingOperator(
- task_id='s3_get_bucket_tagging',
- bucket_name=BUCKET_NAME,
- )
- # [END howto_operator_s3_get_bucket_tagging]
-
- # [START howto_operator_s3_delete_bucket_tagging]
- delete_tagging = S3DeleteBucketTaggingOperator(
- task_id='s3_delete_bucket_tagging',
- bucket_name=BUCKET_NAME,
- )
- # [END howto_operator_s3_delete_bucket_tagging]
-
- # [START howto_operator_s3_create_object]
- create_object = S3CreateObjectOperator(
- task_id="s3_create_object",
- s3_bucket=BUCKET_NAME,
- s3_key=KEY,
- data=DATA,
- replace=True,
- )
- # [END howto_operator_s3_create_object]
-
- # [START howto_operator_s3_list_prefixes]
- list_prefixes = S3ListPrefixesOperator(
- task_id="s3_list_prefix_operator",
- bucket=BUCKET_NAME,
- prefix=PREFIX,
- delimiter=DELIMITER,
- )
- # [END howto_operator_s3_list_prefixes]
-
- # [START howto_operator_s3_list]
- list_keys = S3ListOperator(
- task_id="s3_list_operator",
- bucket=BUCKET_NAME,
- prefix=PREFIX,
- )
- # [END howto_operator_s3_list]
-
- # [START howto_sensor_s3_key_single_key]
- # Check if a file exists
- sensor_one_key = S3KeySensor(
- task_id="s3_sensor_one_key",
- bucket_name=BUCKET_NAME,
- bucket_key=KEY,
- )
- # [END howto_sensor_s3_key_single_key]
-
- # [START howto_sensor_s3_key_multiple_keys]
- # Check if both files exist
- sensor_two_keys = S3KeySensor(
- task_id="s3_sensor_two_keys",
- bucket_name=BUCKET_NAME,
- bucket_key=[KEY, KEY_2],
- )
- # [END howto_sensor_s3_key_multiple_keys]
-
- # [START howto_sensor_s3_key_function]
- # Check if a file exists and match a certain pattern defined in check_fn
- sensor_key_with_function = S3KeySensor(
- task_id="s3_sensor_key_function",
- bucket_name=BUCKET_NAME,
- bucket_key=KEY,
- check_fn=check_fn,
- )
- # [END howto_sensor_s3_key_function]
-
- # [START howto_sensor_s3_keys_unchanged]
- sensor_keys_unchanged = S3KeysUnchangedSensor(
- task_id="s3_sensor_one_key_size",
- bucket_name=BUCKET_NAME_2,
- prefix=PREFIX,
- inactivity_period=10,
- )
- # [END howto_sensor_s3_keys_unchanged]
-
- # [START howto_operator_s3_copy_object]
- copy_object = S3CopyObjectOperator(
- task_id="s3_copy_object",
- source_bucket_name=BUCKET_NAME,
- dest_bucket_name=BUCKET_NAME_2,
- source_bucket_key=KEY,
- dest_bucket_key=KEY_2,
- )
- # [END howto_operator_s3_copy_object]
-
- # [START howto_operator_s3_file_transform]
- transforms_file = S3FileTransformOperator(
- task_id="s3_file_transform",
- source_s3_key=f's3://{BUCKET_NAME}/{KEY}',
- dest_s3_key=f's3://{BUCKET_NAME_2}/{KEY_2}',
- # Use `cp` command as transform script as an example
- transform_script='cp',
- replace=True,
- )
- # [END howto_operator_s3_file_transform]
-
- # [START howto_operator_s3_delete_objects]
- delete_objects = S3DeleteObjectsOperator(
- task_id="s3_delete_objects",
- bucket=BUCKET_NAME_2,
- keys=KEY_2,
- )
- # [END howto_operator_s3_delete_objects]
-
- # [START howto_operator_s3_delete_bucket]
- delete_bucket = S3DeleteBucketOperator(
- task_id='s3_delete_bucket', bucket_name=BUCKET_NAME, force_delete=True
- )
- # [END howto_operator_s3_delete_bucket]
-
- chain(
- create_bucket,
- put_tagging,
- get_tagging,
- delete_tagging,
- create_object,
- list_prefixes,
- list_keys,
- [sensor_one_key, sensor_two_keys, sensor_key_with_function],
- copy_object,
- transforms_file,
- sensor_keys_unchanged,
- delete_objects,
- delete_bucket,
- )
diff --git a/airflow/providers/amazon/aws/example_dags/example_s3_to_ftp.py b/airflow/providers/amazon/aws/example_dags/example_s3_to_ftp.py
index 6ebe2501c613b..47ca88a0af72b 100644
--- a/airflow/providers/amazon/aws/example_dags/example_s3_to_ftp.py
+++ b/airflow/providers/amazon/aws/example_dags/example_s3_to_ftp.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
+from __future__ import annotations
import os
from datetime import datetime
@@ -27,7 +27,6 @@
with models.DAG(
"example_s3_to_ftp",
- schedule_interval=None,
start_date=datetime(2021, 1, 1),
catchup=False,
) as dag:
diff --git a/airflow/providers/amazon/aws/example_dags/example_s3_to_redshift.py b/airflow/providers/amazon/aws/example_dags/example_s3_to_redshift.py
deleted file mode 100644
index 82ae0660053f1..0000000000000
--- a/airflow/providers/amazon/aws/example_dags/example_s3_to_redshift.py
+++ /dev/null
@@ -1,80 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-from datetime import datetime
-from os import getenv
-
-from airflow import DAG
-from airflow.decorators import task
-from airflow.models.baseoperator import chain
-from airflow.providers.amazon.aws.hooks.s3 import S3Hook
-from airflow.providers.amazon.aws.operators.redshift_sql import RedshiftSQLOperator
-from airflow.providers.amazon.aws.transfers.s3_to_redshift import S3ToRedshiftOperator
-
-S3_BUCKET_NAME = getenv("S3_BUCKET_NAME", "s3_bucket_name")
-S3_KEY = getenv("S3_KEY", "s3_filename")
-REDSHIFT_TABLE = getenv("REDSHIFT_TABLE", "redshift_table")
-
-
-@task(task_id='setup__add_sample_data_to_s3')
-def task_add_sample_data_to_s3():
- s3_hook = S3Hook()
- s3_hook.load_string("0,Airflow", f'{S3_KEY}/{REDSHIFT_TABLE}', S3_BUCKET_NAME, replace=True)
-
-
-@task(task_id='teardown__remove_sample_data_from_s3')
-def task_remove_sample_data_from_s3():
- s3_hook = S3Hook()
- if s3_hook.check_for_key(f'{S3_KEY}/{REDSHIFT_TABLE}', S3_BUCKET_NAME):
- s3_hook.delete_objects(S3_BUCKET_NAME, f'{S3_KEY}/{REDSHIFT_TABLE}')
-
-
-with DAG(
- dag_id="example_s3_to_redshift",
- start_date=datetime(2021, 1, 1),
- schedule_interval=None,
- catchup=False,
- tags=['example'],
-) as dag:
- add_sample_data_to_s3 = task_add_sample_data_to_s3()
-
- setup__task_create_table = RedshiftSQLOperator(
- sql=f'CREATE TABLE IF NOT EXISTS {REDSHIFT_TABLE}(Id int, Name varchar)',
- task_id='setup__create_table',
- )
- # [START howto_transfer_s3_to_redshift]
- task_transfer_s3_to_redshift = S3ToRedshiftOperator(
- s3_bucket=S3_BUCKET_NAME,
- s3_key=S3_KEY,
- schema='PUBLIC',
- table=REDSHIFT_TABLE,
- copy_options=['csv'],
- task_id='transfer_s3_to_redshift',
- )
- # [END howto_transfer_s3_to_redshift]
- teardown__task_drop_table = RedshiftSQLOperator(
- sql=f'DROP TABLE IF EXISTS {REDSHIFT_TABLE}',
- task_id='teardown__drop_table',
- )
-
- remove_sample_data_from_s3 = task_remove_sample_data_from_s3()
-
- chain(
- [add_sample_data_to_s3, setup__task_create_table],
- task_transfer_s3_to_redshift,
- [teardown__task_drop_table, remove_sample_data_from_s3],
- )
diff --git a/airflow/providers/amazon/aws/example_dags/example_s3_to_sftp.py b/airflow/providers/amazon/aws/example_dags/example_s3_to_sftp.py
index d7983265e7697..1c625b84942e4 100644
--- a/airflow/providers/amazon/aws/example_dags/example_s3_to_sftp.py
+++ b/airflow/providers/amazon/aws/example_dags/example_s3_to_sftp.py
@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
import os
from datetime import datetime
@@ -26,7 +27,6 @@
with models.DAG(
"example_s3_to_sftp",
- schedule_interval=None,
start_date=datetime(2021, 1, 1),
catchup=False,
) as dag:
diff --git a/airflow/providers/amazon/aws/example_dags/example_sagemaker.py b/airflow/providers/amazon/aws/example_dags/example_sagemaker.py
deleted file mode 100644
index 41e7666eabb8a..0000000000000
--- a/airflow/providers/amazon/aws/example_dags/example_sagemaker.py
+++ /dev/null
@@ -1,467 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-import base64
-import os
-import subprocess
-from datetime import datetime
-from tempfile import NamedTemporaryFile
-
-import boto3
-
-from airflow import DAG
-from airflow.decorators import task
-from airflow.providers.amazon.aws.hooks.s3 import S3Hook
-from airflow.providers.amazon.aws.operators.sagemaker import (
- SageMakerDeleteModelOperator,
- SageMakerModelOperator,
- SageMakerProcessingOperator,
- SageMakerTrainingOperator,
- SageMakerTransformOperator,
- SageMakerTuningOperator,
-)
-from airflow.providers.amazon.aws.sensors.sagemaker import (
- SageMakerTrainingSensor,
- SageMakerTransformSensor,
- SageMakerTuningSensor,
-)
-
-# Project name will be used in naming the S3 buckets and various tasks.
-# The dataset used in this example is identifying varieties of the Iris flower.
-PROJECT_NAME = 'iris'
-TIMESTAMP = '{{ ts_nodash }}'
-
-S3_BUCKET = os.getenv('S3_BUCKET', 'S3_bucket')
-RAW_DATA_S3_KEY = f'{PROJECT_NAME}/preprocessing/input.csv'
-INPUT_DATA_S3_KEY = f'{PROJECT_NAME}/processed-input-data'
-TRAINING_OUTPUT_S3_KEY = f'{PROJECT_NAME}/results'
-PREDICTION_OUTPUT_S3_KEY = f'{PROJECT_NAME}/transform'
-
-PROCESSING_LOCAL_INPUT_PATH = '/opt/ml/processing/input'
-PROCESSING_LOCAL_OUTPUT_PATH = '/opt/ml/processing/output'
-
-MODEL_NAME = f'{PROJECT_NAME}-KNN-model'
-# Job names can't be reused, so appending a timestamp ensures it is unique.
-PROCESSING_JOB_NAME = f'{PROJECT_NAME}-processing-{TIMESTAMP}'
-TRAINING_JOB_NAME = f'{PROJECT_NAME}-train-{TIMESTAMP}'
-TRANSFORM_JOB_NAME = f'{PROJECT_NAME}-transform-{TIMESTAMP}'
-TUNING_JOB_NAME = f'{PROJECT_NAME}-tune-{TIMESTAMP}'
-
-ROLE_ARN = os.getenv(
- 'SAGEMAKER_ROLE_ARN',
- 'arn:aws:iam::1234567890:role/service-role/AmazonSageMaker-ExecutionRole',
-)
-ECR_REPOSITORY = os.getenv('ECR_REPOSITORY', '1234567890.dkr.ecr.us-west-2.amazonaws.com/process_data')
-REGION = ECR_REPOSITORY.split('.')[3]
-
-# For this example we are using a subset of Fischer's Iris Data Set.
-# The full dataset can be found at UC Irvine's machine learning repository:
-# https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data
-DATASET = """
- 5.1,3.5,1.4,0.2,Iris-setosa
- 4.9,3.0,1.4,0.2,Iris-setosa
- 7.0,3.2,4.7,1.4,Iris-versicolor
- 6.4,3.2,4.5,1.5,Iris-versicolor
- 4.9,2.5,4.5,1.7,Iris-virginica
- 7.3,2.9,6.3,1.8,Iris-virginica
- """
-SAMPLE_SIZE = DATASET.count('\n') - 1
-
-# The URI of an Amazon-provided docker image for handling KNN model training. This is a public ECR
-# repo cited in public SageMaker documentation, so the account number does not need to be redacted.
-# For more info see: https://docs.aws.amazon.com/sagemaker/latest/dg/ecr-us-west-2.html#knn-us-west-2.title
-KNN_IMAGE_URI = '174872318107.dkr.ecr.us-west-2.amazonaws.com/knn'
-
-TASK_TIMEOUT = {'MaxRuntimeInSeconds': 6 * 60}
-
-RESOURCE_CONFIG = {
- 'InstanceCount': 1,
- 'InstanceType': 'ml.m5.large',
- 'VolumeSizeInGB': 1,
-}
-
-TRAINING_DATA_SOURCE = {
- 'CompressionType': 'None',
- 'ContentType': 'text/csv',
- 'DataSource': { # type: ignore
- 'S3DataSource': {
- 'S3DataDistributionType': 'FullyReplicated',
- 'S3DataType': 'S3Prefix',
- 'S3Uri': f's3://{S3_BUCKET}/{INPUT_DATA_S3_KEY}/train.csv',
- }
- },
-}
-
-# Define configs for processing, training, model creation, and batch transform jobs
-SAGEMAKER_PROCESSING_JOB_CONFIG = {
- 'ProcessingJobName': PROCESSING_JOB_NAME,
- 'RoleArn': f'{ROLE_ARN}',
- 'ProcessingInputs': [
- {
- 'InputName': 'input',
- 'AppManaged': False,
- 'S3Input': {
- 'S3Uri': f's3://{S3_BUCKET}/{RAW_DATA_S3_KEY}',
- 'LocalPath': PROCESSING_LOCAL_INPUT_PATH,
- 'S3DataType': 'S3Prefix',
- 'S3InputMode': 'File',
- 'S3DataDistributionType': 'FullyReplicated',
- 'S3CompressionType': 'None',
- },
- },
- ],
- 'ProcessingOutputConfig': {
- 'Outputs': [
- {
- 'OutputName': 'output',
- 'S3Output': {
- 'S3Uri': f's3://{S3_BUCKET}/{INPUT_DATA_S3_KEY}',
- 'LocalPath': PROCESSING_LOCAL_OUTPUT_PATH,
- 'S3UploadMode': 'EndOfJob',
- },
- 'AppManaged': False,
- }
- ]
- },
- 'ProcessingResources': {
- 'ClusterConfig': RESOURCE_CONFIG,
- },
- 'StoppingCondition': TASK_TIMEOUT,
- 'AppSpecification': {
- 'ImageUri': ECR_REPOSITORY,
- },
-}
-
-TRAINING_CONFIG = {
- 'TrainingJobName': TRAINING_JOB_NAME,
- 'RoleArn': ROLE_ARN,
- 'AlgorithmSpecification': {
- "TrainingImage": KNN_IMAGE_URI,
- "TrainingInputMode": "File",
- },
- 'HyperParameters': {
- 'predictor_type': 'classifier',
- 'feature_dim': '4',
- 'k': '3',
- 'sample_size': str(SAMPLE_SIZE),
- },
- 'InputDataConfig': [
- {
- 'ChannelName': 'train',
- **TRAINING_DATA_SOURCE, # type: ignore [arg-type]
- }
- ],
- 'OutputDataConfig': {'S3OutputPath': f's3://{S3_BUCKET}/{TRAINING_OUTPUT_S3_KEY}/'},
- 'ResourceConfig': RESOURCE_CONFIG,
- 'StoppingCondition': TASK_TIMEOUT,
-}
-
-MODEL_CONFIG = {
- 'ModelName': MODEL_NAME,
- 'ExecutionRoleArn': ROLE_ARN,
- 'PrimaryContainer': {
- 'Mode': 'SingleModel',
- 'Image': KNN_IMAGE_URI,
- 'ModelDataUrl': f's3://{S3_BUCKET}/{TRAINING_OUTPUT_S3_KEY}/{TRAINING_JOB_NAME}/output/model.tar.gz',
- },
-}
-
-TRANSFORM_CONFIG = {
- 'TransformJobName': TRANSFORM_JOB_NAME,
- 'ModelName': MODEL_NAME,
- 'TransformInput': {
- 'DataSource': {
- 'S3DataSource': {
- 'S3DataType': 'S3Prefix',
- 'S3Uri': f's3://{S3_BUCKET}/{INPUT_DATA_S3_KEY}/test.csv',
- }
- },
- 'SplitType': 'Line',
- 'ContentType': 'text/csv',
- },
- 'TransformOutput': {'S3OutputPath': f's3://{S3_BUCKET}/{PREDICTION_OUTPUT_S3_KEY}'},
- 'TransformResources': {
- 'InstanceCount': 1,
- 'InstanceType': 'ml.m5.large',
- },
-}
-
-TUNING_CONFIG = {
- 'HyperParameterTuningJobName': TUNING_JOB_NAME,
- 'HyperParameterTuningJobConfig': {
- 'Strategy': 'Bayesian',
- 'HyperParameterTuningJobObjective': {
- 'MetricName': 'test:accuracy',
- 'Type': 'Maximize',
- },
- 'ResourceLimits': {
- # You would bump these up in production as appropriate.
- 'MaxNumberOfTrainingJobs': 1,
- 'MaxParallelTrainingJobs': 1,
- },
- 'ParameterRanges': {
- 'CategoricalParameterRanges': [],
- 'IntegerParameterRanges': [
- # Set the min and max values of the hyperparameters you want to tune.
- {
- 'Name': 'k',
- 'MinValue': '1',
- "MaxValue": str(SAMPLE_SIZE),
- },
- {
- 'Name': 'sample_size',
- 'MinValue': '1',
- 'MaxValue': str(SAMPLE_SIZE),
- },
- ],
- },
- },
- 'TrainingJobDefinition': {
- 'StaticHyperParameters': {
- 'predictor_type': 'classifier',
- 'feature_dim': '4',
- },
- 'AlgorithmSpecification': {'TrainingImage': KNN_IMAGE_URI, 'TrainingInputMode': 'File'},
- 'InputDataConfig': [
- {
- 'ChannelName': 'train',
- **TRAINING_DATA_SOURCE, # type: ignore [arg-type]
- },
- {
- 'ChannelName': 'test',
- **TRAINING_DATA_SOURCE, # type: ignore [arg-type]
- },
- ],
- 'OutputDataConfig': {'S3OutputPath': f's3://{S3_BUCKET}/{TRAINING_OUTPUT_S3_KEY}'},
- 'ResourceConfig': RESOURCE_CONFIG,
- 'StoppingCondition': TASK_TIMEOUT,
- 'RoleArn': ROLE_ARN,
- },
-}
-
-
-# This script will be the entrypoint for the docker image which will handle preprocessing the raw data
-# NOTE: The following string must remain dedented as it is being written to a file.
-PREPROCESS_SCRIPT = (
- """
-import boto3
-import numpy as np
-import pandas as pd
-
-def main():
- # Load the Iris dataset from {input_path}/input.csv, split it into train/test
- # subsets, and write them to {output_path}/ for the Processing Operator.
-
- columns = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'species']
- iris = pd.read_csv('{input_path}/input.csv', names=columns)
-
- # Process data
- iris['species'] = iris['species'].replace({{'Iris-virginica': 0, 'Iris-versicolor': 1, 'Iris-setosa': 2}})
- iris = iris[['species', 'sepal_length', 'sepal_width', 'petal_length', 'petal_width']]
-
- # Split into test and train data
- iris_train, iris_test = np.split(
- iris.sample(frac=1, random_state=np.random.RandomState()), [int(0.7 * len(iris))]
- )
-
- # Remove the "answers" from the test set
- iris_test.drop(['species'], axis=1, inplace=True)
-
- # Write the splits to disk
- iris_train.to_csv('{output_path}/train.csv', index=False, header=False)
- iris_test.to_csv('{output_path}/test.csv', index=False, header=False)
-
- print('Preprocessing Done.')
-
-if __name__ == "__main__":
- main()
-
- """
-).format(input_path=PROCESSING_LOCAL_INPUT_PATH, output_path=PROCESSING_LOCAL_OUTPUT_PATH)
-
-
-@task
-def upload_dataset_to_s3():
- """Uploads the provided dataset to a designated Amazon S3 bucket."""
- S3Hook().load_string(
- string_data=DATASET,
- bucket_name=S3_BUCKET,
- key=RAW_DATA_S3_KEY,
- replace=True,
- )
-
-
-@task
-def build_and_upload_docker_image():
- """
- We need a Docker image with the following requirements:
- - Has numpy, pandas, requests, and boto3 installed
- - Has our data preprocessing script mounted and set as the entry point
- """
-
- # Fetch and parse ECR Token to be used for the docker push
- ecr_client = boto3.client('ecr', region_name=REGION)
- token = ecr_client.get_authorization_token()
- credentials = (base64.b64decode(token['authorizationData'][0]['authorizationToken'])).decode('utf-8')
- username, password = credentials.split(':')
-
- with NamedTemporaryFile(mode='w+t') as preprocessing_script, NamedTemporaryFile(mode='w+t') as dockerfile:
-
- preprocessing_script.write(PREPROCESS_SCRIPT)
- preprocessing_script.flush()
-
- dockerfile.write(
- f"""
- FROM amazonlinux
- COPY {preprocessing_script.name.split('/')[2]} /preprocessing.py
- ADD credentials /credentials
- ENV AWS_SHARED_CREDENTIALS_FILE=/credentials
- RUN yum install python3 pip -y
- RUN pip3 install boto3 pandas requests
- CMD [ "python3", "/preprocessing.py"]
- """
- )
- dockerfile.flush()
-
- docker_build_and_push_commands = f"""
- cp /root/.aws/credentials /tmp/credentials &&
- docker build -f {dockerfile.name} -t {ECR_REPOSITORY} /tmp &&
- rm /tmp/credentials &&
- aws ecr get-login-password --region {REGION} |
- docker login --username {username} --password {password} {ECR_REPOSITORY} &&
- docker push {ECR_REPOSITORY}
- """
- docker_build = subprocess.Popen(
- docker_build_and_push_commands,
- shell=True,
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE,
- )
- _, err = docker_build.communicate()
-
- if docker_build.returncode != 0:
- raise RuntimeError(err)
-
-
-@task(trigger_rule='all_done')
-def cleanup():
- # Delete S3 Artifacts
- client = boto3.client('s3')
- object_keys = [
- key['Key'] for key in client.list_objects_v2(Bucket=S3_BUCKET, Prefix=PROJECT_NAME)['Contents']
- ]
- for key in object_keys:
- client.delete_objects(Bucket=S3_BUCKET, Delete={'Objects': [{'Key': key}]})
-
-
-with DAG(
- dag_id='example_sagemaker',
- schedule_interval=None,
- start_date=datetime(2021, 1, 1),
- tags=['example'],
- catchup=False,
-) as dag:
-
- # [START howto_operator_sagemaker_processing]
- preprocess_raw_data = SageMakerProcessingOperator(
- task_id='preprocess_raw_data',
- config=SAGEMAKER_PROCESSING_JOB_CONFIG,
- do_xcom_push=False,
- )
- # [END howto_operator_sagemaker_processing]
-
- # [START howto_operator_sagemaker_training]
- train_model = SageMakerTrainingOperator(
- task_id='train_model',
- config=TRAINING_CONFIG,
- # Waits by default, setting as False to demonstrate the Sensor below.
- wait_for_completion=False,
- do_xcom_push=False,
- )
- # [END howto_operator_sagemaker_training]
-
- # [START howto_sensor_sagemaker_training]
- await_training = SageMakerTrainingSensor(
- task_id='await_training',
- job_name=TRAINING_JOB_NAME,
- )
- # [END howto_sensor_sagemaker_training]
-
- # [START howto_operator_sagemaker_model]
- create_model = SageMakerModelOperator(
- task_id='create_model',
- config=MODEL_CONFIG,
- do_xcom_push=False,
- )
- # [END howto_operator_sagemaker_model]
-
- # [START howto_operator_sagemaker_tuning]
- tune_model = SageMakerTuningOperator(
- task_id='tune_model',
- config=TUNING_CONFIG,
- # Waits by default, setting as False to demonstrate the Sensor below.
- wait_for_completion=False,
- do_xcom_push=False,
- )
- # [END howto_operator_sagemaker_tuning]
-
- # [START howto_sensor_sagemaker_tuning]
- await_tune = SageMakerTuningSensor(
- task_id='await_tuning',
- job_name=TUNING_JOB_NAME,
- )
- # [END howto_sensor_sagemaker_tuning]
-
- # [START howto_operator_sagemaker_transform]
- test_model = SageMakerTransformOperator(
- task_id='test_model',
- config=TRANSFORM_CONFIG,
- # Waits by default, setting as False to demonstrate the Sensor below.
- wait_for_completion=False,
- do_xcom_push=False,
- )
- # [END howto_operator_sagemaker_transform]
-
- # [START howto_sensor_sagemaker_transform]
- await_transform = SageMakerTransformSensor(
- task_id='await_transform',
- job_name=TRANSFORM_JOB_NAME,
- )
- # [END howto_sensor_sagemaker_transform]
-
- # Trigger rule set to "all_done" so clean up will run regardless of success on other tasks.
- # [START howto_operator_sagemaker_delete_model]
- delete_model = SageMakerDeleteModelOperator(
- task_id='delete_model',
- config={'ModelName': MODEL_NAME},
- trigger_rule='all_done',
- )
- # [END howto_operator_sagemaker_delete_model]
-
- (
- upload_dataset_to_s3()
- >> build_and_upload_docker_image()
- >> preprocess_raw_data
- >> train_model
- >> await_training
- >> create_model
- >> tune_model
- >> await_tune
- >> test_model
- >> await_transform
- >> cleanup()
- >> delete_model
- )
diff --git a/airflow/providers/amazon/aws/example_dags/example_sagemaker_endpoint.py b/airflow/providers/amazon/aws/example_dags/example_sagemaker_endpoint.py
deleted file mode 100644
index 52ee5303f2553..0000000000000
--- a/airflow/providers/amazon/aws/example_dags/example_sagemaker_endpoint.py
+++ /dev/null
@@ -1,230 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-import json
-import os
-from datetime import datetime
-
-import boto3
-
-from airflow import DAG
-from airflow.decorators import task
-from airflow.providers.amazon.aws.operators.s3 import S3CreateObjectOperator
-from airflow.providers.amazon.aws.operators.sagemaker import (
- SageMakerDeleteModelOperator,
- SageMakerEndpointConfigOperator,
- SageMakerEndpointOperator,
- SageMakerModelOperator,
- SageMakerTrainingOperator,
-)
-from airflow.providers.amazon.aws.sensors.sagemaker import SageMakerEndpointSensor
-
-# Project name will be used in naming the S3 buckets and various tasks.
-# The dataset used in this example is identifying varieties of the Iris flower.
-PROJECT_NAME = 'iris'
-TIMESTAMP = '{{ ts_nodash }}'
-
-S3_BUCKET = os.getenv('S3_BUCKET', 'S3_bucket')
-ROLE_ARN = os.getenv(
- 'SAGEMAKER_ROLE_ARN',
- 'arn:aws:iam::1234567890:role/service-role/AmazonSageMaker-ExecutionRole',
-)
-INPUT_DATA_S3_KEY = f'{PROJECT_NAME}/processed-input-data'
-TRAINING_OUTPUT_S3_KEY = f'{PROJECT_NAME}/training-results'
-
-MODEL_NAME = f'{PROJECT_NAME}-KNN-model'
-ENDPOINT_NAME = f'{PROJECT_NAME}-endpoint'
-# Job names can't be reused, so appending a timestamp ensures it is unique.
-ENDPOINT_CONFIG_JOB_NAME = f'{PROJECT_NAME}-endpoint-config-{TIMESTAMP}'
-TRAINING_JOB_NAME = f'{PROJECT_NAME}-train-{TIMESTAMP}'
-
-# For an example of how to obtain the following train and test data, please see
-# https://github.com/apache/airflow/blob/main/airflow/providers/amazon/aws/example_dags/example_sagemaker.py
-TRAIN_DATA = '0,4.9,2.5,4.5,1.7\n1,7.0,3.2,4.7,1.4\n0,7.3,2.9,6.3,1.8\n2,5.1,3.5,1.4,0.2\n'
-SAMPLE_TEST_DATA = '6.4,3.2,4.5,1.5'
-
-# The URI of an Amazon-provided docker image for handling KNN model training. This is a public ECR
-# repo cited in public SageMaker documentation, so the account number does not need to be redacted.
-# For more info see: https://docs.aws.amazon.com/sagemaker/latest/dg/ecr-us-west-2.html#knn-us-west-2.title
-KNN_IMAGE_URI = '174872318107.dkr.ecr.us-west-2.amazonaws.com/knn'
-
-# Define configs for processing, training, model creation, and batch transform jobs
-TRAINING_CONFIG = {
- 'TrainingJobName': TRAINING_JOB_NAME,
- 'RoleArn': ROLE_ARN,
- 'AlgorithmSpecification': {
- "TrainingImage": KNN_IMAGE_URI,
- "TrainingInputMode": "File",
- },
- 'HyperParameters': {
- 'predictor_type': 'classifier',
- 'feature_dim': '4',
- 'k': '3',
- 'sample_size': '6',
- },
- 'InputDataConfig': [
- {
- 'ChannelName': 'train',
- 'CompressionType': 'None',
- 'ContentType': 'text/csv',
- 'DataSource': {
- 'S3DataSource': {
- 'S3DataDistributionType': 'FullyReplicated',
- 'S3DataType': 'S3Prefix',
- 'S3Uri': f's3://{S3_BUCKET}/{INPUT_DATA_S3_KEY}/train.csv',
- }
- },
- }
- ],
- 'OutputDataConfig': {'S3OutputPath': f's3://{S3_BUCKET}/{TRAINING_OUTPUT_S3_KEY}/'},
- 'ResourceConfig': {
- 'InstanceCount': 1,
- 'InstanceType': 'ml.m5.large',
- 'VolumeSizeInGB': 1,
- },
- 'StoppingCondition': {'MaxRuntimeInSeconds': 6 * 60},
-}
-
-MODEL_CONFIG = {
- 'ModelName': MODEL_NAME,
- 'ExecutionRoleArn': ROLE_ARN,
- 'PrimaryContainer': {
- 'Mode': 'SingleModel',
- 'Image': KNN_IMAGE_URI,
- 'ModelDataUrl': f's3://{S3_BUCKET}/{TRAINING_OUTPUT_S3_KEY}/{TRAINING_JOB_NAME}/output/model.tar.gz',
- },
-}
-
-ENDPOINT_CONFIG_CONFIG = {
- 'EndpointConfigName': ENDPOINT_CONFIG_JOB_NAME,
- 'ProductionVariants': [
- {
- 'VariantName': f'{PROJECT_NAME}-demo',
- 'ModelName': MODEL_NAME,
- 'InstanceType': 'ml.t2.medium',
- 'InitialInstanceCount': 1,
- },
- ],
-}
-
-DEPLOY_ENDPOINT_CONFIG = {
- 'EndpointName': ENDPOINT_NAME,
- 'EndpointConfigName': ENDPOINT_CONFIG_JOB_NAME,
-}
-
-
-@task
-def call_endpoint():
- runtime = boto3.Session().client('sagemaker-runtime')
-
- response = runtime.invoke_endpoint(
- EndpointName=ENDPOINT_NAME,
- ContentType='text/csv',
- Body=SAMPLE_TEST_DATA,
- )
-
- return json.loads(response["Body"].read().decode())['predictions']
-
-
-@task(trigger_rule='all_done')
-def cleanup():
- # Delete Endpoint and Endpoint Config
- client = boto3.client('sagemaker')
- endpoint_config_name = client.list_endpoint_configs()['EndpointConfigs'][0]['EndpointConfigName']
- client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
- client.delete_endpoint(EndpointName=ENDPOINT_NAME)
-
- # Delete S3 Artifacts
- client = boto3.client('s3')
- object_keys = [
- key['Key'] for key in client.list_objects_v2(Bucket=S3_BUCKET, Prefix=PROJECT_NAME)['Contents']
- ]
- for key in object_keys:
- client.delete_objects(Bucket=S3_BUCKET, Delete={'Objects': [{'Key': key}]})
-
-
-with DAG(
- dag_id='example_sagemaker_endpoint',
- schedule_interval=None,
- start_date=datetime(2021, 1, 1),
- tags=['example'],
- catchup=False,
-) as dag:
-
- upload_data = S3CreateObjectOperator(
- task_id='upload_data',
- s3_bucket=S3_BUCKET,
- s3_key=f'{INPUT_DATA_S3_KEY}/train.csv',
- data=TRAIN_DATA,
- replace=True,
- )
-
- train_model = SageMakerTrainingOperator(
- task_id='train_model',
- config=TRAINING_CONFIG,
- do_xcom_push=False,
- )
-
- create_model = SageMakerModelOperator(
- task_id='create_model',
- config=MODEL_CONFIG,
- do_xcom_push=False,
- )
-
- # [START howto_operator_sagemaker_endpoint_config]
- configure_endpoint = SageMakerEndpointConfigOperator(
- task_id='configure_endpoint',
- config=ENDPOINT_CONFIG_CONFIG,
- do_xcom_push=False,
- )
- # [END howto_operator_sagemaker_endpoint_config]
-
- # [START howto_operator_sagemaker_endpoint]
- deploy_endpoint = SageMakerEndpointOperator(
- task_id='deploy_endpoint',
- config=DEPLOY_ENDPOINT_CONFIG,
- # Waits by default, > train_model
- >> create_model
- >> configure_endpoint
- >> deploy_endpoint
- >> await_endpoint
- >> call_endpoint()
- >> cleanup()
- >> delete_model
- )
diff --git a/airflow/providers/amazon/aws/example_dags/example_salesforce_to_s3.py b/airflow/providers/amazon/aws/example_dags/example_salesforce_to_s3.py
index 735cda4d3af14..067b5fa86ef5f 100644
--- a/airflow/providers/amazon/aws/example_dags/example_salesforce_to_s3.py
+++ b/airflow/providers/amazon/aws/example_dags/example_salesforce_to_s3.py
@@ -14,11 +14,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""
This is a basic example DAG for using `SalesforceToS3Operator` to retrieve Salesforce account
data and upload it to an Amazon S3 bucket.
"""
+from __future__ import annotations
from datetime import datetime
from os import getenv
@@ -32,7 +32,6 @@
with DAG(
dag_id="example_salesforce_to_s3",
- schedule_interval=None,
start_date=datetime(2021, 7, 8),
catchup=False,
tags=["example"],
diff --git a/airflow/providers/amazon/aws/example_dags/example_sftp_to_s3.py b/airflow/providers/amazon/aws/example_dags/example_sftp_to_s3.py
index 0e2407a7d3546..24f480fb2fec3 100644
--- a/airflow/providers/amazon/aws/example_dags/example_sftp_to_s3.py
+++ b/airflow/providers/amazon/aws/example_dags/example_sftp_to_s3.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
+from __future__ import annotations
import os
from datetime import datetime
@@ -27,7 +27,6 @@
with models.DAG(
"example_sftp_to_s3",
- schedule_interval=None,
start_date=datetime(2021, 1, 1),
catchup=False,
) as dag:
diff --git a/airflow/providers/amazon/aws/example_dags/example_sns.py b/airflow/providers/amazon/aws/example_dags/example_sns.py
deleted file mode 100644
index 782156b14c3d3..0000000000000
--- a/airflow/providers/amazon/aws/example_dags/example_sns.py
+++ /dev/null
@@ -1,39 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from datetime import datetime
-from os import environ
-
-from airflow import DAG
-from airflow.providers.amazon.aws.operators.sns import SnsPublishOperator
-
-SNS_TOPIC_ARN = environ.get('SNS_TOPIC_ARN', 'arn:aws:sns:us-west-2:123456789012:dummy-topic-name')
-
-with DAG(
- dag_id='example_sns',
- schedule_interval=None,
- start_date=datetime(2021, 1, 1),
- tags=['example'],
- catchup=False,
-) as dag:
-
- # [START howto_operator_sns_publish_operator]
- publish = SnsPublishOperator(
- task_id='publish_message',
- target_arn=SNS_TOPIC_ARN,
- message='This is a sample message sent to SNS via an Apache Airflow DAG task.',
- )
- # [END howto_operator_sns_publish_operator]
diff --git a/airflow/providers/amazon/aws/example_dags/example_sql_to_s3.py b/airflow/providers/amazon/aws/example_dags/example_sql_to_s3.py
deleted file mode 100644
index df2abee0f3052..0000000000000
--- a/airflow/providers/amazon/aws/example_dags/example_sql_to_s3.py
+++ /dev/null
@@ -1,44 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-
-import os
-from datetime import datetime
-
-from airflow import models
-from airflow.providers.amazon.aws.transfers.sql_to_s3 import SqlToS3Operator
-
-S3_BUCKET = os.environ.get("S3_BUCKET", "test-bucket")
-S3_KEY = os.environ.get("S3_KEY", "key")
-SQL_QUERY = os.environ.get("SQL_QUERY", "SHOW tables")
-
-with models.DAG(
- "example_sql_to_s3",
- schedule_interval=None,
- start_date=datetime(2021, 1, 1),
- catchup=False,
-) as dag:
- # [START howto_transfer_sql_to_s3]
- sql_to_s3_task = SqlToS3Operator(
- task_id="sql_to_s3_task",
- sql_conn_id="mysql_default",
- query=SQL_QUERY,
- s3_bucket=S3_BUCKET,
- s3_key=S3_KEY,
- replace=True,
- )
- # [END howto_transfer_sql_to_s3]
diff --git a/airflow/providers/amazon/aws/example_dags/example_sqs.py b/airflow/providers/amazon/aws/example_dags/example_sqs.py
deleted file mode 100644
index 69ff1e5b75831..0000000000000
--- a/airflow/providers/amazon/aws/example_dags/example_sqs.py
+++ /dev/null
@@ -1,72 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from datetime import datetime
-from os import getenv
-
-from airflow import DAG
-from airflow.decorators import task
-from airflow.models.baseoperator import chain
-from airflow.providers.amazon.aws.hooks.sqs import SqsHook
-from airflow.providers.amazon.aws.operators.sqs import SqsPublishOperator
-from airflow.providers.amazon.aws.sensors.sqs import SqsSensor
-
-QUEUE_NAME = getenv('QUEUE_NAME', 'Airflow-Example-Queue')
-
-
-@task(task_id="create_queue")
-def create_queue_fn():
- """Create the example queue"""
- hook = SqsHook()
- result = hook.create_queue(queue_name=QUEUE_NAME)
- return result['QueueUrl']
-
-
-@task(task_id="delete_queue")
-def delete_queue_fn(queue_url):
- """Delete the example queue"""
- hook = SqsHook()
- hook.get_conn().delete_queue(QueueUrl=queue_url)
-
-
-with DAG(
- dag_id='example_sqs',
- schedule_interval=None,
- start_date=datetime(2021, 1, 1),
- tags=['example'],
- catchup=False,
-) as dag:
-
- create_queue = create_queue_fn()
-
- # [START howto_operator_sqs]
- publish_to_queue = SqsPublishOperator(
- task_id='publish_to_queue',
- sqs_queue=create_queue,
- message_content="{{ task_instance }}-{{ execution_date }}",
- )
- # [END howto_operator_sqs]
-
- # [START howto_sensor_sqs]
- read_from_queue = SqsSensor(
- task_id='read_from_queue',
- sqs_queue=create_queue,
- )
- # [END howto_sensor_sqs]
-
- delete_queue = delete_queue_fn(create_queue)
-
- chain(create_queue, publish_to_queue, read_from_queue, delete_queue)
diff --git a/airflow/providers/amazon/aws/example_dags/example_step_functions.py b/airflow/providers/amazon/aws/example_dags/example_step_functions.py
deleted file mode 100644
index 02763e3ea13f1..0000000000000
--- a/airflow/providers/amazon/aws/example_dags/example_step_functions.py
+++ /dev/null
@@ -1,56 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-from datetime import datetime
-from os import environ
-
-from airflow import DAG
-from airflow.models.baseoperator import chain
-from airflow.providers.amazon.aws.operators.step_function import (
- StepFunctionGetExecutionOutputOperator,
- StepFunctionStartExecutionOperator,
-)
-from airflow.providers.amazon.aws.sensors.step_function import StepFunctionExecutionSensor
-
-STEP_FUNCTIONS_STATE_MACHINE_ARN = environ.get('STEP_FUNCTIONS_STATE_MACHINE_ARN', 'state_machine_arn')
-
-with DAG(
- dag_id='example_step_functions',
- schedule_interval=None,
- start_date=datetime(2021, 1, 1),
- tags=['example'],
- catchup=False,
-) as dag:
-
- # [START howto_operator_step_function_start_execution]
- start_execution = StepFunctionStartExecutionOperator(
- task_id='start_execution', state_machine_arn=STEP_FUNCTIONS_STATE_MACHINE_ARN
- )
- # [END howto_operator_step_function_start_execution]
-
- # [START howto_sensor_step_function_execution]
- wait_for_execution = StepFunctionExecutionSensor(
- task_id='wait_for_execution', execution_arn=start_execution.output
- )
- # [END howto_sensor_step_function_execution]
-
- # [START howto_operator_step_function_get_execution_output]
- get_execution_output = StepFunctionGetExecutionOutputOperator(
- task_id='get_execution_output', execution_arn=start_execution.output
- )
- # [END howto_operator_step_function_get_execution_output]
-
- chain(start_execution, wait_for_execution, get_execution_output)
diff --git a/airflow/providers/amazon/aws/exceptions.py b/airflow/providers/amazon/aws/exceptions.py
index 27f46dfae7563..b606dc504f2dc 100644
--- a/airflow/providers/amazon/aws/exceptions.py
+++ b/airflow/providers/amazon/aws/exceptions.py
@@ -15,10 +15,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-#
+from __future__ import annotations
+
# Note: Any AirflowException raised is expected to cause the TaskInstance
# to be marked in an ERROR state
-import warnings
class EcsTaskFailToStart(Exception):
@@ -42,19 +42,3 @@ def __init__(self, failures: list, message: str):
def __reduce__(self):
return EcsOperatorError, (self.failures, self.message)
-
-
-class ECSOperatorError(EcsOperatorError):
- """
- This class is deprecated.
- Please use :class:`airflow.providers.amazon.aws.exceptions.EcsOperatorError`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- "This class is deprecated. "
- "Please use `airflow.providers.amazon.aws.exceptions.EcsOperatorError`.",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/providers/amazon/aws/hooks/appflow.py b/airflow/providers/amazon/aws/hooks/appflow.py
new file mode 100644
index 0000000000000..3bf57e50e0001
--- /dev/null
+++ b/airflow/providers/amazon/aws/hooks/appflow.py
@@ -0,0 +1,146 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import json
+from datetime import datetime, timezone
+from time import sleep
+from typing import TYPE_CHECKING
+
+from airflow.compat.functools import cached_property
+from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
+
+if TYPE_CHECKING:
+ from mypy_boto3_appflow.client import AppflowClient
+ from mypy_boto3_appflow.type_defs import TaskTypeDef
+
+
+class AppflowHook(AwsBaseHook):
+ """
+ Interact with Amazon Appflow, using the boto3 library.
+ Hook attribute ``conn`` has all methods that listed in documentation.
+
+ .. seealso::
+ - https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/appflow.html
+ - https://docs.aws.amazon.com/appflow/1.0/APIReference/Welcome.html
+
+ Additional arguments (such as ``aws_conn_id`` or ``region_name``) may be specified and
+ are passed down to the underlying AwsBaseHook.
+
+ .. seealso::
+ :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
+
+ """
+
+ EVENTUAL_CONSISTENCY_OFFSET: int = 15 # seconds
+ EVENTUAL_CONSISTENCY_POLLING: int = 10 # seconds
+
+ def __init__(self, *args, **kwargs) -> None:
+ kwargs["client_type"] = "appflow"
+ super().__init__(*args, **kwargs)
+
+ @cached_property
+ def conn(self) -> AppflowClient:
+ """Get the underlying boto3 Appflow client (cached)"""
+ return super().conn
+
+ def run_flow(self, flow_name: str, poll_interval: int = 20) -> str:
+ """
+ Execute an AppFlow run.
+
+ :param flow_name: The flow name
+ :param poll_interval: Time (seconds) to wait between two consecutive calls to check the run status
+ :return: The run execution ID
+ """
+ ts_before: datetime = datetime.now(timezone.utc)
+ sleep(self.EVENTUAL_CONSISTENCY_OFFSET)
+ response_start = self.conn.start_flow(flowName=flow_name)
+ execution_id = response_start["executionId"]
+ self.log.info("executionId: %s", execution_id)
+
+ response_desc = self.conn.describe_flow(flowName=flow_name)
+ last_exec_details = response_desc["lastRunExecutionDetails"]
+
+ # Wait Appflow eventual consistence
+ self.log.info("Waiting for Appflow eventual consistence...")
+ while (
+ response_desc.get("lastRunExecutionDetails", {}).get(
+ "mostRecentExecutionTime", datetime(1970, 1, 1, tzinfo=timezone.utc)
+ )
+ < ts_before
+ ):
+ sleep(self.EVENTUAL_CONSISTENCY_POLLING)
+ response_desc = self.conn.describe_flow(flowName=flow_name)
+ last_exec_details = response_desc["lastRunExecutionDetails"]
+
+ # Wait flow stops
+ self.log.info("Waiting for flow run...")
+ while (
+ "mostRecentExecutionStatus" not in last_exec_details
+ or last_exec_details["mostRecentExecutionStatus"] == "InProgress"
+ ):
+ sleep(poll_interval)
+ response_desc = self.conn.describe_flow(flowName=flow_name)
+ last_exec_details = response_desc["lastRunExecutionDetails"]
+
+ self.log.info("lastRunExecutionDetails: %s", last_exec_details)
+
+ if last_exec_details["mostRecentExecutionStatus"] == "Error":
+ raise Exception(f"Flow error:\n{json.dumps(response_desc, default=str)}")
+
+ return execution_id
+
+ def update_flow_filter(
+ self, flow_name: str, filter_tasks: list[TaskTypeDef], set_trigger_ondemand: bool = False
+ ) -> None:
+ """
+ Update the flow task filter.
+ All filters will be removed if an empty array is passed to filter_tasks.
+
+ :param flow_name: The flow name
+ :param filter_tasks: List flow tasks to be added
+ :param set_trigger_ondemand: If True, set the trigger to on-demand; otherwise, keep the trigger as is
+ :return: None
+ """
+ response = self.conn.describe_flow(flowName=flow_name)
+ connector_type = response["sourceFlowConfig"]["connectorType"]
+ tasks: list[TaskTypeDef] = []
+
+ # cleanup old filter tasks
+ for task in response["tasks"]:
+ if (
+ task["taskType"] == "Filter"
+ and task.get("connectorOperator", {}).get(connector_type) != "PROJECTION"
+ ):
+ self.log.info("Removing task: %s", task)
+ else:
+ tasks.append(task) # List of non-filter tasks
+
+ tasks += filter_tasks # Add the new filter tasks
+
+ if set_trigger_ondemand:
+ # Clean up attribute to force on-demand trigger
+ del response["triggerConfig"]["triggerProperties"]
+
+ self.conn.update_flow(
+ flowName=response["flowName"],
+ destinationFlowConfigList=response["destinationFlowConfigList"],
+ sourceFlowConfig=response["sourceFlowConfig"],
+ triggerConfig=response["triggerConfig"],
+ description=response.get("description", "Flow description."),
+ tasks=tasks,
+ )
diff --git a/airflow/providers/amazon/aws/hooks/athena.py b/airflow/providers/amazon/aws/hooks/athena.py
index 82e69ccf6f001..3c70898a34298 100644
--- a/airflow/providers/amazon/aws/hooks/athena.py
+++ b/airflow/providers/amazon/aws/hooks/athena.py
@@ -15,7 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""
This module contains AWS Athena hook.
@@ -23,9 +22,11 @@
PageIterator
"""
+from __future__ import annotations
+
import warnings
from time import sleep
-from typing import Any, Dict, Optional
+from typing import Any
from botocore.paginate import PageIterator
@@ -46,14 +47,14 @@ class AthenaHook(AwsBaseHook):
"""
INTERMEDIATE_STATES = (
- 'QUEUED',
- 'RUNNING',
+ "QUEUED",
+ "RUNNING",
)
FAILURE_STATES = (
- 'FAILED',
- 'CANCELLED',
+ "FAILED",
+ "CANCELLED",
)
- SUCCESS_STATES = ('SUCCEEDED',)
+ SUCCESS_STATES = ("SUCCEEDED",)
TERMINAL_STATES = (
"SUCCEEDED",
"FAILED",
@@ -61,16 +62,16 @@ class AthenaHook(AwsBaseHook):
)
def __init__(self, *args: Any, sleep_time: int = 30, **kwargs: Any) -> None:
- super().__init__(client_type='athena', *args, **kwargs) # type: ignore
+ super().__init__(client_type="athena", *args, **kwargs) # type: ignore
self.sleep_time = sleep_time
def run_query(
self,
query: str,
- query_context: Dict[str, str],
- result_configuration: Dict[str, Any],
- client_request_token: Optional[str] = None,
- workgroup: str = 'primary',
+ query_context: dict[str, str],
+ result_configuration: dict[str, Any],
+ client_request_token: str | None = None,
+ workgroup: str = "primary",
) -> str:
"""
Run Presto query on athena with provided config and return submitted query_execution_id
@@ -83,17 +84,17 @@ def run_query(
:return: str
"""
params = {
- 'QueryString': query,
- 'QueryExecutionContext': query_context,
- 'ResultConfiguration': result_configuration,
- 'WorkGroup': workgroup,
+ "QueryString": query,
+ "QueryExecutionContext": query_context,
+ "ResultConfiguration": result_configuration,
+ "WorkGroup": workgroup,
}
if client_request_token:
- params['ClientRequestToken'] = client_request_token
+ params["ClientRequestToken"] = client_request_token
response = self.get_conn().start_query_execution(**params)
- return response['QueryExecutionId']
+ return response["QueryExecutionId"]
- def check_query_status(self, query_execution_id: str) -> Optional[str]:
+ def check_query_status(self, query_execution_id: str) -> str | None:
"""
Fetch the status of submitted athena query. Returns None or one of valid query states.
@@ -103,15 +104,15 @@ def check_query_status(self, query_execution_id: str) -> Optional[str]:
response = self.get_conn().get_query_execution(QueryExecutionId=query_execution_id)
state = None
try:
- state = response['QueryExecution']['Status']['State']
+ state = response["QueryExecution"]["Status"]["State"]
except Exception as ex:
- self.log.error('Exception while getting query state %s', ex)
+ self.log.error("Exception while getting query state %s", ex)
finally:
# The error is being absorbed here and is being handled by the caller.
# The error is being absorbed to implement retries.
return state
- def get_state_change_reason(self, query_execution_id: str) -> Optional[str]:
+ def get_state_change_reason(self, query_execution_id: str) -> str | None:
"""
Fetch the reason for a state change (e.g. error message). Returns None or reason string.
@@ -121,17 +122,17 @@ def get_state_change_reason(self, query_execution_id: str) -> Optional[str]:
response = self.get_conn().get_query_execution(QueryExecutionId=query_execution_id)
reason = None
try:
- reason = response['QueryExecution']['Status']['StateChangeReason']
+ reason = response["QueryExecution"]["Status"]["StateChangeReason"]
except Exception as ex:
- self.log.error('Exception while getting query state change reason: %s', ex)
+ self.log.error("Exception while getting query state change reason: %s", ex)
finally:
# The error is being absorbed here and is being handled by the caller.
# The error is being absorbed to implement retries.
return reason
def get_query_results(
- self, query_execution_id: str, next_token_id: Optional[str] = None, max_results: int = 1000
- ) -> Optional[dict]:
+ self, query_execution_id: str, next_token_id: str | None = None, max_results: int = 1000
+ ) -> dict | None:
"""
Fetch submitted athena query results. returns none if query is in intermediate state or
failed/cancelled state else dict of query output
@@ -143,23 +144,23 @@ def get_query_results(
"""
query_state = self.check_query_status(query_execution_id)
if query_state is None:
- self.log.error('Invalid Query state')
+ self.log.error("Invalid Query state")
return None
elif query_state in self.INTERMEDIATE_STATES or query_state in self.FAILURE_STATES:
self.log.error('Query is in "%s" state. Cannot fetch results', query_state)
return None
- result_params = {'QueryExecutionId': query_execution_id, 'MaxResults': max_results}
+ result_params = {"QueryExecutionId": query_execution_id, "MaxResults": max_results}
if next_token_id:
- result_params['NextToken'] = next_token_id
+ result_params["NextToken"] = next_token_id
return self.get_conn().get_query_results(**result_params)
def get_query_results_paginator(
self,
query_execution_id: str,
- max_items: Optional[int] = None,
- page_size: Optional[int] = None,
- starting_token: Optional[str] = None,
- ) -> Optional[PageIterator]:
+ max_items: int | None = None,
+ page_size: int | None = None,
+ starting_token: str | None = None,
+ ) -> PageIterator | None:
"""
Fetch submitted athena query results. returns none if query is in intermediate state or
failed/cancelled state else a paginator to iterate through pages of results. If you
@@ -173,46 +174,66 @@ def get_query_results_paginator(
"""
query_state = self.check_query_status(query_execution_id)
if query_state is None:
- self.log.error('Invalid Query state (null)')
+ self.log.error("Invalid Query state (null)")
return None
if query_state in self.INTERMEDIATE_STATES or query_state in self.FAILURE_STATES:
self.log.error('Query is in "%s" state. Cannot fetch results', query_state)
return None
result_params = {
- 'QueryExecutionId': query_execution_id,
- 'PaginationConfig': {
- 'MaxItems': max_items,
- 'PageSize': page_size,
- 'StartingToken': starting_token,
+ "QueryExecutionId": query_execution_id,
+ "PaginationConfig": {
+ "MaxItems": max_items,
+ "PageSize": page_size,
+ "StartingToken": starting_token,
},
}
- paginator = self.get_conn().get_paginator('get_query_results')
+ paginator = self.get_conn().get_paginator("get_query_results")
return paginator.paginate(**result_params)
- def poll_query_status(self, query_execution_id: str, max_tries: Optional[int] = None) -> Optional[str]:
+ def poll_query_status(
+ self,
+ query_execution_id: str,
+ max_tries: int | None = None,
+ max_polling_attempts: int | None = None,
+ ) -> str | None:
"""
Poll the status of submitted athena query until query state reaches final state.
Returns one of the final states
:param query_execution_id: Id of submitted athena query
- :param max_tries: Number of times to poll for query state before function exits
+ :param max_tries: Deprecated - Use max_polling_attempts instead
+ :param max_polling_attempts: Number of times to poll for query state before function exits
:return: str
"""
+ if max_tries:
+ warnings.warn(
+ f"Passing 'max_tries' to {self.__class__.__name__}.poll_query_status is deprecated "
+ f"and will be removed in a future release. Please use 'max_polling_attempts' instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ if max_polling_attempts and max_polling_attempts != max_tries:
+ raise Exception("max_polling_attempts must be the same value as max_tries")
+ else:
+ max_polling_attempts = max_tries
+
try_number = 1
- final_query_state = None # Query state when query reaches final state or max_tries reached
+ final_query_state = None # Query state when query reaches final state or max_polling_attempts reached
while True:
query_state = self.check_query_status(query_execution_id)
if query_state is None:
- self.log.info('Trial %s: Invalid query state. Retrying again', try_number)
+ self.log.info("Trial %s: Invalid query state. Retrying again", try_number)
elif query_state in self.TERMINAL_STATES:
self.log.info(
- 'Trial %s: Query execution completed. Final state is %s}', try_number, query_state
+ "Trial %s: Query execution completed. Final state is %s}", try_number, query_state
)
final_query_state = query_state
break
else:
- self.log.info('Trial %s: Query is still in non-terminal state - %s', try_number, query_state)
- if max_tries and try_number >= max_tries: # Break loop if max_tries reached
+ self.log.info("Trial %s: Query is still in non-terminal state - %s", try_number, query_state)
+ if (
+ max_polling_attempts and try_number >= max_polling_attempts
+ ): # Break loop if max_polling_attempts reached
final_query_state = query_state
break
try_number += 1
@@ -233,7 +254,7 @@ def get_output_location(self, query_execution_id: str) -> str:
if response:
try:
- output_location = response['QueryExecution']['ResultConfiguration']['OutputLocation']
+ output_location = response["QueryExecution"]["ResultConfiguration"]["OutputLocation"]
except KeyError:
self.log.error("Error retrieving OutputLocation")
raise
@@ -244,7 +265,7 @@ def get_output_location(self, query_execution_id: str) -> str:
return output_location
- def stop_query(self, query_execution_id: str) -> Dict:
+ def stop_query(self, query_execution_id: str) -> dict:
"""
Cancel the submitted athena query
@@ -252,18 +273,3 @@ def stop_query(self, query_execution_id: str) -> Dict:
:return: dict
"""
return self.get_conn().stop_query_execution(QueryExecutionId=query_execution_id)
-
-
-class AWSAthenaHook(AthenaHook):
- """
- This hook is deprecated.
- Please use :class:`airflow.providers.amazon.aws.hooks.athena.AthenaHook`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- "This hook is deprecated. Please use `airflow.providers.amazon.aws.hooks.athena.AthenaHook`.",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/providers/amazon/aws/hooks/aws_dynamodb.py b/airflow/providers/amazon/aws/hooks/aws_dynamodb.py
deleted file mode 100644
index dedb80073e3e5..0000000000000
--- a/airflow/providers/amazon/aws/hooks/aws_dynamodb.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.hooks.dynamodb`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.hooks.dynamodb import AwsDynamoDBHook # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.dynamodb`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/providers/amazon/aws/hooks/base_aws.py b/airflow/providers/amazon/aws/hooks/base_aws.py
index bbf0bfff83e91..b902d102cb346 100644
--- a/airflow/providers/amazon/aws/hooks/base_aws.py
+++ b/airflow/providers/amazon/aws/hooks/base_aws.py
@@ -15,44 +15,46 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""
This module contains Base AWS Hook.
.. seealso::
For more information on how to use this hook, take a look at the guide:
- :ref:`howto/connection:AWSHook`
+ :ref:`howto/connection:aws`
"""
+from __future__ import annotations
-import configparser
import datetime
+import json
import logging
-import sys
import warnings
from functools import wraps
-from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
+from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar, Union
import boto3
import botocore
import botocore.session
import requests
import tenacity
+from botocore.client import ClientMeta
from botocore.config import Config
from botocore.credentials import ReadOnlyCredentials
-from slugify import slugify
-
-if sys.version_info >= (3, 8):
- from functools import cached_property
-else:
- from cached_property import cached_property
-
from dateutil.tz import tzlocal
+from slugify import slugify
+from airflow.compat.functools import cached_property
from airflow.configuration import conf
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, AirflowNotFoundException
from airflow.hooks.base import BaseHook
-from airflow.models.connection import Connection
+from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper
from airflow.utils.log.logging_mixin import LoggingMixin
+from airflow.utils.log.secrets_masker import mask_secret
+
+BaseAwsConnection = TypeVar("BaseAwsConnection", bound=Union[boto3.client, boto3.resource])
+
+
+if TYPE_CHECKING:
+ from airflow.models.connection import Connection # Avoid circular imports.
class BaseSessionFactory(LoggingMixin):
@@ -67,73 +69,85 @@ class BaseSessionFactory(LoggingMixin):
:ref:`howto/connection:aws:session-factory`
"""
- def __init__(self, conn: Connection, region_name: Optional[str], config: Config) -> None:
+ def __init__(
+ self,
+ conn: Connection | AwsConnectionWrapper | None,
+ region_name: str | None = None,
+ config: Config | None = None,
+ ) -> None:
super().__init__()
- self.conn = conn
- self.region_name = region_name
- self.config = config
- self.extra_config = self.conn.extra_dejson
+ self._conn = conn
+ self._region_name = region_name
+ self._config = config
- self.basic_session: Optional[boto3.session.Session] = None
- self.role_arn: Optional[str] = None
+ @cached_property
+ def conn(self) -> AwsConnectionWrapper:
+ """Cached AWS Connection Wrapper."""
+ return AwsConnectionWrapper(
+ conn=self._conn,
+ region_name=self._region_name,
+ botocore_config=self._config,
+ )
+
+ @cached_property
+ def basic_session(self) -> boto3.session.Session:
+ """Cached property with basic boto3.session.Session."""
+ return self._create_basic_session(session_kwargs=self.conn.session_kwargs)
+
+ @property
+ def extra_config(self) -> dict[str, Any]:
+ """AWS Connection extra_config."""
+ return self.conn.extra_config
+
+ @property
+ def region_name(self) -> str | None:
+ """AWS Region Name read-only property."""
+ return self.conn.region_name
+
+ @property
+ def config(self) -> Config | None:
+ """Configuration for botocore client read-only property."""
+ return self.conn.botocore_config
+
+ @property
+ def role_arn(self) -> str | None:
+ """Assume Role ARN from AWS Connection"""
+ return self.conn.role_arn
def create_session(self) -> boto3.session.Session:
- """Create AWS session."""
- session_kwargs = {}
- if "session_kwargs" in self.extra_config:
+ """Create boto3 Session from connection config."""
+ if not self.conn:
self.log.info(
- "Retrieving session_kwargs from Connection.extra_config['session_kwargs']: %s",
- self.extra_config["session_kwargs"],
+ "No connection ID provided. Fallback on boto3 credential strategy (region_name=%r). "
+ "See: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html",
+ self.region_name,
)
- session_kwargs = self.extra_config["session_kwargs"]
- self.basic_session = self._create_basic_session(session_kwargs=session_kwargs)
- self.role_arn = self._read_role_arn_from_extra_config()
- # If role_arn was specified then STS + assume_role
- if self.role_arn is None:
+ return boto3.session.Session(region_name=self.region_name)
+ elif not self.role_arn:
return self.basic_session
- return self._create_session_with_assume_role(session_kwargs=session_kwargs)
-
- def _get_region_name(self) -> Optional[str]:
- region_name = self.region_name
- if self.region_name is None and 'region_name' in self.extra_config:
- self.log.info("Retrieving region_name from Connection.extra_config['region_name']")
- region_name = self.extra_config["region_name"]
- return region_name
-
- def _create_basic_session(self, session_kwargs: Dict[str, Any]) -> boto3.session.Session:
- aws_access_key_id, aws_secret_access_key = self._read_credentials_from_connection()
- aws_session_token = self.extra_config.get("aws_session_token")
- region_name = self._get_region_name()
- self.log.debug(
- "Creating session with aws_access_key_id=%s region_name=%s",
- aws_access_key_id,
- region_name,
- )
-
- return boto3.session.Session(
- aws_access_key_id=aws_access_key_id,
- aws_secret_access_key=aws_secret_access_key,
- region_name=region_name,
- aws_session_token=aws_session_token,
- **session_kwargs,
- )
-
- def _create_session_with_assume_role(self, session_kwargs: Dict[str, Any]) -> boto3.session.Session:
- assume_role_method = self.extra_config.get('assume_role_method', 'assume_role')
- self.log.debug("assume_role_method=%s", assume_role_method)
- supported_methods = ['assume_role', 'assume_role_with_saml', 'assume_role_with_web_identity']
- if assume_role_method not in supported_methods:
- raise NotImplementedError(
- f'assume_role_method={assume_role_method} in Connection {self.conn.conn_id} Extra.'
- f'Currently {supported_methods} are supported.'
- '(Exclude this setting will default to "assume_role").'
- )
- if assume_role_method == 'assume_role_with_web_identity':
+ # Values stored in ``AwsConnectionWrapper.session_kwargs`` are intended to be used only
+ # to create the initial boto3 session.
+ # If the user wants to use the 'assume_role' mechanism then only the 'region_name' needs to be
+ # provided, otherwise other parameters might conflict with the base botocore session.
+ # Unfortunately it is not a part of public boto3 API, see source of boto3.session.Session:
+ # https://boto3.amazonaws.com/v1/documentation/api/latest/_modules/boto3/session.html#Session
+ # If we provide 'aws_access_key_id' or 'aws_secret_access_key' or 'aws_session_token'
+ # as part of session kwargs it will use them instead of assumed credentials.
+ assume_session_kwargs = {}
+ if self.conn.region_name:
+ assume_session_kwargs["region_name"] = self.conn.region_name
+ return self._create_session_with_assume_role(session_kwargs=assume_session_kwargs)
+
+ def _create_basic_session(self, session_kwargs: dict[str, Any]) -> boto3.session.Session:
+ return boto3.session.Session(**session_kwargs)
+
+ def _create_session_with_assume_role(self, session_kwargs: dict[str, Any]) -> boto3.session.Session:
+ if self.conn.assume_role_method == "assume_role_with_web_identity":
# Deferred credentials have no initial credentials
credential_fetcher = self._get_web_identity_credential_fetcher()
credentials = botocore.credentials.DeferredRefreshableCredentials(
- method='assume-role-with-web-identity',
+ method="assume-role-with-web-identity",
refresh_using=credential_fetcher.fetch_credentials,
time_fetcher=lambda: datetime.datetime.now(tz=tzlocal()),
)
@@ -144,41 +158,34 @@ def _create_session_with_assume_role(self, session_kwargs: Dict[str, Any]) -> bo
refresh_using=self._refresh_credentials,
method="sts-assume-role",
)
+
session = botocore.session.get_session()
session._credentials = credentials
-
- if self.basic_session is None:
- raise RuntimeError("The basic session should be created here!")
-
region_name = self.basic_session.region_name
session.set_config_variable("region", region_name)
return boto3.session.Session(botocore_session=session, **session_kwargs)
- def _refresh_credentials(self) -> Dict[str, Any]:
- self.log.debug('Refreshing credentials')
- assume_role_method = self.extra_config.get('assume_role_method', 'assume_role')
- sts_session = self.basic_session
-
- if sts_session is None:
- raise RuntimeError(
- "Session should be initialized when refresh credentials with assume_role is used!"
- )
+ def _refresh_credentials(self) -> dict[str, Any]:
+ self.log.debug("Refreshing credentials")
+ assume_role_method = self.conn.assume_role_method
+ if assume_role_method not in ("assume_role", "assume_role_with_saml"):
+ raise NotImplementedError(f"assume_role_method={assume_role_method} not expected")
- sts_client = sts_session.client("sts", config=self.config)
+ sts_client = self.basic_session.client("sts", config=self.config)
- if assume_role_method == 'assume_role':
+ if assume_role_method == "assume_role":
sts_response = self._assume_role(sts_client=sts_client)
- elif assume_role_method == 'assume_role_with_saml':
- sts_response = self._assume_role_with_saml(sts_client=sts_client)
else:
- raise NotImplementedError(f'assume_role_method={assume_role_method} not expected')
- sts_response_http_status = sts_response['ResponseMetadata']['HTTPStatusCode']
- if not sts_response_http_status == 200:
- raise RuntimeError(f'sts_response_http_status={sts_response_http_status}')
- credentials = sts_response['Credentials']
- expiry_time = credentials.get('Expiration').isoformat()
- self.log.debug('New credentials expiry_time: %s', expiry_time)
+ sts_response = self._assume_role_with_saml(sts_client=sts_client)
+
+ sts_response_http_status = sts_response["ResponseMetadata"]["HTTPStatusCode"]
+ if sts_response_http_status != 200:
+ raise RuntimeError(f"sts_response_http_status={sts_response_http_status}")
+
+ credentials = sts_response["Credentials"]
+ expiry_time = credentials.get("Expiration").isoformat()
+ self.log.debug("New credentials expiry_time: %s", expiry_time)
credentials = {
"access_key": credentials.get("AccessKeyId"),
"secret_key": credentials.get("SecretAccessKey"),
@@ -187,77 +194,37 @@ def _refresh_credentials(self) -> Dict[str, Any]:
}
return credentials
- def _read_role_arn_from_extra_config(self) -> Optional[str]:
- aws_account_id = self.extra_config.get("aws_account_id")
- aws_iam_role = self.extra_config.get("aws_iam_role")
- role_arn = self.extra_config.get("role_arn")
- if role_arn is None and aws_account_id is not None and aws_iam_role is not None:
- self.log.info("Constructing role_arn from aws_account_id and aws_iam_role")
- role_arn = f"arn:aws:iam::{aws_account_id}:role/{aws_iam_role}"
- self.log.debug("role_arn is %s", role_arn)
- return role_arn
-
- def _read_credentials_from_connection(self) -> Tuple[Optional[str], Optional[str]]:
- aws_access_key_id = None
- aws_secret_access_key = None
- if self.conn.login:
- aws_access_key_id = self.conn.login
- aws_secret_access_key = self.conn.password
- self.log.info("Credentials retrieved from login")
- elif "aws_access_key_id" in self.extra_config and "aws_secret_access_key" in self.extra_config:
- aws_access_key_id = self.extra_config["aws_access_key_id"]
- aws_secret_access_key = self.extra_config["aws_secret_access_key"]
- self.log.info("Credentials retrieved from extra_config")
- elif "s3_config_file" in self.extra_config:
- aws_access_key_id, aws_secret_access_key = _parse_s3_config(
- self.extra_config["s3_config_file"],
- self.extra_config.get("s3_config_format"),
- self.extra_config.get("profile"),
- )
- self.log.info("Credentials retrieved from extra_config['s3_config_file']")
- return aws_access_key_id, aws_secret_access_key
-
- def _strip_invalid_session_name_characters(self, role_session_name: str) -> str:
- return slugify(role_session_name, regex_pattern=r'[^\w+=,.@-]+')
-
- def _assume_role(self, sts_client: boto3.client) -> Dict:
- assume_role_kwargs = self.extra_config.get("assume_role_kwargs", {})
- if "external_id" in self.extra_config: # Backwards compatibility
- assume_role_kwargs["ExternalId"] = self.extra_config.get("external_id")
- role_session_name = self._strip_invalid_session_name_characters(f"Airflow_{self.conn.conn_id}")
- self.log.debug(
- "Doing sts_client.assume_role to role_arn=%s (role_session_name=%s)",
- self.role_arn,
- role_session_name,
- )
- return sts_client.assume_role(
- RoleArn=self.role_arn, RoleSessionName=role_session_name, **assume_role_kwargs
- )
+ def _assume_role(self, sts_client: boto3.client) -> dict:
+ kw = {
+ "RoleSessionName": self._strip_invalid_session_name_characters(f"Airflow_{self.conn.conn_id}"),
+ **self.conn.assume_role_kwargs,
+ "RoleArn": self.role_arn,
+ }
+ return sts_client.assume_role(**kw)
- def _assume_role_with_saml(self, sts_client: boto3.client) -> Dict[str, Any]:
- saml_config = self.extra_config['assume_role_with_saml']
- principal_arn = saml_config['principal_arn']
+ def _assume_role_with_saml(self, sts_client: boto3.client) -> dict[str, Any]:
+ saml_config = self.extra_config["assume_role_with_saml"]
+ principal_arn = saml_config["principal_arn"]
- idp_auth_method = saml_config['idp_auth_method']
- if idp_auth_method == 'http_spegno_auth':
+ idp_auth_method = saml_config["idp_auth_method"]
+ if idp_auth_method == "http_spegno_auth":
saml_assertion = self._fetch_saml_assertion_using_http_spegno_auth(saml_config)
else:
raise NotImplementedError(
- f'idp_auth_method={idp_auth_method} in Connection {self.conn.conn_id} Extra.'
+ f"idp_auth_method={idp_auth_method} in Connection {self.conn.conn_id} Extra."
'Currently only "http_spegno_auth" is supported, and must be specified.'
)
self.log.debug("Doing sts_client.assume_role_with_saml to role_arn=%s", self.role_arn)
- assume_role_kwargs = self.extra_config.get("assume_role_kwargs", {})
return sts_client.assume_role_with_saml(
RoleArn=self.role_arn,
PrincipalArn=principal_arn,
SAMLAssertion=saml_assertion,
- **assume_role_kwargs,
+ **self.conn.assume_role_kwargs,
)
def _get_idp_response(
- self, saml_config: Dict[str, Any], auth: requests.auth.AuthBase
+ self, saml_config: dict[str, Any], auth: requests.auth.AuthBase
) -> requests.models.Response:
idp_url = saml_config["idp_url"]
self.log.debug("idp_url= %s", idp_url)
@@ -285,7 +252,7 @@ def _get_idp_response(
return idp_response
- def _fetch_saml_assertion_using_http_spegno_auth(self, saml_config: Dict[str, Any]) -> str:
+ def _fetch_saml_assertion_using_http_spegno_auth(self, saml_config: dict[str, Any]) -> str:
# requests_gssapi will need paramiko > 2.6 since you'll need
# 'gssapi' not 'python-gssapi' from PyPi.
# https://github.com/paramiko/paramiko/pull/1311
@@ -293,32 +260,32 @@ def _fetch_saml_assertion_using_http_spegno_auth(self, saml_config: Dict[str, An
from lxml import etree
auth = requests_gssapi.HTTPSPNEGOAuth()
- if 'mutual_authentication' in saml_config:
- mutual_auth = saml_config['mutual_authentication']
- if mutual_auth == 'REQUIRED':
+ if "mutual_authentication" in saml_config:
+ mutual_auth = saml_config["mutual_authentication"]
+ if mutual_auth == "REQUIRED":
auth = requests_gssapi.HTTPSPNEGOAuth(requests_gssapi.REQUIRED)
- elif mutual_auth == 'OPTIONAL':
+ elif mutual_auth == "OPTIONAL":
auth = requests_gssapi.HTTPSPNEGOAuth(requests_gssapi.OPTIONAL)
- elif mutual_auth == 'DISABLED':
+ elif mutual_auth == "DISABLED":
auth = requests_gssapi.HTTPSPNEGOAuth(requests_gssapi.DISABLED)
else:
raise NotImplementedError(
- f'mutual_authentication={mutual_auth} in Connection {self.conn.conn_id} Extra.'
+ f"mutual_authentication={mutual_auth} in Connection {self.conn.conn_id} Extra."
'Currently "REQUIRED", "OPTIONAL" and "DISABLED" are supported.'
- '(Exclude this setting will default to HTTPSPNEGOAuth() ).'
+ "(Exclude this setting will default to HTTPSPNEGOAuth() )."
)
# Query the IDP
idp_response = self._get_idp_response(saml_config, auth=auth)
# Assist with debugging. Note: contains sensitive info!
- xpath = saml_config['saml_response_xpath']
- log_idp_response = 'log_idp_response' in saml_config and saml_config['log_idp_response']
+ xpath = saml_config["saml_response_xpath"]
+ log_idp_response = "log_idp_response" in saml_config and saml_config["log_idp_response"]
if log_idp_response:
self.log.warning(
- 'The IDP response contains sensitive information, but log_idp_response is ON (%s).',
+ "The IDP response contains sensitive information, but log_idp_response is ON (%s).",
log_idp_response,
)
- self.log.debug('idp_response.content= %s', idp_response.content)
- self.log.debug('xpath= %s', xpath)
+ self.log.debug("idp_response.content= %s", idp_response.content)
+ self.log.debug("xpath= %s", xpath)
# Extract SAML Assertion from the returned HTML / XML
xml = etree.fromstring(idp_response.content)
saml_assertion = xml.xpath(xpath)
@@ -326,29 +293,26 @@ def _fetch_saml_assertion_using_http_spegno_auth(self, saml_config: Dict[str, An
if len(saml_assertion) == 1:
saml_assertion = saml_assertion[0]
if not saml_assertion:
- raise ValueError('Invalid SAML Assertion')
+ raise ValueError("Invalid SAML Assertion")
return saml_assertion
def _get_web_identity_credential_fetcher(
self,
) -> botocore.credentials.AssumeRoleWithWebIdentityCredentialFetcher:
- if self.basic_session is None:
- raise Exception("Session should be set where identity is fetched!")
base_session = self.basic_session._session or botocore.session.get_session()
client_creator = base_session.create_client
- federation = self.extra_config.get('assume_role_with_web_identity_federation')
- if federation == 'google':
+ federation = self.extra_config.get("assume_role_with_web_identity_federation")
+ if federation == "google":
web_identity_token_loader = self._get_google_identity_token_loader()
else:
raise AirflowException(
f'Unsupported federation: {federation}. Currently "google" only are supported.'
)
- assume_role_kwargs = self.extra_config.get("assume_role_kwargs", {})
return botocore.credentials.AssumeRoleWithWebIdentityCredentialFetcher(
client_creator=client_creator,
web_identity_token_loader=web_identity_token_loader,
role_arn=self.role_arn,
- extra_args=assume_role_kwargs,
+ extra_args=self.conn.assume_role_kwargs,
)
def _get_google_identity_token_loader(self):
@@ -358,7 +322,7 @@ def _get_google_identity_token_loader(self):
get_default_id_token_credentials,
)
- audience = self.extra_config.get('assume_role_with_web_identity_federation_audience')
+ audience = self.extra_config.get("assume_role_with_web_identity_federation_audience")
google_id_token_credentials = get_default_id_token_credentials(target_audience=audience)
@@ -370,8 +334,39 @@ def web_identity_token_loader():
return web_identity_token_loader
+ def _strip_invalid_session_name_characters(self, role_session_name: str) -> str:
+ return slugify(role_session_name, regex_pattern=r"[^\w+=,.@-]+")
+
+ def _get_region_name(self) -> str | None:
+ warnings.warn(
+ "`BaseSessionFactory._get_region_name` method deprecated and will be removed "
+ "in a future releases. Please use `BaseSessionFactory.region_name` property instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ return self.region_name
+
+ def _read_role_arn_from_extra_config(self) -> str | None:
+ warnings.warn(
+ "`BaseSessionFactory._read_role_arn_from_extra_config` method deprecated and will be removed "
+ "in a future releases. Please use `BaseSessionFactory.role_arn` property instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ return self.role_arn
+
+ def _read_credentials_from_connection(self) -> tuple[str | None, str | None]:
+ warnings.warn(
+ "`BaseSessionFactory._read_credentials_from_connection` method deprecated and will be removed "
+ "in a future releases. Please use `BaseSessionFactory.conn.aws_access_key_id` and "
+ "`BaseSessionFactory.aws_secret_access_key` properties instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ return self.conn.aws_access_key_id, self.conn.aws_secret_access_key
-class AwsBaseHook(BaseHook):
+
+class AwsGenericHook(BaseHook, Generic[BaseAwsConnection]):
"""
Interact with AWS.
This class is a thin wrapper around the boto3 python library.
@@ -381,147 +376,152 @@ class AwsBaseHook(BaseHook):
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
- :param verify: Whether or not to verify SSL certificates.
+ :param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param client_type: boto3.client client_type. Eg 's3', 'emr' etc
:param resource_type: boto3.resource resource_type. Eg 'dynamodb' etc
- :param config: Configuration for botocore client.
- (https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html)
+ :param config: Configuration for botocore client. See:
+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""
- conn_name_attr = 'aws_conn_id'
- default_conn_name = 'aws_default'
- conn_type = 'aws'
- hook_name = 'Amazon Web Services'
+ conn_name_attr = "aws_conn_id"
+ default_conn_name = "aws_default"
+ conn_type = "aws"
+ hook_name = "Amazon Web Services"
def __init__(
self,
- aws_conn_id: Optional[str] = default_conn_name,
- verify: Union[bool, str, None] = None,
- region_name: Optional[str] = None,
- client_type: Optional[str] = None,
- resource_type: Optional[str] = None,
- config: Optional[Config] = None,
+ aws_conn_id: str | None = default_conn_name,
+ verify: bool | str | None = None,
+ region_name: str | None = None,
+ client_type: str | None = None,
+ resource_type: str | None = None,
+ config: Config | None = None,
) -> None:
super().__init__()
self.aws_conn_id = aws_conn_id
- self.verify = verify
self.client_type = client_type
self.resource_type = resource_type
- self.region_name = region_name
- self.config = config
-
- if not (self.client_type or self.resource_type):
- raise AirflowException('Either client_type or resource_type must be provided.')
-
- def _get_credentials(self, region_name: Optional[str]) -> Tuple[boto3.session.Session, Optional[str]]:
- if not self.aws_conn_id:
- session = boto3.session.Session(region_name=region_name)
- return session, None
+ self._region_name = region_name
+ self._config = config
+ self._verify = verify
- self.log.debug("Airflow Connection: aws_conn_id=%s", self.aws_conn_id)
-
- try:
- # Fetch the Airflow connection object
- connection_object = self.get_connection(self.aws_conn_id)
- extra_config = connection_object.extra_dejson
- endpoint_url = extra_config.get("host")
-
- # https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html#botocore.config.Config
- if "config_kwargs" in extra_config:
- self.log.debug(
- "Retrieving config_kwargs from Connection.extra_config['config_kwargs']: %s",
- extra_config["config_kwargs"],
+ @cached_property
+ def conn_config(self) -> AwsConnectionWrapper:
+ """Get the Airflow Connection object and wrap it in helper (cached)."""
+ connection = None
+ if self.aws_conn_id:
+ try:
+ connection = self.get_connection(self.aws_conn_id)
+ except AirflowNotFoundException:
+ warnings.warn(
+ f"Unable to find AWS Connection ID '{self.aws_conn_id}', switching to empty. "
+ "This behaviour is deprecated and will be removed in a future releases. "
+ "Please provide existed AWS connection ID or if required boto3 credential strategy "
+ "explicit set AWS Connection ID to None.",
+ DeprecationWarning,
+ stacklevel=2,
)
- self.config = Config(**extra_config["config_kwargs"])
- session = SessionFactory(
- conn=connection_object, region_name=region_name, config=self.config
- ).create_session()
+ return AwsConnectionWrapper(
+ conn=connection, region_name=self._region_name, botocore_config=self._config, verify=self._verify
+ )
- return session, endpoint_url
+ @property
+ def region_name(self) -> str | None:
+ """AWS Region Name read-only property."""
+ return self.conn_config.region_name
- except AirflowException:
- self.log.warning("Unable to use Airflow Connection for credentials.")
- self.log.debug("Fallback on boto3 credential strategy")
- # http://boto3.readthedocs.io/en/latest/guide/configuration.html
+ @property
+ def config(self) -> Config | None:
+ """Configuration for botocore client read-only property."""
+ return self.conn_config.botocore_config
- self.log.debug(
- "Creating session using boto3 credential strategy region_name=%s",
- region_name,
- )
- session = boto3.session.Session(region_name=region_name)
- return session, None
+ @property
+ def verify(self) -> bool | str | None:
+ """Verify or not SSL certificates boto3 client/resource read-only property."""
+ return self.conn_config.verify
+
+ def get_session(self, region_name: str | None = None) -> boto3.session.Session:
+ """Get the underlying boto3.session.Session(region_name=region_name)."""
+ return SessionFactory(
+ conn=self.conn_config, region_name=region_name, config=self.config
+ ).create_session()
def get_client_type(
self,
- client_type: Optional[str] = None,
- region_name: Optional[str] = None,
- config: Optional[Config] = None,
+ region_name: str | None = None,
+ config: Config | None = None,
) -> boto3.client:
"""Get the underlying boto3 client using boto3 session"""
- session, endpoint_url = self._get_credentials(region_name=region_name)
-
- if client_type:
- warnings.warn(
- "client_type is deprecated. Set client_type from class attribute.",
- DeprecationWarning,
- stacklevel=2,
- )
- else:
- client_type = self.client_type
+ client_type = self.client_type
# No AWS Operators use the config argument to this method.
# Keep backward compatibility with other users who might use it
if config is None:
config = self.config
- return session.client(client_type, endpoint_url=endpoint_url, config=config, verify=self.verify)
+ session = self.get_session(region_name=region_name)
+ return session.client(
+ client_type, endpoint_url=self.conn_config.endpoint_url, config=config, verify=self.verify
+ )
def get_resource_type(
self,
- resource_type: Optional[str] = None,
- region_name: Optional[str] = None,
- config: Optional[Config] = None,
+ region_name: str | None = None,
+ config: Config | None = None,
) -> boto3.resource:
"""Get the underlying boto3 resource using boto3 session"""
- session, endpoint_url = self._get_credentials(region_name=region_name)
-
- if resource_type:
- warnings.warn(
- "resource_type is deprecated. Set resource_type from class attribute.",
- DeprecationWarning,
- stacklevel=2,
- )
- else:
- resource_type = self.resource_type
+ resource_type = self.resource_type
# No AWS Operators use the config argument to this method.
# Keep backward compatibility with other users who might use it
if config is None:
config = self.config
- return session.resource(resource_type, endpoint_url=endpoint_url, config=config, verify=self.verify)
+ session = self.get_session(region_name=region_name)
+ return session.resource(
+ resource_type, endpoint_url=self.conn_config.endpoint_url, config=config, verify=self.verify
+ )
@cached_property
- def conn(self) -> Union[boto3.client, boto3.resource]:
+ def conn(self) -> BaseAwsConnection:
"""
Get the underlying boto3 client/resource (cached)
:return: boto3.client or boto3.resource
- :rtype: Union[boto3.client, boto3.resource]
"""
- if self.client_type:
+ if not ((not self.client_type) ^ (not self.resource_type)):
+ raise ValueError(
+ f"Either client_type={self.client_type!r} or "
+ f"resource_type={self.resource_type!r} must be provided, not both."
+ )
+ elif self.client_type:
return self.get_client_type(region_name=self.region_name)
- elif self.resource_type:
- return self.get_resource_type(region_name=self.region_name)
else:
- # Rare possibility - subclasses have not specified a client_type or resource_type
- raise NotImplementedError('Could not get boto3 connection!')
+ return self.get_resource_type(region_name=self.region_name)
- def get_conn(self) -> Union[boto3.client, boto3.resource]:
+ @cached_property
+ def conn_client_meta(self) -> ClientMeta:
+ """Get botocore client metadata from Hook connection (cached)."""
+ conn = self.conn
+ if isinstance(conn, botocore.client.BaseClient):
+ return conn.meta
+ return conn.meta.client.meta
+
+ @property
+ def conn_region_name(self) -> str:
+ """Get actual AWS Region Name from Hook connection (cached)."""
+ return self.conn_client_meta.region_name
+
+ @property
+ def conn_partition(self) -> str:
+ """Get associated AWS Region Partition from Hook connection (cached)."""
+ return self.conn_client_meta.partition
+
+ def get_conn(self) -> BaseAwsConnection:
"""
Get the underlying boto3 client/resource (cached)
@@ -529,29 +529,27 @@ def get_conn(self) -> Union[boto3.client, boto3.resource]:
with subclasses that rely on a super().get_conn() method.
:return: boto3.client or boto3.resource
- :rtype: Union[boto3.client, boto3.resource]
"""
# Compat shim
return self.conn
- def get_session(self, region_name: Optional[str] = None) -> boto3.session.Session:
- """Get the underlying boto3.session."""
- session, _ = self._get_credentials(region_name=region_name)
- return session
-
- def get_credentials(self, region_name: Optional[str] = None) -> ReadOnlyCredentials:
+ def get_credentials(self, region_name: str | None = None) -> ReadOnlyCredentials:
"""
Get the underlying `botocore.Credentials` object.
This contains the following authentication attributes: access_key, secret_key and token.
+ By use this method also secret_key and token will mask in tasks logs.
"""
- session, _ = self._get_credentials(region_name=region_name)
# Credentials are refreshable, so accessing your access key and
# secret key separately can lead to a race condition.
# See https://stackoverflow.com/a/36291428/8283373
- return session.get_credentials().get_frozen_credentials()
+ creds = self.get_session(region_name=region_name).get_credentials().get_frozen_credentials()
+ mask_secret(creds.secret_key)
+ if creds.token:
+ mask_secret(creds.token)
+ return creds
- def expand_role(self, role: str, region_name: Optional[str] = None) -> str:
+ def expand_role(self, role: str, region_name: str | None = None) -> str:
"""
If the IAM role is a role name, get the Amazon Resource Name (ARN) for the role.
If IAM role is already an IAM role ARN, no change is made.
@@ -563,8 +561,10 @@ def expand_role(self, role: str, region_name: Optional[str] = None) -> str:
if "/" in role:
return role
else:
- session, endpoint_url = self._get_credentials(region_name=region_name)
- _client = session.client('iam', endpoint_url=endpoint_url, config=self.config, verify=self.verify)
+ session = self.get_session(region_name=region_name)
+ _client = session.client(
+ "iam", endpoint_url=self.conn_config.endpoint_url, config=self.config, verify=self.verify
+ )
return _client.get_role(RoleName=role)["Role"]["Arn"]
@staticmethod
@@ -577,21 +577,21 @@ def retry(should_retry: Callable[[Exception], bool]):
def retry_decorator(fun: Callable):
@wraps(fun)
def decorator_f(self, *args, **kwargs):
- retry_args = getattr(self, 'retry_args', None)
+ retry_args = getattr(self, "retry_args", None)
if retry_args is None:
return fun(self, *args, **kwargs)
- multiplier = retry_args.get('multiplier', 1)
- min_limit = retry_args.get('min', 1)
- max_limit = retry_args.get('max', 1)
- stop_after_delay = retry_args.get('stop_after_delay', 10)
+ multiplier = retry_args.get("multiplier", 1)
+ min_limit = retry_args.get("min", 1)
+ max_limit = retry_args.get("max", 1)
+ stop_after_delay = retry_args.get("stop_after_delay", 10)
tenacity_before_logger = tenacity.before_log(self.log, logging.INFO) if self.log else None
tenacity_after_logger = tenacity.after_log(self.log, logging.INFO) if self.log else None
default_kwargs = {
- 'wait': tenacity.wait_exponential(multiplier=multiplier, max=max_limit, min=min_limit),
- 'retry': tenacity.retry_if_exception(should_retry),
- 'stop': tenacity.stop_after_delay(stop_after_delay),
- 'before': tenacity_before_logger,
- 'after': tenacity_after_logger,
+ "wait": tenacity.wait_exponential(multiplier=multiplier, max=max_limit, min=min_limit),
+ "retry": tenacity.retry_if_exception(should_retry),
+ "stop": tenacity.stop_after_delay(stop_after_delay),
+ "before": tenacity_before_logger,
+ "after": tenacity_after_logger,
}
return tenacity.retry(**default_kwargs)(fun)(self, *args, **kwargs)
@@ -599,59 +599,81 @@ def decorator_f(self, *args, **kwargs):
return retry_decorator
+ def _get_credentials(self, region_name: str | None) -> tuple[boto3.session.Session, str | None]:
+ warnings.warn(
+ "`AwsGenericHook._get_credentials` method deprecated and will be removed in a future releases. "
+ "Please use `AwsGenericHook.get_session` method and "
+ "`AwsGenericHook.conn_config.endpoint_url` property instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+
+ return self.get_session(region_name=region_name), self.conn_config.endpoint_url
+
+ @staticmethod
+ def get_ui_field_behaviour() -> dict[str, Any]:
+ """Returns custom UI field behaviour for AWS Connection."""
+ return {
+ "hidden_fields": ["host", "schema", "port"],
+ "relabeling": {
+ "login": "AWS Access Key ID",
+ "password": "AWS Secret Access Key",
+ },
+ "placeholders": {
+ "login": "AKIAIOSFODNN7EXAMPLE",
+ "password": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
+ "extra": json.dumps(
+ {
+ "region_name": "us-east-1",
+ "session_kwargs": {"profile_name": "default"},
+ "config_kwargs": {"retries": {"mode": "standard", "max_attempts": 10}},
+ "role_arn": "arn:aws:iam::123456789098:role/role-name",
+ "assume_role_method": "assume_role",
+ "assume_role_kwargs": {"RoleSessionName": "airflow"},
+ "aws_session_token": "AQoDYXdzEJr...EXAMPLETOKEN",
+ "endpoint_url": "http://localhost:4566",
+ },
+ indent=2,
+ ),
+ },
+ }
+
+ def test_connection(self):
+ """
+ Tests the AWS connection by call AWS STS (Security Token Service) GetCallerIdentity API.
-def _parse_s3_config(
- config_file_name: str, config_format: Optional[str] = "boto", profile: Optional[str] = None
-) -> Tuple[Optional[str], Optional[str]]:
+ .. seealso::
+ https://docs.aws.amazon.com/STS/latest/APIReference/API_GetCallerIdentity.html
+ """
+ try:
+ session = self.get_session()
+ conn_info = session.client("sts").get_caller_identity()
+ metadata = conn_info.pop("ResponseMetadata", {})
+ if metadata.get("HTTPStatusCode") != 200:
+ try:
+ return False, json.dumps(metadata)
+ except TypeError:
+ return False, str(metadata)
+ conn_info["credentials_method"] = session.get_credentials().method
+ conn_info["region_name"] = session.region_name
+ return True, ", ".join(f"{k}={v!r}" for k, v in conn_info.items())
+
+ except Exception as e:
+ return False, str(f"{type(e).__name__!r} error occurred while testing connection: {e}")
+
+
+class AwsBaseHook(AwsGenericHook[Union[boto3.client, boto3.resource]]):
"""
- Parses a config file for s3 credentials. Can currently
- parse boto, s3cmd.conf and AWS SDK config formats
+ Interact with AWS.
+ This class is a thin wrapper around the boto3 python library
+ with basic conn annotation.
- :param config_file_name: path to the config file
- :param config_format: config type. One of "boto", "s3cmd" or "aws".
- Defaults to "boto"
- :param profile: profile name in AWS type config file
+ .. seealso::
+ :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsGenericHook`
"""
- config = configparser.ConfigParser()
- if config.read(config_file_name): # pragma: no cover
- sections = config.sections()
- else:
- raise AirflowException(f"Couldn't read {config_file_name}")
- # Setting option names depending on file format
- if config_format is None:
- config_format = "boto"
- conf_format = config_format.lower()
- if conf_format == "boto": # pragma: no cover
- if profile is not None and "profile " + profile in sections:
- cred_section = "profile " + profile
- else:
- cred_section = "Credentials"
- elif conf_format == "aws" and profile is not None:
- cred_section = profile
- else:
- cred_section = "default"
- # Option names
- if conf_format in ("boto", "aws"): # pragma: no cover
- key_id_option = "aws_access_key_id"
- secret_key_option = "aws_secret_access_key"
- # security_token_option = 'aws_security_token'
- else:
- key_id_option = "access_key"
- secret_key_option = "secret_key"
- # Actual Parsing
- if cred_section not in sections:
- raise AirflowException("This config file format is not recognized")
- else:
- try:
- access_key = config.get(cred_section, key_id_option)
- secret_key = config.get(cred_section, secret_key_option)
- except Exception:
- logging.warning("Option Error in parsing s3 config file")
- raise
- return access_key, secret_key
-def resolve_session_factory() -> Type[BaseSessionFactory]:
+def resolve_session_factory() -> type[BaseSessionFactory]:
"""Resolves custom SessionFactory class"""
clazz = conf.getimport("aws", "session_factory", fallback=None)
if not clazz:
@@ -665,3 +687,14 @@ def resolve_session_factory() -> Type[BaseSessionFactory]:
SessionFactory = resolve_session_factory()
+
+
+def _parse_s3_config(config_file_name: str, config_format: str | None = "boto", profile: str | None = None):
+ """For compatibility with airflow.contrib.hooks.aws_hook"""
+ from airflow.providers.amazon.aws.utils.connection_wrapper import _parse_s3_config
+
+ return _parse_s3_config(
+ config_file_name=config_file_name,
+ config_format=config_format,
+ profile=profile,
+ )
diff --git a/airflow/providers/amazon/aws/hooks/batch_client.py b/airflow/providers/amazon/aws/hooks/batch_client.py
index 3b10012b3943f..e9080189e76b2 100644
--- a/airflow/providers/amazon/aws/hooks/batch_client.py
+++ b/airflow/providers/amazon/aws/hooks/batch_client.py
@@ -20,14 +20,14 @@
.. seealso::
- - http://boto3.readthedocs.io/en/latest/guide/configuration.html
- - http://boto3.readthedocs.io/en/latest/reference/services/batch.html
+ - https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html
+ - https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/batch.html
- https://docs.aws.amazon.com/batch/latest/APIReference/Welcome.html
"""
-import warnings
+from __future__ import annotations
+
from random import uniform
from time import sleep
-from typing import Dict, List, Optional, Union
import botocore.client
import botocore.exceptions
@@ -48,17 +48,16 @@ class BatchProtocol(Protocol):
.. seealso::
- https://mypy.readthedocs.io/en/latest/protocols.html
- - http://boto3.readthedocs.io/en/latest/reference/services/batch.html
+ - https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/batch.html
"""
- def describe_jobs(self, jobs: List[str]) -> Dict:
+ def describe_jobs(self, jobs: list[str]) -> dict:
"""
Get job descriptions from AWS Batch
:param jobs: a list of JobId to describe
:return: an API response to describe jobs
- :rtype: Dict
"""
...
@@ -71,7 +70,6 @@ def get_waiter(self, waiterName: str) -> botocore.waiter.Waiter:
model file (typically this is CamelCasing).
:return: a waiter object for the named AWS Batch service
- :rtype: botocore.waiter.Waiter
.. note::
AWS Batch might not have any waiters (until botocore PR-1307 is released).
@@ -94,11 +92,11 @@ def submit_job(
jobName: str,
jobQueue: str,
jobDefinition: str,
- arrayProperties: Dict,
- parameters: Dict,
- containerOverrides: Dict,
- tags: Dict,
- ) -> Dict:
+ arrayProperties: dict,
+ parameters: dict,
+ containerOverrides: dict,
+ tags: dict,
+ ) -> dict:
"""
Submit a Batch job
@@ -117,11 +115,10 @@ def submit_job(
:param tags: the same parameter that boto3 will receive
:return: an API response
- :rtype: Dict
"""
...
- def terminate_job(self, jobId: str, reason: str) -> Dict:
+ def terminate_job(self, jobId: str, reason: str) -> dict:
"""
Terminate a Batch job
@@ -130,7 +127,6 @@ def terminate_job(self, jobId: str, reason: str) -> Dict:
:param reason: a reason to terminate job ID
:return: an API response
- :rtype: Dict
"""
...
@@ -181,36 +177,41 @@ class BatchClientHook(AwsBaseHook):
DEFAULT_DELAY_MIN = 1
DEFAULT_DELAY_MAX = 10
- FAILURE_STATE = 'FAILED'
- SUCCESS_STATE = 'SUCCEEDED'
- RUNNING_STATE = 'RUNNING'
+ FAILURE_STATE = "FAILED"
+ SUCCESS_STATE = "SUCCEEDED"
+ RUNNING_STATE = "RUNNING"
INTERMEDIATE_STATES = (
- 'SUBMITTED',
- 'PENDING',
- 'RUNNABLE',
- 'STARTING',
+ "SUBMITTED",
+ "PENDING",
+ "RUNNABLE",
+ "STARTING",
RUNNING_STATE,
)
+ COMPUTE_ENVIRONMENT_TERMINAL_STATUS = ("VALID", "DELETED")
+ COMPUTE_ENVIRONMENT_INTERMEDIATE_STATUS = ("CREATING", "UPDATING", "DELETING")
+
+ JOB_QUEUE_TERMINAL_STATUS = ("VALID", "DELETED")
+ JOB_QUEUE_INTERMEDIATE_STATUS = ("CREATING", "UPDATING", "DELETING")
+
def __init__(
- self, *args, max_retries: Optional[int] = None, status_retries: Optional[int] = None, **kwargs
+ self, *args, max_retries: int | None = None, status_retries: int | None = None, **kwargs
) -> None:
# https://github.com/python/mypy/issues/6799 hence type: ignore
- super().__init__(client_type='batch', *args, **kwargs) # type: ignore
+ super().__init__(client_type="batch", *args, **kwargs) # type: ignore
self.max_retries = max_retries or self.MAX_RETRIES
self.status_retries = status_retries or self.STATUS_RETRIES
@property
- def client(self) -> Union[BatchProtocol, botocore.client.BaseClient]:
+ def client(self) -> BatchProtocol | botocore.client.BaseClient:
"""
An AWS API client for Batch services.
:return: a boto3 'batch' client for the ``.region_name``
- :rtype: Union[BatchProtocol, botocore.client.BaseClient]
"""
return self.conn
- def terminate_job(self, job_id: str, reason: str) -> Dict:
+ def terminate_job(self, job_id: str, reason: str) -> dict:
"""
Terminate a Batch job
@@ -219,7 +220,6 @@ def terminate_job(self, job_id: str, reason: str) -> Dict:
:param reason: a reason to terminate job ID
:return: an API response
- :rtype: Dict
"""
response = self.get_conn().terminate_job(jobId=job_id, reason=reason)
self.log.info(response)
@@ -232,7 +232,6 @@ def check_job_success(self, job_id: str) -> bool:
:param job_id: a Batch job ID
- :rtype: bool
:raises: AirflowException
"""
@@ -251,7 +250,7 @@ def check_job_success(self, job_id: str) -> bool:
raise AirflowException(f"AWS Batch job ({job_id}) has unknown status: {job}")
- def wait_for_job(self, job_id: str, delay: Union[int, float, None] = None) -> None:
+ def wait_for_job(self, job_id: str, delay: int | float | None = None) -> None:
"""
Wait for Batch job to complete
@@ -266,7 +265,7 @@ def wait_for_job(self, job_id: str, delay: Union[int, float, None] = None) -> No
self.poll_for_job_complete(job_id, delay)
self.log.info("AWS Batch job (%s) has completed", job_id)
- def poll_for_job_running(self, job_id: str, delay: Union[int, float, None] = None) -> None:
+ def poll_for_job_running(self, job_id: str, delay: int | float | None = None) -> None:
"""
Poll for job running. The status that indicates a job is running or
already complete are: 'RUNNING'|'SUCCEEDED'|'FAILED'.
@@ -288,7 +287,7 @@ def poll_for_job_running(self, job_id: str, delay: Union[int, float, None] = Non
running_status = [self.RUNNING_STATE, self.SUCCESS_STATE, self.FAILURE_STATE]
self.poll_job_status(job_id, running_status)
- def poll_for_job_complete(self, job_id: str, delay: Union[int, float, None] = None) -> None:
+ def poll_for_job_complete(self, job_id: str, delay: int | float | None = None) -> None:
"""
Poll for job completion. The status that indicates job completion
are: 'SUCCEEDED'|'FAILED'.
@@ -306,7 +305,7 @@ def poll_for_job_complete(self, job_id: str, delay: Union[int, float, None] = No
complete_status = [self.SUCCESS_STATE, self.FAILURE_STATE]
self.poll_job_status(job_id, complete_status)
- def poll_job_status(self, job_id: str, match_status: List[str]) -> bool:
+ def poll_job_status(self, job_id: str, match_status: list[str]) -> bool:
"""
Poll for job status using an exponential back-off strategy (with max_retries).
@@ -315,7 +314,6 @@ def poll_job_status(self, job_id: str, match_status: List[str]) -> bool:
:param match_status: a list of job status to match; the Batch job status are:
'SUBMITTED'|'PENDING'|'RUNNABLE'|'STARTING'|'RUNNING'|'SUCCEEDED'|'FAILED'
- :rtype: bool
:raises: AirflowException
"""
@@ -348,14 +346,13 @@ def poll_job_status(self, job_id: str, match_status: List[str]) -> bool:
)
self.delay(pause)
- def get_job_description(self, job_id: str) -> Dict:
+ def get_job_description(self, job_id: str) -> dict:
"""
Get job description (using status_retries).
:param job_id: a Batch job ID
:return: an API response for describe jobs
- :rtype: Dict
:raises: AirflowException
"""
@@ -390,7 +387,7 @@ def get_job_description(self, job_id: str) -> Dict:
self.delay(pause)
@staticmethod
- def parse_job_description(job_id: str, response: Dict) -> Dict:
+ def parse_job_description(job_id: str, response: dict) -> dict:
"""
Parse job description to extract description for job_id
@@ -399,7 +396,6 @@ def parse_job_description(job_id: str, response: Dict) -> Dict:
:param response: an API response for describe jobs
:return: an API response to describe job_id
- :rtype: Dict
:raises: AirflowException
"""
@@ -410,10 +406,47 @@ def parse_job_description(job_id: str, response: Dict) -> Dict:
return matching_jobs[0]
+ def get_job_awslogs_info(self, job_id: str) -> dict[str, str] | None:
+ """
+ Parse job description to extract AWS CloudWatch information.
+
+ :param job_id: AWS Batch Job ID
+ """
+ job_container_desc = self.get_job_description(job_id=job_id).get("container", {})
+ log_configuration = job_container_desc.get("logConfiguration", {})
+
+ # In case if user select other "logDriver" rather than "awslogs"
+ # than CloudWatch logging should be disabled.
+ # If user not specify anything than expected that "awslogs" will use
+ # with default settings:
+ # awslogs-group = /aws/batch/job
+ # awslogs-region = `same as AWS Batch Job region`
+ log_driver = log_configuration.get("logDriver", "awslogs")
+ if log_driver != "awslogs":
+ self.log.warning(
+ "AWS Batch job (%s) uses logDriver (%s). AWS CloudWatch logging disabled.", job_id, log_driver
+ )
+ return None
+
+ awslogs_stream_name = job_container_desc.get("logStreamName")
+ if not awslogs_stream_name:
+ # In case of call this method on very early stage of running AWS Batch
+ # there is possibility than AWS CloudWatch Stream Name not exists yet.
+ # AWS CloudWatch Stream Name also not created in case of misconfiguration.
+ self.log.warning("AWS Batch job (%s) doesn't create AWS CloudWatch Stream.", job_id)
+ return None
+
+ # Try to get user-defined log configuration options
+ log_options = log_configuration.get("options", {})
+
+ return {
+ "awslogs_stream_name": awslogs_stream_name,
+ "awslogs_group": log_options.get("awslogs-group", "/aws/batch/job"),
+ "awslogs_region": log_options.get("awslogs-region", self.conn_region_name),
+ }
+
@staticmethod
- def add_jitter(
- delay: Union[int, float], width: Union[int, float] = 1, minima: Union[int, float] = 0
- ) -> float:
+ def add_jitter(delay: int | float, width: int | float = 1, minima: int | float = 0) -> float:
"""
Use delay +/- width for random jitter
@@ -432,7 +465,6 @@ def add_jitter(
:return: uniform(delay - width, delay + width) jitter
and it is a non-negative number
- :rtype: float
"""
delay = abs(delay)
width = abs(width)
@@ -442,7 +474,7 @@ def add_jitter(
return uniform(lower, upper)
@staticmethod
- def delay(delay: Union[int, float, None] = None) -> None:
+ def delay(delay: int | float | None = None) -> None:
"""
Pause execution for ``delay`` seconds.
@@ -470,7 +502,6 @@ def exponential_delay(tries: int) -> float:
:param tries: Number of tries
- :rtype: float
Examples of behavior:
@@ -506,35 +537,3 @@ def exp(tries):
delay = 1 + pow(tries * 0.6, 2)
delay = min(max_interval, delay)
return uniform(delay / 3, delay)
-
-
-class AwsBatchProtocol(BatchProtocol, Protocol):
- """
- This class is deprecated.
- Please use :class:`airflow.providers.amazon.aws.hooks.batch.BatchProtocol`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- "This class is deprecated. "
- "Please use :class:`airflow.providers.amazon.aws.hooks.batch.BatchProtocol`.",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class AwsBatchClientHook(BatchClientHook):
- """
- This hook is deprecated.
- Please use :class:`airflow.providers.amazon.aws.hooks.batch.BatchClientHook`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- "This hook is deprecated. "
- "Please use :class:`airflow.providers.amazon.aws.hooks.batch.BatchClientHook`.",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/providers/amazon/aws/hooks/batch_waiters.py b/airflow/providers/amazon/aws/hooks/batch_waiters.py
index 59ba0e431f25e..0bbb982e41738 100644
--- a/airflow/providers/amazon/aws/hooks/batch_waiters.py
+++ b/airflow/providers/amazon/aws/hooks/batch_waiters.py
@@ -15,8 +15,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-#
-
"""
AWS Batch service waiters
@@ -25,13 +23,12 @@
- https://boto3.amazonaws.com/v1/documentation/api/latest/guide/clients.html#waiters
- https://github.com/boto/botocore/blob/develop/botocore/waiter.py
"""
+from __future__ import annotations
import json
import sys
-import warnings
from copy import deepcopy
from pathlib import Path
-from typing import Dict, List, Optional, Union
import botocore.client
import botocore.exceptions
@@ -71,16 +68,12 @@ class BatchWaitersHook(BatchClientHook):
# and the details of the config on that waiter can be further modified without any
# accidental impact on the generation of new waiters from the defined waiter_model, e.g.
waiters.get_waiter("JobExists").config.delay # -> 5
- waiter = waiters.get_waiter(
- "JobExists"
- ) # -> botocore.waiter.Batch.Waiter.JobExists object
+ waiter = waiters.get_waiter("JobExists") # -> botocore.waiter.Batch.Waiter.JobExists object
waiter.config.delay = 10
waiters.get_waiter("JobExists").config.delay # -> 5 as defined by waiter_model
# To use a specific waiter, update the config and call the `wait()` method for jobId, e.g.
- waiter = waiters.get_waiter(
- "JobExists"
- ) # -> botocore.waiter.Batch.Waiter.JobExists object
+ waiter = waiters.get_waiter("JobExists") # -> botocore.waiter.Batch.Waiter.JobExists object
waiter.config.delay = random.uniform(1, 10) # seconds
waiter.config.max_attempts = 10
waiter.wait(jobs=[jobId])
@@ -97,27 +90,26 @@ class BatchWaitersHook(BatchClientHook):
:param aws_conn_id: connection id of AWS credentials / region name. If None,
credential boto3 strategy will be used
- (http://boto3.readthedocs.io/en/latest/guide/configuration.html).
+ (https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html).
:param region_name: region name to use in AWS client.
Override the AWS region in connection (if provided)
"""
- def __init__(self, *args, waiter_config: Optional[Dict] = None, **kwargs) -> None:
+ def __init__(self, *args, waiter_config: dict | None = None, **kwargs) -> None:
super().__init__(*args, **kwargs)
- self._default_config = None # type: Optional[Dict]
+ self._default_config: dict | None = None
self._waiter_config = waiter_config or self.default_config
self._waiter_model = botocore.waiter.WaiterModel(self._waiter_config)
@property
- def default_config(self) -> Dict:
+ def default_config(self) -> dict:
"""
An immutable default waiter configuration
:return: a waiter configuration for AWS Batch services
- :rtype: Dict
"""
if self._default_config is None:
config_path = Path(__file__).with_name("batch_waiters.json").resolve()
@@ -126,7 +118,7 @@ def default_config(self) -> Dict:
return deepcopy(self._default_config) # avoid accidental mutation
@property
- def waiter_config(self) -> Dict:
+ def waiter_config(self) -> dict:
"""
An immutable waiter configuration for this instance; a ``deepcopy`` is returned by this
property. During the init for BatchWaiters, the waiter_config is used to build a
@@ -134,7 +126,6 @@ def waiter_config(self) -> Dict:
mutations of waiter_config leaking into the waiter_model.
:return: a waiter configuration for AWS Batch services
- :rtype: Dict
"""
return deepcopy(self._waiter_config) # avoid accidental mutation
@@ -144,7 +135,6 @@ def waiter_model(self) -> botocore.waiter.WaiterModel:
A configured waiter model used to generate waiters on AWS Batch services.
:return: a waiter model for AWS Batch services
- :rtype: botocore.waiter.WaiterModel
"""
return self._waiter_model
@@ -179,20 +169,18 @@ def get_waiter(self, waiter_name: str) -> botocore.waiter.Waiter:
model file (typically this is CamelCasing); see ``.list_waiters``.
:return: a waiter object for the named AWS Batch service
- :rtype: botocore.waiter.Waiter
"""
return botocore.waiter.create_waiter_with_client(waiter_name, self.waiter_model, self.client)
- def list_waiters(self) -> List[str]:
+ def list_waiters(self) -> list[str]:
"""
List the waiters in a waiter configuration for AWS Batch services.
:return: waiter names for AWS Batch services
- :rtype: List[str]
"""
return self.waiter_model.waiter_names
- def wait_for_job(self, job_id: str, delay: Union[int, float, None] = None) -> None:
+ def wait_for_job(self, job_id: str, delay: int | float | None = None) -> None:
"""
Wait for Batch job to complete. This assumes that the ``.waiter_model`` is configured
using some variation of the ``.default_config`` so that it can generate waiters with the
@@ -231,19 +219,3 @@ def wait_for_job(self, job_id: str, delay: Union[int, float, None] = None) -> No
except (botocore.exceptions.ClientError, botocore.exceptions.WaiterError) as err:
raise AirflowException(err)
-
-
-class AwsBatchWaitersHook(BatchWaitersHook):
- """
- This hook is deprecated.
- Please use :class:`airflow.providers.amazon.aws.hooks.batch.BatchWaitersHook`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- "This hook is deprecated. "
- "Please use :class:`airflow.providers.amazon.aws.hooks.batch.BatchWaitersHook`.",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/providers/amazon/aws/hooks/cloud_formation.py b/airflow/providers/amazon/aws/hooks/cloud_formation.py
index e96f397628ed9..ac196ee602d1e 100644
--- a/airflow/providers/amazon/aws/hooks/cloud_formation.py
+++ b/airflow/providers/amazon/aws/hooks/cloud_formation.py
@@ -15,10 +15,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""This module contains AWS CloudFormation Hook"""
-import warnings
-from typing import Optional, Union
+from __future__ import annotations
from boto3 import client, resource
from botocore.exceptions import ClientError
@@ -38,17 +36,17 @@ class CloudFormationHook(AwsBaseHook):
"""
def __init__(self, *args, **kwargs):
- super().__init__(client_type='cloudformation', *args, **kwargs)
+ super().__init__(client_type="cloudformation", *args, **kwargs)
- def get_stack_status(self, stack_name: Union[client, resource]) -> Optional[dict]:
+ def get_stack_status(self, stack_name: client | resource) -> dict | None:
"""Get stack status from CloudFormation."""
- self.log.info('Poking for stack %s', stack_name)
+ self.log.info("Poking for stack %s", stack_name)
try:
- stacks = self.get_conn().describe_stacks(StackName=stack_name)['Stacks']
- return stacks[0]['StackStatus']
+ stacks = self.get_conn().describe_stacks(StackName=stack_name)["Stacks"]
+ return stacks[0]["StackStatus"]
except ClientError as e:
- if 'does not exist' in str(e):
+ if "does not exist" in str(e):
return None
else:
raise e
@@ -60,11 +58,11 @@ def create_stack(self, stack_name: str, cloudformation_parameters: dict) -> None
:param stack_name: stack_name.
:param cloudformation_parameters: parameters to be passed to CloudFormation.
"""
- if 'StackName' not in cloudformation_parameters:
- cloudformation_parameters['StackName'] = stack_name
+ if "StackName" not in cloudformation_parameters:
+ cloudformation_parameters["StackName"] = stack_name
self.get_conn().create_stack(**cloudformation_parameters)
- def delete_stack(self, stack_name: str, cloudformation_parameters: Optional[dict] = None) -> None:
+ def delete_stack(self, stack_name: str, cloudformation_parameters: dict | None = None) -> None:
"""
Delete stack in CloudFormation.
@@ -72,22 +70,6 @@ def delete_stack(self, stack_name: str, cloudformation_parameters: Optional[dict
:param cloudformation_parameters: parameters to be passed to CloudFormation (optional).
"""
cloudformation_parameters = cloudformation_parameters or {}
- if 'StackName' not in cloudformation_parameters:
- cloudformation_parameters['StackName'] = stack_name
+ if "StackName" not in cloudformation_parameters:
+ cloudformation_parameters["StackName"] = stack_name
self.get_conn().delete_stack(**cloudformation_parameters)
-
-
-class AWSCloudFormationHook(CloudFormationHook):
- """
- This hook is deprecated.
- Please use :class:`airflow.providers.amazon.aws.hooks.cloud_formation.CloudFormationHook`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- "This hook is deprecated. "
- "Please use :class:`airflow.providers.amazon.aws.hooks.cloud_formation.CloudFormationHook`.",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/providers/amazon/aws/hooks/datasync.py b/airflow/providers/amazon/aws/hooks/datasync.py
index b75123d6fc6ec..89cca307502a9 100644
--- a/airflow/providers/amazon/aws/hooks/datasync.py
+++ b/airflow/providers/amazon/aws/hooks/datasync.py
@@ -14,12 +14,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Interact with AWS DataSync, using the AWS ``boto3`` library."""
+from __future__ import annotations
import time
-import warnings
-from typing import List, Optional
from airflow.exceptions import AirflowBadRequest, AirflowException, AirflowTaskTimeout
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
@@ -53,7 +51,7 @@ class DataSyncHook(AwsBaseHook):
TASK_EXECUTION_SUCCESS_STATES = ("SUCCESS",)
def __init__(self, wait_interval_seconds: int = 30, *args, **kwargs) -> None:
- super().__init__(client_type='datasync', *args, **kwargs) # type: ignore[misc]
+ super().__init__(client_type="datasync", *args, **kwargs) # type: ignore[misc]
self.locations: list = []
self.tasks: list = []
# wait_interval_seconds = 0 is used during unit tests
@@ -86,7 +84,7 @@ def create_location(self, location_uri: str, **create_location_kwargs) -> str:
def get_location_arns(
self, location_uri: str, case_sensitive: bool = False, ignore_trailing_slash: bool = True
- ) -> List[str]:
+ ) -> list[str]:
"""
Return all LocationArns which match a LocationUri.
@@ -94,7 +92,6 @@ def get_location_arns(
:param bool case_sensitive: Do a case sensitive search for location URI.
:param bool ignore_trailing_slash: Ignore / at the end of URI when matching.
:return: List of LocationArns.
- :rtype: list(str)
:raises AirflowBadRequest: if ``location_uri`` is empty
"""
if not location_uri:
@@ -141,7 +138,6 @@ def create_task(
:param str destination_location_arn: Destination LocationArn. Must exist already.
:param create_task_kwargs: Passed to ``boto.create_task()``. See AWS boto3 datasync documentation.
:return: TaskArn of the created Task
- :rtype: str
"""
task = self.get_conn().create_task(
SourceLocationArn=source_location_arn,
@@ -192,7 +188,6 @@ def get_task_arns_for_location_arns(
:param list source_location_arns: List of source LocationArns.
:param list destination_location_arns: List of destination LocationArns.
:return: list
- :rtype: list(TaskArns)
:raises AirflowBadRequest: if ``source_location_arns`` or ``destination_location_arns`` are empty.
"""
if not source_location_arns:
@@ -219,7 +214,6 @@ def start_task_execution(self, task_arn: str, **kwargs) -> str:
:param str task_arn: TaskArn
:return: TaskExecutionArn
:param kwargs: kwargs sent to ``boto3.start_task_execution()``
- :rtype: str
:raises ClientError: If a TaskExecution is already busy running for this ``task_arn``.
:raises AirflowBadRequest: If ``task_arn`` is empty.
"""
@@ -245,7 +239,6 @@ def get_task_description(self, task_arn: str) -> dict:
:param str task_arn: TaskArn
:return: AWS metadata about a task.
- :rtype: dict
:raises AirflowBadRequest: If ``task_arn`` is empty.
"""
if not task_arn:
@@ -258,18 +251,16 @@ def describe_task_execution(self, task_execution_arn: str) -> dict:
:param str task_execution_arn: TaskExecutionArn
:return: AWS metadata about a task execution.
- :rtype: dict
:raises AirflowBadRequest: If ``task_execution_arn`` is empty.
"""
return self.get_conn().describe_task_execution(TaskExecutionArn=task_execution_arn)
- def get_current_task_execution_arn(self, task_arn: str) -> Optional[str]:
+ def get_current_task_execution_arn(self, task_arn: str) -> str | None:
"""
Get current TaskExecutionArn (if one exists) for the specified ``task_arn``.
:param str task_arn: TaskArn
:return: CurrentTaskExecutionArn for this ``task_arn`` or None.
- :rtype: str
:raises AirflowBadRequest: if ``task_arn`` is empty.
"""
if not task_arn:
@@ -287,7 +278,6 @@ def wait_for_task_execution(self, task_execution_arn: str, max_iterations: int =
:param str task_execution_arn: TaskExecutionArn
:param int max_iterations: Maximum number of iterations before timing out.
:return: Result of task execution.
- :rtype: bool
:raises AirflowTaskTimeout: If maximum iterations is exceeded.
:raises AirflowBadRequest: If ``task_execution_arn`` is empty.
"""
@@ -316,18 +306,3 @@ def wait_for_task_execution(self, task_execution_arn: str, max_iterations: int =
if iterations <= 0:
raise AirflowTaskTimeout("Max iterations exceeded!")
raise AirflowException(f"Unknown status: {status}") # Should never happen
-
-
-class AWSDataSyncHook(DataSyncHook):
- """
- This hook is deprecated.
- Please use :class:`airflow.providers.amazon.aws.hooks.datasync.DataSyncHook`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- "This hook is deprecated. Please use `airflow.providers.amazon.aws.hooks.datasync.DataSyncHook`.",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/providers/amazon/aws/hooks/dms.py b/airflow/providers/amazon/aws/hooks/dms.py
index a1bd19daf3129..7ca541dd41f54 100644
--- a/airflow/providers/amazon/aws/hooks/dms.py
+++ b/airflow/providers/amazon/aws/hooks/dms.py
@@ -15,9 +15,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
import json
from enum import Enum
-from typing import Optional
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
@@ -25,10 +26,10 @@
class DmsTaskWaiterStatus(str, Enum):
"""Available AWS DMS Task Waiter statuses."""
- DELETED = 'deleted'
- READY = 'ready'
- RUNNING = 'running'
- STOPPED = 'stopped'
+ DELETED = "deleted"
+ READY = "ready"
+ RUNNING = "running"
+ STOPPED = "stopped"
class DmsHook(AwsBaseHook):
@@ -39,24 +40,21 @@ def __init__(
*args,
**kwargs,
):
- kwargs['client_type'] = 'dms'
+ kwargs["client_type"] = "dms"
super().__init__(*args, **kwargs)
- def describe_replication_tasks(self, **kwargs):
+ def describe_replication_tasks(self, **kwargs) -> tuple[str | None, list]:
"""
Describe replication tasks
:return: Marker and list of replication tasks
- :rtype: (Optional[str], list)
"""
dms_client = self.get_conn()
response = dms_client.describe_replication_tasks(**kwargs)
- return response.get('Marker'), response.get('ReplicationTasks', [])
+ return response.get("Marker"), response.get("ReplicationTasks", [])
- def find_replication_tasks_by_arn(
- self, replication_task_arn: str, without_settings: Optional[bool] = False
- ):
+ def find_replication_tasks_by_arn(self, replication_task_arn: str, without_settings: bool | None = False):
"""
Find and describe replication tasks by task ARN
:param replication_task_arn: Replication task arn
@@ -67,8 +65,8 @@ def find_replication_tasks_by_arn(
_, tasks = self.describe_replication_tasks(
Filters=[
{
- 'Name': 'replication-task-arn',
- 'Values': [replication_task_arn],
+ "Name": "replication-task-arn",
+ "Values": [replication_task_arn],
}
],
WithoutSettings=without_settings,
@@ -76,7 +74,7 @@ def find_replication_tasks_by_arn(
return tasks
- def get_task_status(self, replication_task_arn: str) -> Optional[str]:
+ def get_task_status(self, replication_task_arn: str) -> str | None:
"""
Retrieve task status.
@@ -89,11 +87,11 @@ def get_task_status(self, replication_task_arn: str) -> Optional[str]:
)
if len(replication_tasks) == 1:
- status = replication_tasks[0]['Status']
+ status = replication_tasks[0]["Status"]
self.log.info('Replication task with ARN(%s) has status "%s".', replication_task_arn, status)
return status
else:
- self.log.info('Replication task with ARN(%s) is not found.', replication_task_arn)
+ self.log.info("Replication task with ARN(%s) is not found.", replication_task_arn)
return None
def create_replication_task(
@@ -128,7 +126,7 @@ def create_replication_task(
**kwargs,
)
- replication_task_arn = create_task_response['ReplicationTask']['ReplicationTaskArn']
+ replication_task_arn = create_task_response["ReplicationTask"]["ReplicationTaskArn"]
self.wait_for_task_status(replication_task_arn, DmsTaskWaiterStatus.READY)
return replication_task_arn
@@ -182,15 +180,15 @@ def wait_for_task_status(self, replication_task_arn: str, status: DmsTaskWaiterS
:param replication_task_arn: Replication task ARN
"""
if not isinstance(status, DmsTaskWaiterStatus):
- raise TypeError('Status must be an instance of DmsTaskWaiterStatus')
+ raise TypeError("Status must be an instance of DmsTaskWaiterStatus")
dms_client = self.get_conn()
- waiter = dms_client.get_waiter(f'replication_task_{status}')
+ waiter = dms_client.get_waiter(f"replication_task_{status}")
waiter.wait(
Filters=[
{
- 'Name': 'replication-task-arn',
- 'Values': [
+ "Name": "replication-task-arn",
+ "Values": [
replication_task_arn,
],
},
diff --git a/airflow/providers/amazon/aws/hooks/dynamodb.py b/airflow/providers/amazon/aws/hooks/dynamodb.py
index 7b298ee15ca62..52c96e9b02c23 100644
--- a/airflow/providers/amazon/aws/hooks/dynamodb.py
+++ b/airflow/providers/amazon/aws/hooks/dynamodb.py
@@ -15,11 +15,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
-
"""This module contains the AWS DynamoDB hook"""
-import warnings
-from typing import Iterable, List, Optional
+from __future__ import annotations
+
+from typing import Iterable
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
@@ -40,7 +39,7 @@ class DynamoDBHook(AwsBaseHook):
"""
def __init__(
- self, *args, table_keys: Optional[List] = None, table_name: Optional[str] = None, **kwargs
+ self, *args, table_keys: list | None = None, table_name: str | None = None, **kwargs
) -> None:
self.table_keys = table_keys
self.table_name = table_name
@@ -58,19 +57,3 @@ def write_batch_data(self, items: Iterable) -> bool:
return True
except Exception as general_error:
raise AirflowException(f"Failed to insert items in dynamodb, error: {str(general_error)}")
-
-
-class AwsDynamoDBHook(DynamoDBHook):
- """
- This class is deprecated.
- Please use :class:`airflow.providers.amazon.aws.hooks.dynamodb.DynamoDBHook`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- "This class is deprecated. "
- "Please use :class:`airflow.providers.amazon.aws.hooks.dynamodb.DynamoDBHook`.",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/providers/amazon/aws/hooks/ec2.py b/airflow/providers/amazon/aws/hooks/ec2.py
index 96dbaf541051d..bb715c4bf2ca1 100644
--- a/airflow/providers/amazon/aws/hooks/ec2.py
+++ b/airflow/providers/amazon/aws/hooks/ec2.py
@@ -15,11 +15,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-#
+from __future__ import annotations
import functools
import time
-from typing import List, Optional
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
@@ -31,16 +30,17 @@ def checker(self, *args, **kwargs):
if self._api_type == "client_type":
return func(self, *args, **kwargs)
+ ec2_doc_link = "https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ec2.html"
raise AirflowException(
- """
+ f"""
This method is only callable when using client_type API for interacting with EC2.
Create the EC2Hook object as follows to use this method
ec2 = EC2Hook(api_type="client_type")
Read following for details on client_type and resource_type APIs:
- 1. https://boto3.amazonaws.com/v1/documentation/api/1.9.42/reference/services/ec2.html#client
- 2. https://boto3.amazonaws.com/v1/documentation/api/1.9.42/reference/services/ec2.html#service-resource # noqa
+ 1. {ec2_doc_link}#client
+ 2. {ec2_doc_link}#service-resource
"""
)
@@ -70,14 +70,13 @@ def __init__(self, api_type="resource_type", *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
- def get_instance(self, instance_id: str, filters: Optional[List] = None):
+ def get_instance(self, instance_id: str, filters: list | None = None):
"""
Get EC2 instance by id and return it.
:param instance_id: id of the AWS EC2 instance
:param filters: List of filters to specify instances to get
:return: Instance object
- :rtype: ec2.Instance
"""
if self._api_type == "client_type":
return self.get_instances(filters=filters, instance_ids=[instance_id])
@@ -121,7 +120,7 @@ def terminate_instances(self, instance_ids: list) -> dict:
return self.conn.terminate_instances(InstanceIds=instance_ids)
@only_client_type
- def describe_instances(self, filters: Optional[List] = None, instance_ids: Optional[List] = None):
+ def describe_instances(self, filters: list | None = None, instance_ids: list | None = None):
"""
Describe EC2 instances, optionally applying filters and selective instance ids
@@ -138,7 +137,7 @@ def describe_instances(self, filters: Optional[List] = None, instance_ids: Optio
return self.conn.describe_instances(Filters=filters, InstanceIds=instance_ids)
@only_client_type
- def get_instances(self, filters: Optional[List] = None, instance_ids: Optional[List] = None) -> list:
+ def get_instances(self, filters: list | None = None, instance_ids: list | None = None) -> list:
"""
Get list of instance details, optionally applying filters and selective instance ids
@@ -153,7 +152,7 @@ def get_instances(self, filters: Optional[List] = None, instance_ids: Optional[L
]
@only_client_type
- def get_instance_ids(self, filters: Optional[List] = None) -> list:
+ def get_instance_ids(self, filters: list | None = None) -> list:
"""
Get list of instance ids, optionally applying filters to fetch selective instances
@@ -168,7 +167,6 @@ def get_instance_state(self, instance_id: str) -> str:
:param instance_id: id of the AWS EC2 instance
:return: current state of the instance
- :rtype: str
"""
if self._api_type == "client_type":
return self.get_instances(instance_ids=[instance_id])[0]["State"]["Name"]
@@ -184,7 +182,6 @@ def wait_for_state(self, instance_id: str, target_state: str, check_interval: fl
:param check_interval: time in seconds that the job should wait in
between each instance state checks until operation is completed
:return: None
- :rtype: None
"""
instance_state = self.get_instance_state(instance_id=instance_id)
diff --git a/airflow/providers/amazon/aws/hooks/ecs.py b/airflow/providers/amazon/aws/hooks/ecs.py
new file mode 100644
index 0000000000000..f5a9945d92fa3
--- /dev/null
+++ b/airflow/providers/amazon/aws/hooks/ecs.py
@@ -0,0 +1,225 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import time
+from collections import deque
+from datetime import datetime, timedelta
+from enum import Enum
+from logging import Logger
+from threading import Event, Thread
+from typing import Generator
+
+from botocore.exceptions import ClientError, ConnectionClosedError
+from botocore.waiter import Waiter
+
+from airflow.providers.amazon.aws.exceptions import EcsOperatorError, EcsTaskFailToStart
+from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
+from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
+from airflow.typing_compat import Protocol, runtime_checkable
+
+
+def should_retry(exception: Exception):
+ """Check if exception is related to ECS resource quota (CPU, MEM)."""
+ if isinstance(exception, EcsOperatorError):
+ return any(
+ quota_reason in failure["reason"]
+ for quota_reason in ["RESOURCE:MEMORY", "RESOURCE:CPU"]
+ for failure in exception.failures
+ )
+ return False
+
+
+def should_retry_eni(exception: Exception):
+ """Check if exception is related to ENI (Elastic Network Interfaces)."""
+ if isinstance(exception, EcsTaskFailToStart):
+ return any(
+ eni_reason in exception.message
+ for eni_reason in ["network interface provisioning", "ResourceInitializationError"]
+ )
+ return False
+
+
+class EcsClusterStates(Enum):
+ """Contains the possible State values of an ECS Cluster."""
+
+ ACTIVE = "ACTIVE"
+ PROVISIONING = "PROVISIONING"
+ DEPROVISIONING = "DEPROVISIONING"
+ FAILED = "FAILED"
+ INACTIVE = "INACTIVE"
+
+
+class EcsTaskDefinitionStates(Enum):
+ """Contains the possible State values of an ECS Task Definition."""
+
+ ACTIVE = "ACTIVE"
+ INACTIVE = "INACTIVE"
+
+
+class EcsTaskStates(Enum):
+ """Contains the possible State values of an ECS Task."""
+
+ PROVISIONING = "PROVISIONING"
+ PENDING = "PENDING"
+ ACTIVATING = "ACTIVATING"
+ RUNNING = "RUNNING"
+ DEACTIVATING = "DEACTIVATING"
+ STOPPING = "STOPPING"
+ DEPROVISIONING = "DEPROVISIONING"
+ STOPPED = "STOPPED"
+ NONE = "NONE"
+
+
+class EcsHook(AwsGenericHook):
+ """
+ Interact with AWS Elastic Container Service, using the boto3 library
+ Hook attribute `conn` has all methods that listed in documentation
+
+ .. seealso::
+ - https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html
+ - https://docs.aws.amazon.com/AmazonECS/latest/APIReference/Welcome.html
+
+ Additional arguments (such as ``aws_conn_id`` or ``region_name``) may be specified and
+ are passed down to the underlying AwsBaseHook.
+
+ .. seealso::
+ :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsGenericHook`
+
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ """
+
+ def __init__(self, *args, **kwargs) -> None:
+ kwargs["client_type"] = "ecs"
+ super().__init__(*args, **kwargs)
+
+ def get_cluster_state(self, cluster_name: str) -> str:
+ return self.conn.describe_clusters(clusters=[cluster_name])["clusters"][0]["status"]
+
+ def get_task_definition_state(self, task_definition: str) -> str:
+ return self.conn.describe_task_definition(taskDefinition=task_definition)["taskDefinition"]["status"]
+
+ def get_task_state(self, cluster, task) -> str:
+ return self.conn.describe_tasks(cluster=cluster, tasks=[task])["tasks"][0]["lastStatus"]
+
+
+class EcsTaskLogFetcher(Thread):
+ """
+ Fetches Cloudwatch log events with specific interval as a thread
+ and sends the log events to the info channel of the provided logger.
+ """
+
+ def __init__(
+ self,
+ *,
+ log_group: str,
+ log_stream_name: str,
+ fetch_interval: timedelta,
+ logger: Logger,
+ aws_conn_id: str | None = "aws_default",
+ region_name: str | None = None,
+ ):
+ super().__init__()
+ self._event = Event()
+
+ self.fetch_interval = fetch_interval
+
+ self.logger = logger
+ self.log_group = log_group
+ self.log_stream_name = log_stream_name
+
+ self.hook = AwsLogsHook(aws_conn_id=aws_conn_id, region_name=region_name)
+
+ def run(self) -> None:
+ logs_to_skip = 0
+ while not self.is_stopped():
+ time.sleep(self.fetch_interval.total_seconds())
+ log_events = self._get_log_events(logs_to_skip)
+ for log_event in log_events:
+ self.logger.info(self._event_to_str(log_event))
+ logs_to_skip += 1
+
+ def _get_log_events(self, skip: int = 0) -> Generator:
+ try:
+ yield from self.hook.get_log_events(self.log_group, self.log_stream_name, skip=skip)
+ except ClientError as error:
+ if error.response["Error"]["Code"] != "ResourceNotFoundException":
+ self.logger.warning("Error on retrieving Cloudwatch log events", error)
+
+ yield from ()
+ except ConnectionClosedError as error:
+ self.logger.warning("ConnectionClosedError on retrieving Cloudwatch log events", error)
+ yield from ()
+
+ def _event_to_str(self, event: dict) -> str:
+ event_dt = datetime.utcfromtimestamp(event["timestamp"] / 1000.0)
+ formatted_event_dt = event_dt.strftime("%Y-%m-%d %H:%M:%S,%f")[:-3]
+ message = event["message"]
+ return f"[{formatted_event_dt}] {message}"
+
+ def get_last_log_messages(self, number_messages) -> list:
+ return [log["message"] for log in deque(self._get_log_events(), maxlen=number_messages)]
+
+ def get_last_log_message(self) -> str | None:
+ try:
+ return self.get_last_log_messages(1)[0]
+ except IndexError:
+ return None
+
+ def is_stopped(self) -> bool:
+ return self._event.is_set()
+
+ def stop(self):
+ self._event.set()
+
+
+@runtime_checkable
+class EcsProtocol(Protocol):
+ """
+ A structured Protocol for ``boto3.client('ecs')``. This is used for type hints on
+ :py:meth:`.EcsOperator.client`.
+
+ .. seealso::
+
+ - https://mypy.readthedocs.io/en/latest/protocols.html
+ - https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html
+ """
+
+ def run_task(self, **kwargs) -> dict:
+ """https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.run_task""" # noqa: E501
+ ...
+
+ def get_waiter(self, x: str) -> Waiter:
+ """https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.get_waiter""" # noqa: E501
+ ...
+
+ def describe_tasks(self, cluster: str, tasks) -> dict:
+ """https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.describe_tasks""" # noqa: E501
+ ...
+
+ def stop_task(self, cluster, task, reason: str) -> dict:
+ """https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.stop_task""" # noqa: E501
+ ...
+
+ def describe_task_definition(self, taskDefinition: str) -> dict:
+ """https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.describe_task_definition""" # noqa: E501
+ ...
+
+ def list_tasks(self, cluster: str, launchType: str, desiredStatus: str, family: str) -> dict:
+ """https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.list_tasks""" # noqa: E501
+ ...
diff --git a/airflow/providers/amazon/aws/hooks/eks.py b/airflow/providers/amazon/aws/hooks/eks.py
index d2a795e498dfd..03488dea96771 100644
--- a/airflow/providers/amazon/aws/hooks/eks.py
+++ b/airflow/providers/amazon/aws/hooks/eks.py
@@ -14,30 +14,30 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Interact with Amazon EKS, using the boto3 library."""
+from __future__ import annotations
+
import base64
import json
import sys
import tempfile
-import warnings
from contextlib import contextmanager
from enum import Enum
from functools import partial
-from typing import Callable, Dict, Generator, List, Optional
+from typing import Callable, Generator
-import yaml
from botocore.exceptions import ClientError
from botocore.signers import RequestSigner
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
+from airflow.utils import yaml
from airflow.utils.json import AirflowJsonEncoder
-DEFAULT_PAGINATION_TOKEN = ''
+DEFAULT_PAGINATION_TOKEN = ""
STS_TOKEN_EXPIRES_IN = 60
AUTHENTICATION_API_VERSION = "client.authentication.k8s.io/v1alpha1"
-_POD_USERNAME = 'aws'
-_CONTEXT_NAME = 'aws'
+_POD_USERNAME = "aws"
+_CONTEXT_NAME = "aws"
class ClusterStates(Enum):
@@ -86,7 +86,7 @@ class EksHook(AwsBaseHook):
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
"""
- client_type = 'eks'
+ client_type = "eks"
def __init__(self, *args, **kwargs) -> None:
kwargs["client_type"] = self.client_type
@@ -96,9 +96,9 @@ def create_cluster(
self,
name: str,
roleArn: str,
- resourcesVpcConfig: Dict,
+ resourcesVpcConfig: dict,
**kwargs,
- ) -> Dict:
+ ) -> dict:
"""
Creates an Amazon EKS control plane.
@@ -108,7 +108,6 @@ def create_cluster(
:param resourcesVpcConfig: The VPC configuration used by the cluster control plane.
:return: Returns descriptive information about the created EKS Cluster.
- :rtype: Dict
"""
eks_client = self.conn
@@ -116,19 +115,19 @@ def create_cluster(
name=name, roleArn=roleArn, resourcesVpcConfig=resourcesVpcConfig, **kwargs
)
- self.log.info("Created Amazon EKS cluster with the name %s.", response.get('cluster').get('name'))
+ self.log.info("Created Amazon EKS cluster with the name %s.", response.get("cluster").get("name"))
return response
def create_nodegroup(
self,
clusterName: str,
nodegroupName: str,
- subnets: List[str],
- nodeRole: Optional[str],
+ subnets: list[str],
+ nodeRole: str | None,
*,
- tags: Optional[Dict] = None,
+ tags: dict | None = None,
**kwargs,
- ) -> Dict:
+ ) -> dict:
"""
Creates an Amazon EKS managed node group for an Amazon EKS Cluster.
@@ -139,17 +138,16 @@ def create_nodegroup(
:param tags: Optional tags to apply to your nodegroup.
:return: Returns descriptive information about the created EKS Managed Nodegroup.
- :rtype: Dict
"""
eks_client = self.conn
# The below tag is mandatory and must have a value of either 'owned' or 'shared'
# A value of 'owned' denotes that the subnets are exclusive to the nodegroup.
# The 'shared' value allows more than one resource to use the subnet.
- cluster_tag_key = f'kubernetes.io/cluster/{clusterName}'
+ cluster_tag_key = f"kubernetes.io/cluster/{clusterName}"
resolved_tags = tags or {}
if cluster_tag_key not in resolved_tags:
- resolved_tags[cluster_tag_key] = 'owned'
+ resolved_tags[cluster_tag_key] = "owned"
response = eks_client.create_nodegroup(
clusterName=clusterName,
@@ -162,19 +160,19 @@ def create_nodegroup(
self.log.info(
"Created an Amazon EKS managed node group named %s in Amazon EKS cluster %s",
- response.get('nodegroup').get('nodegroupName'),
- response.get('nodegroup').get('clusterName'),
+ response.get("nodegroup").get("nodegroupName"),
+ response.get("nodegroup").get("clusterName"),
)
return response
def create_fargate_profile(
self,
clusterName: str,
- fargateProfileName: Optional[str],
- podExecutionRoleArn: Optional[str],
- selectors: List,
+ fargateProfileName: str | None,
+ podExecutionRoleArn: str | None,
+ selectors: list,
**kwargs,
- ) -> Dict:
+ ) -> dict:
"""
Creates an AWS Fargate profile for an Amazon EKS cluster.
@@ -185,7 +183,6 @@ def create_fargate_profile(
:param selectors: The selectors to match for pods to use this Fargate profile.
:return: Returns descriptive information about the created Fargate profile.
- :rtype: Dict
"""
eks_client = self.conn
@@ -199,28 +196,27 @@ def create_fargate_profile(
self.log.info(
"Created AWS Fargate profile with the name %s for Amazon EKS cluster %s.",
- response.get('fargateProfile').get('fargateProfileName'),
- response.get('fargateProfile').get('clusterName'),
+ response.get("fargateProfile").get("fargateProfileName"),
+ response.get("fargateProfile").get("clusterName"),
)
return response
- def delete_cluster(self, name: str) -> Dict:
+ def delete_cluster(self, name: str) -> dict:
"""
Deletes the Amazon EKS Cluster control plane.
:param name: The name of the cluster to delete.
:return: Returns descriptive information about the deleted EKS Cluster.
- :rtype: Dict
"""
eks_client = self.conn
response = eks_client.delete_cluster(name=name)
- self.log.info("Deleted Amazon EKS cluster with the name %s.", response.get('cluster').get('name'))
+ self.log.info("Deleted Amazon EKS cluster with the name %s.", response.get("cluster").get("name"))
return response
- def delete_nodegroup(self, clusterName: str, nodegroupName: str) -> Dict:
+ def delete_nodegroup(self, clusterName: str, nodegroupName: str) -> dict:
"""
Deletes an Amazon EKS managed node group from a specified cluster.
@@ -228,7 +224,6 @@ def delete_nodegroup(self, clusterName: str, nodegroupName: str) -> Dict:
:param nodegroupName: The name of the nodegroup to delete.
:return: Returns descriptive information about the deleted EKS Managed Nodegroup.
- :rtype: Dict
"""
eks_client = self.conn
@@ -236,12 +231,12 @@ def delete_nodegroup(self, clusterName: str, nodegroupName: str) -> Dict:
self.log.info(
"Deleted Amazon EKS managed node group named %s from Amazon EKS cluster %s.",
- response.get('nodegroup').get('nodegroupName'),
- response.get('nodegroup').get('clusterName'),
+ response.get("nodegroup").get("nodegroupName"),
+ response.get("nodegroup").get("clusterName"),
)
return response
- def delete_fargate_profile(self, clusterName: str, fargateProfileName: str) -> Dict:
+ def delete_fargate_profile(self, clusterName: str, fargateProfileName: str) -> dict:
"""
Deletes an AWS Fargate profile from a specified Amazon EKS cluster.
@@ -249,7 +244,6 @@ def delete_fargate_profile(self, clusterName: str, fargateProfileName: str) -> D
:param fargateProfileName: The name of the Fargate profile to delete.
:return: Returns descriptive information about the deleted Fargate profile.
- :rtype: Dict
"""
eks_client = self.conn
@@ -259,12 +253,12 @@ def delete_fargate_profile(self, clusterName: str, fargateProfileName: str) -> D
self.log.info(
"Deleted AWS Fargate profile with the name %s from Amazon EKS cluster %s.",
- response.get('fargateProfile').get('fargateProfileName'),
- response.get('fargateProfile').get('clusterName'),
+ response.get("fargateProfile").get("fargateProfileName"),
+ response.get("fargateProfile").get("clusterName"),
)
return response
- def describe_cluster(self, name: str, verbose: bool = False) -> Dict:
+ def describe_cluster(self, name: str, verbose: bool = False) -> dict:
"""
Returns descriptive information about an Amazon EKS Cluster.
@@ -272,21 +266,20 @@ def describe_cluster(self, name: str, verbose: bool = False) -> Dict:
:param verbose: Provides additional logging if set to True. Defaults to False.
:return: Returns descriptive information about a specific EKS Cluster.
- :rtype: Dict
"""
eks_client = self.conn
response = eks_client.describe_cluster(name=name)
self.log.info(
- "Retrieved details for Amazon EKS cluster named %s.", response.get('cluster').get('name')
+ "Retrieved details for Amazon EKS cluster named %s.", response.get("cluster").get("name")
)
if verbose:
- cluster_data = response.get('cluster')
+ cluster_data = response.get("cluster")
self.log.info("Amazon EKS cluster details: %s", json.dumps(cluster_data, cls=AirflowJsonEncoder))
return response
- def describe_nodegroup(self, clusterName: str, nodegroupName: str, verbose: bool = False) -> Dict:
+ def describe_nodegroup(self, clusterName: str, nodegroupName: str, verbose: bool = False) -> dict:
"""
Returns descriptive information about an Amazon EKS managed node group.
@@ -295,7 +288,6 @@ def describe_nodegroup(self, clusterName: str, nodegroupName: str, verbose: bool
:param verbose: Provides additional logging if set to True. Defaults to False.
:return: Returns descriptive information about a specific EKS Nodegroup.
- :rtype: Dict
"""
eks_client = self.conn
@@ -303,11 +295,11 @@ def describe_nodegroup(self, clusterName: str, nodegroupName: str, verbose: bool
self.log.info(
"Retrieved details for Amazon EKS managed node group named %s in Amazon EKS cluster %s.",
- response.get('nodegroup').get('nodegroupName'),
- response.get('nodegroup').get('clusterName'),
+ response.get("nodegroup").get("nodegroupName"),
+ response.get("nodegroup").get("clusterName"),
)
if verbose:
- nodegroup_data = response.get('nodegroup')
+ nodegroup_data = response.get("nodegroup")
self.log.info(
"Amazon EKS managed node group details: %s",
json.dumps(nodegroup_data, cls=AirflowJsonEncoder),
@@ -316,7 +308,7 @@ def describe_nodegroup(self, clusterName: str, nodegroupName: str, verbose: bool
def describe_fargate_profile(
self, clusterName: str, fargateProfileName: str, verbose: bool = False
- ) -> Dict:
+ ) -> dict:
"""
Returns descriptive information about an AWS Fargate profile.
@@ -325,7 +317,6 @@ def describe_fargate_profile(
:param verbose: Provides additional logging if set to True. Defaults to False.
:return: Returns descriptive information about an AWS Fargate profile.
- :rtype: Dict
"""
eks_client = self.conn
@@ -335,11 +326,11 @@ def describe_fargate_profile(
self.log.info(
"Retrieved details for AWS Fargate profile named %s in Amazon EKS cluster %s.",
- response.get('fargateProfile').get('fargateProfileName'),
- response.get('fargateProfile').get('clusterName'),
+ response.get("fargateProfile").get("fargateProfileName"),
+ response.get("fargateProfile").get("clusterName"),
)
if verbose:
- fargate_profile_data = response.get('fargateProfile')
+ fargate_profile_data = response.get("fargateProfile")
self.log.info(
"AWS Fargate profile details: %s", json.dumps(fargate_profile_data, cls=AirflowJsonEncoder)
)
@@ -352,12 +343,11 @@ def get_cluster_state(self, clusterName: str) -> ClusterStates:
:param clusterName: The name of the cluster to check.
:return: Returns the current status of a given Amazon EKS Cluster.
- :rtype: ClusterStates
"""
eks_client = self.conn
try:
- return ClusterStates(eks_client.describe_cluster(name=clusterName).get('cluster').get('status'))
+ return ClusterStates(eks_client.describe_cluster(name=clusterName).get("cluster").get("status"))
except ClientError as ex:
if ex.response.get("Error").get("Code") == "ResourceNotFoundException":
return ClusterStates.NONEXISTENT
@@ -371,7 +361,6 @@ def get_fargate_profile_state(self, clusterName: str, fargateProfileName: str) -
:param fargateProfileName: The name of the Fargate profile to check.
:return: Returns the current status of a given AWS Fargate profile.
- :rtype: AWS FargateProfileStates
"""
eks_client = self.conn
@@ -380,8 +369,8 @@ def get_fargate_profile_state(self, clusterName: str, fargateProfileName: str) -
eks_client.describe_fargate_profile(
clusterName=clusterName, fargateProfileName=fargateProfileName
)
- .get('fargateProfile')
- .get('status')
+ .get("fargateProfile")
+ .get("status")
)
except ClientError as ex:
if ex.response.get("Error").get("Code") == "ResourceNotFoundException":
@@ -396,15 +385,14 @@ def get_nodegroup_state(self, clusterName: str, nodegroupName: str) -> Nodegroup
:param nodegroupName: The name of the nodegroup to check.
:return: Returns the current status of a given Amazon EKS Nodegroup.
- :rtype: NodegroupStates
"""
eks_client = self.conn
try:
return NodegroupStates(
eks_client.describe_nodegroup(clusterName=clusterName, nodegroupName=nodegroupName)
- .get('nodegroup')
- .get('status')
+ .get("nodegroup")
+ .get("status")
)
except ClientError as ex:
if ex.response.get("Error").get("Code") == "ResourceNotFoundException":
@@ -414,14 +402,13 @@ def get_nodegroup_state(self, clusterName: str, nodegroupName: str) -> Nodegroup
def list_clusters(
self,
verbose: bool = False,
- ) -> List:
+ ) -> list:
"""
Lists all Amazon EKS Clusters in your AWS account.
:param verbose: Provides additional logging if set to True. Defaults to False.
:return: A List containing the cluster names.
- :rtype: List
"""
eks_client = self.conn
list_cluster_call = partial(eks_client.list_clusters)
@@ -432,7 +419,7 @@ def list_nodegroups(
self,
clusterName: str,
verbose: bool = False,
- ) -> List:
+ ) -> list:
"""
Lists all Amazon EKS managed node groups associated with the specified cluster.
@@ -440,7 +427,6 @@ def list_nodegroups(
:param verbose: Provides additional logging if set to True. Defaults to False.
:return: A List of nodegroup names within the given cluster.
- :rtype: List
"""
eks_client = self.conn
list_nodegroups_call = partial(eks_client.list_nodegroups, clusterName=clusterName)
@@ -451,7 +437,7 @@ def list_fargate_profiles(
self,
clusterName: str,
verbose: bool = False,
- ) -> List:
+ ) -> list:
"""
Lists all AWS Fargate profiles associated with the specified cluster.
@@ -459,7 +445,6 @@ def list_fargate_profiles(
:param verbose: Provides additional logging if set to True. Defaults to False.
:return: A list of Fargate profile names within a given cluster.
- :rtype: List
"""
eks_client = self.conn
list_fargate_profiles_call = partial(eks_client.list_fargate_profiles, clusterName=clusterName)
@@ -468,7 +453,7 @@ def list_fargate_profiles(
api_call=list_fargate_profiles_call, response_key="fargateProfileNames", verbose=verbose
)
- def _list_all(self, api_call: Callable, response_key: str, verbose: bool) -> List:
+ def _list_all(self, api_call: Callable, response_key: str, verbose: bool) -> list:
"""
Repeatedly calls a provided boto3 API Callable and collates the responses into a List.
@@ -477,9 +462,8 @@ def _list_all(self, api_call: Callable, response_key: str, verbose: bool) -> Lis
:param verbose: Provides additional logging if set to True. Defaults to False.
:return: A List of the combined results of the provided API call.
- :rtype: List
"""
- name_collection: List = []
+ name_collection: list = []
token = DEFAULT_PAGINATION_TOKEN
while token is not None:
@@ -498,9 +482,7 @@ def _list_all(self, api_call: Callable, response_key: str, verbose: bool) -> Lis
def generate_config_file(
self,
eks_cluster_name: str,
- pod_namespace: Optional[str],
- pod_username: Optional[str] = None,
- pod_context: Optional[str] = None,
+ pod_namespace: str | None,
) -> Generator[str, None, None]:
"""
Writes the kubeconfig file given an EKS Cluster.
@@ -508,20 +490,6 @@ def generate_config_file(
:param eks_cluster_name: The name of the cluster to generate kubeconfig file for.
:param pod_namespace: The namespace to run within kubernetes.
"""
- if pod_username:
- warnings.warn(
- "This pod_username parameter is deprecated, because changing the value does not make any "
- "visible changes to the user.",
- DeprecationWarning,
- stacklevel=2,
- )
- if pod_context:
- warnings.warn(
- "This pod_context parameter is deprecated, because changing the value does not make any "
- "visible changes to the user.",
- DeprecationWarning,
- stacklevel=2,
- )
# Set up the client
eks_client = self.conn
@@ -588,7 +556,7 @@ def generate_config_file(
}
config_text = yaml.dump(cluster_config, default_flow_style=False)
- with tempfile.NamedTemporaryFile(mode='w') as config_file:
+ with tempfile.NamedTemporaryFile(mode="w") as config_file:
config_file.write(config_text)
config_file.flush()
yield config_file.name
@@ -597,49 +565,34 @@ def fetch_access_token_for_cluster(self, eks_cluster_name: str) -> str:
session = self.get_session()
service_id = self.conn.meta.service_model.service_id
sts_url = (
- f'https://sts.{session.region_name}.amazonaws.com/?Action=GetCallerIdentity&Version=2011-06-15'
+ f"https://sts.{session.region_name}.amazonaws.com/?Action=GetCallerIdentity&Version=2011-06-15"
)
signer = RequestSigner(
service_id=service_id,
region_name=session.region_name,
- signing_name='sts',
- signature_version='v4',
+ signing_name="sts",
+ signature_version="v4",
credentials=session.get_credentials(),
event_emitter=session.events,
)
request_params = {
- 'method': 'GET',
- 'url': sts_url,
- 'body': {},
- 'headers': {'x-k8s-aws-id': eks_cluster_name},
- 'context': {},
+ "method": "GET",
+ "url": sts_url,
+ "body": {},
+ "headers": {"x-k8s-aws-id": eks_cluster_name},
+ "context": {},
}
signed_url = signer.generate_presigned_url(
request_dict=request_params,
region_name=session.region_name,
expires_in=STS_TOKEN_EXPIRES_IN,
- operation_name='',
+ operation_name="",
)
- base64_url = base64.urlsafe_b64encode(signed_url.encode('utf-8')).decode('utf-8')
+ base64_url = base64.urlsafe_b64encode(signed_url.encode("utf-8")).decode("utf-8")
# remove any base64 encoding padding:
- return 'k8s-aws-v1.' + base64_url.rstrip("=")
-
-
-class EKSHook(EksHook):
- """
- This hook is deprecated.
- Please use :class:`airflow.providers.amazon.aws.hooks.eks.EksHook`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- "This hook is deprecated. Please use `airflow.providers.amazon.aws.hooks.eks.EksHook`.",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
+ return "k8s-aws-v1." + base64_url.rstrip("=")
diff --git a/airflow/providers/amazon/aws/hooks/elasticache_replication_group.py b/airflow/providers/amazon/aws/hooks/elasticache_replication_group.py
index 47af28845d1aa..f859bef5e861f 100644
--- a/airflow/providers/amazon/aws/hooks/elasticache_replication_group.py
+++ b/airflow/providers/amazon/aws/hooks/elasticache_replication_group.py
@@ -15,8 +15,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
from time import sleep
-from typing import Optional
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
@@ -57,7 +58,6 @@ def create_replication_group(self, config: dict) -> dict:
:param config: Configuration for creating the replication group
:return: Response from ElastiCache create replication group API
- :rtype: dict
"""
return self.conn.create_replication_group(**config)
@@ -67,7 +67,6 @@ def delete_replication_group(self, replication_group_id: str) -> dict:
:param replication_group_id: ID of replication group to delete
:return: Response from ElastiCache delete replication group API
- :rtype: dict
"""
return self.conn.delete_replication_group(ReplicationGroupId=replication_group_id)
@@ -77,7 +76,6 @@ def describe_replication_group(self, replication_group_id: str) -> dict:
:param replication_group_id: ID of replication group to describe
:return: Response from ElastiCache describe replication group API
- :rtype: dict
"""
return self.conn.describe_replication_groups(ReplicationGroupId=replication_group_id)
@@ -87,9 +85,8 @@ def get_replication_group_status(self, replication_group_id: str) -> str:
:param replication_group_id: ID of replication group to check for status
:return: Current status of replication group
- :rtype: str
"""
- return self.describe_replication_group(replication_group_id)['ReplicationGroups'][0]['Status']
+ return self.describe_replication_group(replication_group_id)["ReplicationGroups"][0]["Status"]
def is_replication_group_available(self, replication_group_id: str) -> bool:
"""
@@ -97,17 +94,16 @@ def is_replication_group_available(self, replication_group_id: str) -> bool:
:param replication_group_id: ID of replication group to check for availability
:return: True if available else False
- :rtype: bool
"""
- return self.get_replication_group_status(replication_group_id) == 'available'
+ return self.get_replication_group_status(replication_group_id) == "available"
def wait_for_availability(
self,
replication_group_id: str,
- initial_sleep_time: Optional[float] = None,
- exponential_back_off_factor: Optional[float] = None,
- max_retries: Optional[int] = None,
- ):
+ initial_sleep_time: float | None = None,
+ exponential_back_off_factor: float | None = None,
+ max_retries: int | None = None,
+ ) -> bool:
"""
Check if replication group is available or not by performing a describe over it
@@ -119,13 +115,12 @@ def wait_for_availability(
:param max_retries: Max retries for checking availability of replication group
If this is not supplied then this is defaulted to class level value
:return: True if replication is available else False
- :rtype: bool
"""
sleep_time = initial_sleep_time or self.initial_poke_interval
exponential_back_off_factor = exponential_back_off_factor or self.exponential_back_off_factor
max_retries = max_retries or self.max_retries
num_tries = 0
- status = 'not-found'
+ status = "not-found"
stop_poking = False
while not stop_poking and num_tries <= max_retries:
@@ -133,7 +128,7 @@ def wait_for_availability(
stop_poking = status in self.TERMINAL_STATES
self.log.info(
- 'Current status of replication group with ID %s is %s', replication_group_id, status
+ "Current status of replication group with ID %s is %s", replication_group_id, status
)
if not stop_poking:
@@ -143,13 +138,13 @@ def wait_for_availability(
if num_tries > max_retries:
break
- self.log.info('Poke retry %s. Sleep time %s seconds. Sleeping...', num_tries, sleep_time)
+ self.log.info("Poke retry %s. Sleep time %s seconds. Sleeping...", num_tries, sleep_time)
sleep(sleep_time)
sleep_time *= exponential_back_off_factor
- if status != 'available':
+ if status != "available":
self.log.warning('Replication group is not available. Current status is "%s"', status)
return False
@@ -159,9 +154,9 @@ def wait_for_availability(
def wait_for_deletion(
self,
replication_group_id: str,
- initial_sleep_time: Optional[float] = None,
- exponential_back_off_factor: Optional[float] = None,
- max_retries: Optional[int] = None,
+ initial_sleep_time: float | None = None,
+ exponential_back_off_factor: float | None = None,
+ max_retries: int | None = None,
):
"""
Helper for deleting a replication group ensuring it is either deleted or can't be deleted
@@ -174,7 +169,6 @@ def wait_for_deletion(
:param max_retries: Max retries for checking availability of replication group
If this is not supplied then this is defaulted to class level value
:return: Response from ElastiCache delete replication group API and flag to identify if deleted or not
- :rtype: (dict, bool)
"""
deleted = False
sleep_time = initial_sleep_time or self.initial_poke_interval
@@ -188,12 +182,12 @@ def wait_for_deletion(
status = self.get_replication_group_status(replication_group_id=replication_group_id)
self.log.info(
- 'Current status of replication group with ID %s is %s', replication_group_id, status
+ "Current status of replication group with ID %s is %s", replication_group_id, status
)
# Can only delete if status is `available`
# Status becomes `deleting` after this call so this will only run once
- if status == 'available':
+ if status == "available":
self.log.info("Initiating delete and then wait for it to finish")
response = self.delete_replication_group(replication_group_id=replication_group_id)
@@ -213,9 +207,9 @@ def wait_for_deletion(
# modifying - Replication group has status deleting which is not valid
# for deletion.
- message = exp.response['Error']['Message']
+ message = exp.response["Error"]["Message"]
- self.log.warning('Received error message from AWS ElastiCache API : %s', message)
+ self.log.warning("Received error message from AWS ElastiCache API : %s", message)
if not deleted:
num_tries += 1
@@ -224,7 +218,7 @@ def wait_for_deletion(
if num_tries > max_retries:
break
- self.log.info('Poke retry %s. Sleep time %s seconds. Sleeping...', num_tries, sleep_time)
+ self.log.info("Poke retry %s. Sleep time %s seconds. Sleeping...", num_tries, sleep_time)
sleep(sleep_time)
@@ -235,10 +229,10 @@ def wait_for_deletion(
def ensure_delete_replication_group(
self,
replication_group_id: str,
- initial_sleep_time: Optional[float] = None,
- exponential_back_off_factor: Optional[float] = None,
- max_retries: Optional[int] = None,
- ):
+ initial_sleep_time: float | None = None,
+ exponential_back_off_factor: float | None = None,
+ max_retries: int | None = None,
+ ) -> dict:
"""
Delete a replication group ensuring it is either deleted or can't be deleted
@@ -250,10 +244,9 @@ def ensure_delete_replication_group(
:param max_retries: Max retries for checking availability of replication group
If this is not supplied then this is defaulted to class level value
:return: Response from ElastiCache delete replication group API
- :rtype: dict
:raises AirflowException: If replication group is not deleted
"""
- self.log.info('Deleting replication group with ID %s', replication_group_id)
+ self.log.info("Deleting replication group with ID %s", replication_group_id)
response, deleted = self.wait_for_deletion(
replication_group_id=replication_group_id,
diff --git a/airflow/providers/amazon/aws/hooks/emr.py b/airflow/providers/amazon/aws/hooks/emr.py
index 143bdcdcc8913..5423dd1af83a5 100644
--- a/airflow/providers/amazon/aws/hooks/emr.py
+++ b/airflow/providers/amazon/aws/hooks/emr.py
@@ -15,19 +15,27 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
+import json
+import warnings
from time import sleep
-from typing import Any, Dict, List, Optional
+from typing import Any, Callable
from botocore.exceptions import ClientError
-from airflow.exceptions import AirflowException
+from airflow.compat.functools import cached_property
+from airflow.exceptions import AirflowException, AirflowNotFoundException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
class EmrHook(AwsBaseHook):
"""
- Interact with AWS EMR. emr_conn_id is only necessary for using the
- create_job_flow method.
+ Interact with Amazon Elastic MapReduce Service.
+
+ :param emr_conn_id: :ref:`Amazon Elastic MapReduce Connection `.
+ This attribute is only necessary when using
+ the :meth:`~airflow.providers.amazon.aws.hooks.emr.EmrHook.create_job_flow` method.
Additional arguments (such as ``aws_conn_id``) may be specified and
are passed down to the underlying AwsBaseHook.
@@ -36,17 +44,17 @@ class EmrHook(AwsBaseHook):
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
"""
- conn_name_attr = 'emr_conn_id'
- default_conn_name = 'emr_default'
- conn_type = 'emr'
- hook_name = 'Amazon Elastic MapReduce'
+ conn_name_attr = "emr_conn_id"
+ default_conn_name = "emr_default"
+ conn_type = "emr"
+ hook_name = "Amazon Elastic MapReduce"
- def __init__(self, emr_conn_id: Optional[str] = default_conn_name, *args, **kwargs) -> None:
+ def __init__(self, emr_conn_id: str | None = default_conn_name, *args, **kwargs) -> None:
self.emr_conn_id = emr_conn_id
kwargs["client_type"] = "emr"
super().__init__(*args, **kwargs)
- def get_cluster_id_by_name(self, emr_cluster_name: str, cluster_states: List[str]) -> Optional[str]:
+ def get_cluster_id_by_name(self, emr_cluster_name: str, cluster_states: list[str]) -> str | None:
"""
Fetch id of EMR cluster with given name and (optional) states.
Will return only if single id is found.
@@ -58,38 +66,222 @@ def get_cluster_id_by_name(self, emr_cluster_name: str, cluster_states: List[str
response = self.get_conn().list_clusters(ClusterStates=cluster_states)
matching_clusters = list(
- filter(lambda cluster: cluster['Name'] == emr_cluster_name, response['Clusters'])
+ filter(lambda cluster: cluster["Name"] == emr_cluster_name, response["Clusters"])
)
if len(matching_clusters) == 1:
- cluster_id = matching_clusters[0]['Id']
- self.log.info('Found cluster name = %s id = %s', emr_cluster_name, cluster_id)
+ cluster_id = matching_clusters[0]["Id"]
+ self.log.info("Found cluster name = %s id = %s", emr_cluster_name, cluster_id)
return cluster_id
elif len(matching_clusters) > 1:
- raise AirflowException(f'More than one cluster found for name {emr_cluster_name}')
+ raise AirflowException(f"More than one cluster found for name {emr_cluster_name}")
else:
- self.log.info('No cluster found for name %s', emr_cluster_name)
+ self.log.info("No cluster found for name %s", emr_cluster_name)
return None
- def create_job_flow(self, job_flow_overrides: Dict[str, Any]) -> Dict[str, Any]:
- """
- Creates a job flow using the config from the EMR connection.
- Keys of the json extra hash may have the arguments of the boto3
- run_job_flow method.
- Overrides for this config may be passed as the job_flow_overrides.
+ def create_job_flow(self, job_flow_overrides: dict[str, Any]) -> dict[str, Any]:
"""
- if not self.emr_conn_id:
- raise AirflowException('emr_conn_id must be present to use create_job_flow')
+ Create and start running a new cluster (job flow).
- emr_conn = self.get_connection(self.emr_conn_id)
+ This method uses ``EmrHook.emr_conn_id`` to receive the initial Amazon EMR cluster configuration.
+ If ``EmrHook.emr_conn_id`` is empty or the connection does not exist, then an empty initial
+ configuration is used.
- config = emr_conn.extra_dejson.copy()
+ :param job_flow_overrides: Is used to overwrite the parameters in the initial Amazon EMR configuration
+ cluster. The resulting configuration will be used in the boto3 emr client run_job_flow method.
+
+ .. seealso::
+ - :ref:`Amazon Elastic MapReduce Connection `
+ - `API RunJobFlow `_
+ - `boto3 emr client run_job_flow method `_.
+ """
+ config = {}
+ if self.emr_conn_id:
+ try:
+ emr_conn = self.get_connection(self.emr_conn_id)
+ except AirflowNotFoundException:
+ warnings.warn(
+ f"Unable to find {self.hook_name} Connection ID {self.emr_conn_id!r}, "
+ "using an empty initial configuration. If you want to get rid of this warning "
+ "message please provide a valid `emr_conn_id` or set it to None.",
+ UserWarning,
+ stacklevel=2,
+ )
+ else:
+ if emr_conn.conn_type and emr_conn.conn_type != self.conn_type:
+ warnings.warn(
+ f"{self.hook_name} Connection expected connection type {self.conn_type!r}, "
+ f"Connection {self.emr_conn_id!r} has conn_type={emr_conn.conn_type!r}. "
+ f"This connection might not work correctly.",
+ UserWarning,
+ stacklevel=2,
+ )
+ config = emr_conn.extra_dejson.copy()
config.update(job_flow_overrides)
response = self.get_conn().run_job_flow(**config)
return response
+ def add_job_flow_steps(
+ self, job_flow_id: str, steps: list[dict] | str | None = None, wait_for_completion: bool = False
+ ) -> list[str]:
+ """
+ Add new steps to a running cluster.
+
+ :param job_flow_id: The id of the job flow to which the steps are being added
+ :param steps: A list of the steps to be executed by the job flow
+ :param wait_for_completion: If True, wait for the steps to be completed. Default is False
+ """
+ response = self.get_conn().add_job_flow_steps(JobFlowId=job_flow_id, Steps=steps)
+
+ if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
+ raise AirflowException(f"Adding steps failed: {response}")
+
+ self.log.info("Steps %s added to JobFlow", response["StepIds"])
+ if wait_for_completion:
+ waiter = self.get_conn().get_waiter("step_complete")
+ for step_id in response["StepIds"]:
+ waiter.wait(
+ ClusterId=job_flow_id,
+ StepId=step_id,
+ WaiterConfig={
+ "Delay": 5,
+ "MaxAttempts": 100,
+ },
+ )
+ return response["StepIds"]
+
+ def test_connection(self):
+ """
+ Return failed state for test Amazon Elastic MapReduce Connection (untestable).
+
+ We need to overwrite this method because this hook is based on
+ :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsGenericHook`,
+ otherwise it will try to test connection to AWS STS by using the default boto3 credential strategy.
+ """
+ msg = (
+ f"{self.hook_name!r} Airflow Connection cannot be tested, by design it stores "
+ f"only key/value pairs and does not make a connection to an external resource."
+ )
+ return False, msg
+
+ @staticmethod
+ def get_ui_field_behaviour() -> dict[str, Any]:
+ """Returns custom UI field behaviour for Amazon Elastic MapReduce Connection."""
+ return {
+ "hidden_fields": ["host", "schema", "port", "login", "password"],
+ "relabeling": {
+ "extra": "Run Job Flow Configuration",
+ },
+ "placeholders": {
+ "extra": json.dumps(
+ {
+ "Name": "MyClusterName",
+ "ReleaseLabel": "emr-5.36.0",
+ "Applications": [{"Name": "Spark"}],
+ "Instances": {
+ "InstanceGroups": [
+ {
+ "Name": "Primary node",
+ "Market": "SPOT",
+ "InstanceRole": "MASTER",
+ "InstanceType": "m5.large",
+ "InstanceCount": 1,
+ },
+ ],
+ "KeepJobFlowAliveWhenNoSteps": False,
+ "TerminationProtected": False,
+ },
+ "StepConcurrencyLevel": 2,
+ },
+ indent=2,
+ ),
+ },
+ }
+
+
+class EmrServerlessHook(AwsBaseHook):
+ """
+ Interact with EMR Serverless API.
+
+ Additional arguments (such as ``aws_conn_id``) may be specified and
+ are passed down to the underlying AwsBaseHook.
+
+ .. seealso::
+ :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
+ """
+
+ JOB_INTERMEDIATE_STATES = {"PENDING", "RUNNING", "SCHEDULED", "SUBMITTED"}
+ JOB_FAILURE_STATES = {"FAILED", "CANCELLING", "CANCELLED"}
+ JOB_SUCCESS_STATES = {"SUCCESS"}
+ JOB_TERMINAL_STATES = JOB_SUCCESS_STATES.union(JOB_FAILURE_STATES)
+
+ APPLICATION_INTERMEDIATE_STATES = {"CREATING", "STARTING", "STOPPING"}
+ APPLICATION_FAILURE_STATES = {"STOPPED", "TERMINATED"}
+ APPLICATION_SUCCESS_STATES = {"CREATED", "STARTED"}
+
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
+ kwargs["client_type"] = "emr-serverless"
+ super().__init__(*args, **kwargs)
+
+ @cached_property
+ def conn(self):
+ """Get the underlying boto3 EmrServerlessAPIService client (cached)"""
+ return super().conn
+
+ # This method should be replaced with boto waiters which would implement timeouts and backoff nicely.
+ def waiter(
+ self,
+ get_state_callable: Callable,
+ get_state_args: dict,
+ parse_response: list,
+ desired_state: set,
+ failure_states: set,
+ object_type: str,
+ action: str,
+ countdown: int = 25 * 60,
+ check_interval_seconds: int = 60,
+ ) -> None:
+ """
+ Will run the sensor until it turns True.
+
+ :param get_state_callable: A callable to run until it returns True
+ :param get_state_args: Arguments to pass to get_state_callable
+ :param parse_response: Dictionary keys to extract state from response of get_state_callable
+ :param desired_state: Wait until the getter returns this value
+ :param failure_states: A set of states which indicate failure and should throw an
+ exception if any are reached before the desired_state
+ :param object_type: Used for the reporting string. What are you waiting for? (application, job, etc)
+ :param action: Used for the reporting string. What action are you waiting for? (created, deleted, etc)
+ :param countdown: Total amount of time the waiter should wait for the desired state
+ before timing out (in seconds). Defaults to 25 * 60 seconds.
+ :param check_interval_seconds: Number of seconds waiter should wait before attempting
+ to retry get_state_callable. Defaults to 60 seconds.
+ """
+ response = get_state_callable(**get_state_args)
+ state: str = self.get_state(response, parse_response)
+ while state not in desired_state:
+ if state in failure_states:
+ raise AirflowException(f"{object_type.title()} reached failure state {state}.")
+ if countdown >= check_interval_seconds:
+ countdown -= check_interval_seconds
+ self.log.info("Waiting for %s to be %s.", object_type.lower(), action.lower())
+ sleep(check_interval_seconds)
+ state = self.get_state(get_state_callable(**get_state_args), parse_response)
+ else:
+ message = f"{object_type.title()} still not {action.lower()} after the allocated time limit."
+ self.log.error(message)
+ raise RuntimeError(message)
+
+ def get_state(self, response, keys) -> str:
+ value = response
+ for key in keys:
+ if value is not None:
+ value = value.get(key, None)
+ return value
+
class EmrContainerHook(AwsBaseHook):
"""
@@ -121,19 +313,45 @@ class EmrContainerHook(AwsBaseHook):
"CANCEL_PENDING",
)
- def __init__(self, *args: Any, virtual_cluster_id: Optional[str] = None, **kwargs: Any) -> None:
+ def __init__(self, *args: Any, virtual_cluster_id: str | None = None, **kwargs: Any) -> None:
super().__init__(client_type="emr-containers", *args, **kwargs) # type: ignore
self.virtual_cluster_id = virtual_cluster_id
+ def create_emr_on_eks_cluster(
+ self,
+ virtual_cluster_name: str,
+ eks_cluster_name: str,
+ eks_namespace: str,
+ tags: dict | None = None,
+ ) -> str:
+ response = self.conn.create_virtual_cluster(
+ name=virtual_cluster_name,
+ containerProvider={
+ "id": eks_cluster_name,
+ "type": "EKS",
+ "info": {"eksInfo": {"namespace": eks_namespace}},
+ },
+ tags=tags or {},
+ )
+
+ if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
+ raise AirflowException(f"Create EMR EKS Cluster failed: {response}")
+ else:
+ self.log.info(
+ "Create EMR EKS Cluster success - virtual cluster id %s",
+ response["id"],
+ )
+ return response["id"]
+
def submit_job(
self,
name: str,
execution_role_arn: str,
release_label: str,
job_driver: dict,
- configuration_overrides: Optional[dict] = None,
- client_request_token: Optional[str] = None,
- tags: Optional[dict] = None,
+ configuration_overrides: dict | None = None,
+ client_request_token: str | None = None,
+ tags: dict | None = None,
) -> str:
"""
Submit a job to the EMR Containers API and return the job ID.
@@ -166,17 +384,17 @@ def submit_job(
response = self.conn.start_job_run(**params)
- if response['ResponseMetadata']['HTTPStatusCode'] != 200:
- raise AirflowException(f'Start Job Run failed: {response}')
+ if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
+ raise AirflowException(f"Start Job Run failed: {response}")
else:
self.log.info(
"Start Job Run success - Job Id %s and virtual cluster id %s",
- response['id'],
- response['virtualClusterId'],
+ response["id"],
+ response["virtualClusterId"],
)
- return response['id']
+ return response["id"]
- def get_job_failure_reason(self, job_id: str) -> Optional[str]:
+ def get_job_failure_reason(self, job_id: str) -> str | None:
"""
Fetch the reason for a job failure (e.g. error message). Returns None or reason string.
@@ -191,17 +409,17 @@ def get_job_failure_reason(self, job_id: str) -> Optional[str]:
virtualClusterId=self.virtual_cluster_id,
id=job_id,
)
- failure_reason = response['jobRun']['failureReason']
+ failure_reason = response["jobRun"]["failureReason"]
state_details = response["jobRun"]["stateDetails"]
reason = f"{failure_reason} - {state_details}"
except KeyError:
- self.log.error('Could not get status of the EMR on EKS job')
+ self.log.error("Could not get status of the EMR on EKS job")
except ClientError as ex:
- self.log.error('AWS request failed, check logs for more info: %s', ex)
+ self.log.error("AWS request failed, check logs for more info: %s", ex)
return reason
- def check_query_status(self, job_id: str) -> Optional[str]:
+ def check_query_status(self, job_id: str) -> str | None:
"""
Fetch the status of submitted job run. Returns None or one of valid query states.
See: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/emr-containers.html#EMRContainers.Client.describe_job_run # noqa: E501
@@ -216,26 +434,43 @@ def check_query_status(self, job_id: str) -> Optional[str]:
return response["jobRun"]["state"]
except self.conn.exceptions.ResourceNotFoundException:
# If the job is not found, we raise an exception as something fatal has happened.
- raise AirflowException(f'Job ID {job_id} not found on Virtual Cluster {self.virtual_cluster_id}')
+ raise AirflowException(f"Job ID {job_id} not found on Virtual Cluster {self.virtual_cluster_id}")
except ClientError as ex:
# If we receive a generic ClientError, we swallow the exception so that the
- self.log.error('AWS request failed, check logs for more info: %s', ex)
+ self.log.error("AWS request failed, check logs for more info: %s", ex)
return None
def poll_query_status(
- self, job_id: str, max_tries: Optional[int] = None, poll_interval: int = 30
- ) -> Optional[str]:
+ self,
+ job_id: str,
+ max_tries: int | None = None,
+ poll_interval: int = 30,
+ max_polling_attempts: int | None = None,
+ ) -> str | None:
"""
Poll the status of submitted job run until query state reaches final state.
Returns one of the final states.
:param job_id: Id of submitted job run
- :param max_tries: Number of times to poll for query state before function exits
+ :param max_tries: Deprecated - Use max_polling_attempts instead
:param poll_interval: Time (in seconds) to wait between calls to check query status on EMR
+ :param max_polling_attempts: Number of times to poll for query state before function exits
:return: str
"""
+ if max_tries:
+ warnings.warn(
+ f"Method `{self.__class__.__name__}.max_tries` is deprecated and will be removed "
+ "in a future release. Please use method `max_polling_attempts` instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ if max_polling_attempts and max_polling_attempts != max_tries:
+ raise Exception("max_polling_attempts must be the same value as max_tries")
+ else:
+ max_polling_attempts = max_tries
+
try_number = 1
- final_query_state = None # Query state when query reaches final state or max_tries reached
+ final_query_state = None # Query state when query reaches final state or max_polling_attempts reached
while True:
query_state = self.check_query_status(job_id)
@@ -247,14 +482,16 @@ def poll_query_status(
break
else:
self.log.info("Try %s: Query is still in non-terminal state - %s", try_number, query_state)
- if max_tries and try_number >= max_tries: # Break loop if max_tries reached
+ if (
+ max_polling_attempts and try_number >= max_polling_attempts
+ ): # Break loop if max_polling_attempts reached
final_query_state = query_state
break
try_number += 1
sleep(poll_interval)
return final_query_state
- def stop_query(self, job_id: str) -> Dict:
+ def stop_query(self, job_id: str) -> dict:
"""
Cancel the submitted job_run
diff --git a/airflow/providers/amazon/aws/hooks/emr_containers.py b/airflow/providers/amazon/aws/hooks/emr_containers.py
deleted file mode 100644
index 1e3b7a0ea4aea..0000000000000
--- a/airflow/providers/amazon/aws/hooks/emr_containers.py
+++ /dev/null
@@ -1,44 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.hooks.emr`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.emr`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class EMRContainerHook(EmrContainerHook):
- """
- This class is deprecated.
- Please use :class:`airflow.providers.amazon.aws.hooks.emr.EmrContainerHook`.
- """
-
- def __init__(self, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.amazon.aws.hooks.emr.EmrContainerHook`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(**kwargs)
diff --git a/airflow/providers/amazon/aws/hooks/glacier.py b/airflow/providers/amazon/aws/hooks/glacier.py
index 00c4b884ae262..926dc05843509 100644
--- a/airflow/providers/amazon/aws/hooks/glacier.py
+++ b/airflow/providers/amazon/aws/hooks/glacier.py
@@ -15,9 +15,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
-
-from typing import Any, Dict
+from typing import Any
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
@@ -29,20 +29,20 @@ def __init__(self, aws_conn_id: str = "aws_default") -> None:
super().__init__(client_type="glacier")
self.aws_conn_id = aws_conn_id
- def retrieve_inventory(self, vault_name: str) -> Dict[str, Any]:
+ def retrieve_inventory(self, vault_name: str) -> dict[str, Any]:
"""
Initiate an Amazon Glacier inventory-retrieval job
:param vault_name: the Glacier vault on which job is executed
"""
- job_params = {'Type': 'inventory-retrieval'}
+ job_params = {"Type": "inventory-retrieval"}
self.log.info("Retrieving inventory for vault: %s", vault_name)
response = self.get_conn().initiate_job(vaultName=vault_name, jobParameters=job_params)
self.log.info("Initiated inventory-retrieval job for: %s", vault_name)
self.log.info("Retrieval Job ID: %s", response["jobId"])
return response
- def retrieve_inventory_results(self, vault_name: str, job_id: str) -> Dict[str, Any]:
+ def retrieve_inventory_results(self, vault_name: str, job_id: str) -> dict[str, Any]:
"""
Retrieve the results of an Amazon Glacier inventory-retrieval job
@@ -53,7 +53,7 @@ def retrieve_inventory_results(self, vault_name: str, job_id: str) -> Dict[str,
response = self.get_conn().get_job_output(vaultName=vault_name, jobId=job_id)
return response
- def describe_job(self, vault_name: str, job_id: str) -> Dict[str, Any]:
+ def describe_job(self, vault_name: str, job_id: str) -> dict[str, Any]:
"""
Retrieve the status of an Amazon S3 Glacier job, such as an
inventory-retrieval job
@@ -63,5 +63,5 @@ def describe_job(self, vault_name: str, job_id: str) -> Dict[str, Any]:
"""
self.log.info("Retrieving status for vault: %s and job %s", vault_name, job_id)
response = self.get_conn().describe_job(vaultName=vault_name, jobId=job_id)
- self.log.info("Job status: %s, code status: %s", response['Action'], response['StatusCode'])
+ self.log.info("Job status: %s, code status: %s", response["Action"], response["StatusCode"])
return response
diff --git a/airflow/providers/amazon/aws/hooks/glue.py b/airflow/providers/amazon/aws/hooks/glue.py
index dcd6d7c4661cd..f318417412858 100644
--- a/airflow/providers/amazon/aws/hooks/glue.py
+++ b/airflow/providers/amazon/aws/hooks/glue.py
@@ -15,14 +15,22 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
import time
-import warnings
-from typing import Dict, List, Optional
+
+import boto3
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
+DEFAULT_LOG_SUFFIX = "output"
+FAILURE_LOG_SUFFIX = "error"
+# A filter value of ' ' translates to "match all".
+# see: https://docs.aws.amazon.com/AmazonCloudWatch/latest/logs/FilterAndPatternSyntax.html
+DEFAULT_LOG_FILTER = " "
+FAILURE_LOG_FILTER = "?ERROR ?Exception"
+
class GlueJobHook(AwsBaseHook):
"""
@@ -44,15 +52,15 @@ class GlueJobHook(AwsBaseHook):
def __init__(
self,
- s3_bucket: Optional[str] = None,
- job_name: Optional[str] = None,
- desc: Optional[str] = None,
+ s3_bucket: str | None = None,
+ job_name: str | None = None,
+ desc: str | None = None,
concurrent_run_limit: int = 1,
- script_location: Optional[str] = None,
+ script_location: str | None = None,
retry_limit: int = 0,
- num_of_dpus: Optional[int] = None,
- iam_role_name: Optional[str] = None,
- create_job_kwargs: Optional[dict] = None,
+ num_of_dpus: int | None = None,
+ iam_role_name: str | None = None,
+ create_job_kwargs: dict | None = None,
*args,
**kwargs,
):
@@ -63,7 +71,7 @@ def __init__(
self.retry_limit = retry_limit
self.s3_bucket = s3_bucket
self.role_name = iam_role_name
- self.s3_glue_logs = 'logs/glue-logs/'
+ self.s3_glue_logs = "logs/glue-logs/"
self.create_job_kwargs = create_job_kwargs or {}
worker_type_exists = "WorkerType" in self.create_job_kwargs
@@ -81,20 +89,20 @@ def __init__(
else:
self.num_of_dpus = num_of_dpus
- kwargs['client_type'] = 'glue'
+ kwargs["client_type"] = "glue"
super().__init__(*args, **kwargs)
- def list_jobs(self) -> List:
+ def list_jobs(self) -> list:
""":return: Lists of Jobs"""
conn = self.get_conn()
return conn.get_jobs()
- def get_iam_execution_role(self) -> Dict:
+ def get_iam_execution_role(self) -> dict:
""":return: iam role for job execution"""
- session, endpoint_url = self._get_credentials(region_name=self.region_name)
- iam_client = session.client('iam', endpoint_url=endpoint_url, config=self.config, verify=self.verify)
-
try:
+ iam_client = self.get_session(region_name=self.region_name).client(
+ "iam", endpoint_url=self.conn_config.endpoint_url, config=self.config, verify=self.verify
+ )
glue_execution_role = iam_client.get_role(RoleName=self.role_name)
self.log.info("Iam Role Name: %s", self.role_name)
return glue_execution_role
@@ -104,9 +112,9 @@ def get_iam_execution_role(self) -> Dict:
def initialize_job(
self,
- script_arguments: Optional[dict] = None,
- run_kwargs: Optional[dict] = None,
- ) -> Dict[str, str]:
+ script_arguments: dict | None = None,
+ run_kwargs: dict | None = None,
+ ) -> dict[str, str]:
"""
Initializes connection with AWS Glue
to run job
@@ -134,34 +142,94 @@ def get_job_state(self, job_name: str, run_id: str) -> str:
"""
glue_client = self.get_conn()
job_run = glue_client.get_job_run(JobName=job_name, RunId=run_id, PredecessorsIncluded=True)
- return job_run['JobRun']['JobRunState']
+ return job_run["JobRun"]["JobRunState"]
+
+ def print_job_logs(
+ self,
+ job_name: str,
+ run_id: str,
+ job_failed: bool = False,
+ next_token: str | None = None,
+ ) -> str | None:
+ """Prints the batch of logs to the Airflow task log and returns nextToken."""
+ log_client = boto3.client("logs")
+ response = {}
+
+ filter_pattern = FAILURE_LOG_FILTER if job_failed else DEFAULT_LOG_FILTER
+ log_group_prefix = self.conn.get_job_run(JobName=job_name, RunId=run_id)["JobRun"]["LogGroupName"]
+ log_group_suffix = FAILURE_LOG_SUFFIX if job_failed else DEFAULT_LOG_SUFFIX
+ log_group_name = f"{log_group_prefix}/{log_group_suffix}"
+
+ try:
+ if next_token:
+ response = log_client.filter_log_events(
+ logGroupName=log_group_name,
+ logStreamNames=[run_id],
+ filterPattern=filter_pattern,
+ nextToken=next_token,
+ )
+ else:
+ response = log_client.filter_log_events(
+ logGroupName=log_group_name,
+ logStreamNames=[run_id],
+ filterPattern=filter_pattern,
+ )
+ if len(response["events"]):
+ messages = "\t".join([event["message"] for event in response["events"]])
+ self.log.info("Glue Job Run Logs:\n\t%s", messages)
+
+ except log_client.exceptions.ResourceNotFoundException:
+ self.log.warning(
+ "No new Glue driver logs found. This might be because there are no new logs, "
+ "or might be an error.\nIf the error persists, check the CloudWatch dashboard "
+ f"at: https://{self.conn_region_name}.console.aws.amazon.com/cloudwatch/home"
+ )
+
+ # If no new log events are available, filter_log_events will return None.
+ # In that case, check the same token again next pass.
+ return response.get("nextToken") or next_token
- def job_completion(self, job_name: str, run_id: str) -> Dict[str, str]:
+ def job_completion(self, job_name: str, run_id: str, verbose: bool = False) -> dict[str, str]:
"""
Waits until Glue job with job_name completes or
fails and return final state if finished.
Raises AirflowException when the job failed
:param job_name: unique job name per AWS account
:param run_id: The job-run ID of the predecessor job run
+ :param verbose: If True, more Glue Job Run logs show in the Airflow Task Logs. (default: False)
:return: Dict of JobRunState and JobRunId
"""
- failed_states = ['FAILED', 'TIMEOUT']
- finished_states = ['SUCCEEDED', 'STOPPED']
+ failed_states = ["FAILED", "TIMEOUT"]
+ finished_states = ["SUCCEEDED", "STOPPED"]
+ next_log_token = None
+ job_failed = False
while True:
- job_run_state = self.get_job_state(job_name, run_id)
- if job_run_state in finished_states:
- self.log.info("Exiting Job %s Run State: %s", run_id, job_run_state)
- return {'JobRunState': job_run_state, 'JobRunId': run_id}
- if job_run_state in failed_states:
- job_error_message = f"Exiting Job {run_id} Run State: {job_run_state}"
- self.log.info(job_error_message)
- raise AirflowException(job_error_message)
- else:
- self.log.info(
- "Polling for AWS Glue Job %s current run state with status %s", job_name, job_run_state
- )
- time.sleep(self.JOB_POLL_INTERVAL)
+ try:
+ job_run_state = self.get_job_state(job_name, run_id)
+ if job_run_state in finished_states:
+ self.log.info("Exiting Job %s Run State: %s", run_id, job_run_state)
+ return {"JobRunState": job_run_state, "JobRunId": run_id}
+ if job_run_state in failed_states:
+ job_failed = True
+ job_error_message = f"Exiting Job {run_id} Run State: {job_run_state}"
+ self.log.info(job_error_message)
+ raise AirflowException(job_error_message)
+ else:
+ self.log.info(
+ "Polling for AWS Glue Job %s current run state with status %s",
+ job_name,
+ job_run_state,
+ )
+ time.sleep(self.JOB_POLL_INTERVAL)
+ finally:
+ if verbose:
+ next_log_token = self.print_job_logs(
+ job_name=job_name,
+ run_id=run_id,
+ job_failed=job_failed,
+ next_token=next_log_token,
+ )
def get_or_create_glue_job(self) -> str:
"""
@@ -172,23 +240,29 @@ def get_or_create_glue_job(self) -> str:
try:
get_job_response = glue_client.get_job(JobName=self.job_name)
self.log.info("Job Already exist. Returning Name of the job")
- return get_job_response['Job']['Name']
+ return get_job_response["Job"]["Name"]
except glue_client.exceptions.EntityNotFoundException:
self.log.info("Job doesn't exist. Now creating and running AWS Glue Job")
if self.s3_bucket is None:
- raise AirflowException('Could not initialize glue job, error: Specify Parameter `s3_bucket`')
- s3_log_path = f's3://{self.s3_bucket}/{self.s3_glue_logs}{self.job_name}'
+ raise AirflowException("Could not initialize glue job, error: Specify Parameter `s3_bucket`")
+ s3_log_path = f"s3://{self.s3_bucket}/{self.s3_glue_logs}{self.job_name}"
execution_role = self.get_iam_execution_role()
try:
+ default_command = {
+ "Name": "glueetl",
+ "ScriptLocation": self.script_location,
+ }
+ command = self.create_job_kwargs.pop("Command", default_command)
+
if "WorkerType" in self.create_job_kwargs and "NumberOfWorkers" in self.create_job_kwargs:
create_job_response = glue_client.create_job(
Name=self.job_name,
Description=self.desc,
LogUri=s3_log_path,
- Role=execution_role['Role']['Arn'],
+ Role=execution_role["Role"]["Arn"],
ExecutionProperty={"MaxConcurrentRuns": self.concurrent_run_limit},
- Command={"Name": "glueetl", "ScriptLocation": self.script_location},
+ Command=command,
MaxRetries=self.retry_limit,
**self.create_job_kwargs,
)
@@ -197,30 +271,14 @@ def get_or_create_glue_job(self) -> str:
Name=self.job_name,
Description=self.desc,
LogUri=s3_log_path,
- Role=execution_role['Role']['Arn'],
+ Role=execution_role["Role"]["Arn"],
ExecutionProperty={"MaxConcurrentRuns": self.concurrent_run_limit},
- Command={"Name": "glueetl", "ScriptLocation": self.script_location},
+ Command=command,
MaxRetries=self.retry_limit,
MaxCapacity=self.num_of_dpus,
**self.create_job_kwargs,
)
- return create_job_response['Name']
+ return create_job_response["Name"]
except Exception as general_error:
self.log.error("Failed to create aws glue job, error: %s", general_error)
raise
-
-
-class AwsGlueJobHook(GlueJobHook):
- """
- This hook is deprecated.
- Please use :class:`airflow.providers.amazon.aws.hooks.glue.GlueJobHook`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- "This hook is deprecated. "
- "Please use :class:`airflow.providers.amazon.aws.hooks.glue.GlueJobHook`.",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/providers/amazon/aws/hooks/glue_catalog.py b/airflow/providers/amazon/aws/hooks/glue_catalog.py
index e77916d09e3b9..5387e19dbf893 100644
--- a/airflow/providers/amazon/aws/hooks/glue_catalog.py
+++ b/airflow/providers/amazon/aws/hooks/glue_catalog.py
@@ -15,10 +15,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""This module contains AWS Glue Catalog Hook"""
-import warnings
-from typing import Dict, List, Optional, Set
+from __future__ import annotations
from botocore.exceptions import ClientError
@@ -38,16 +36,16 @@ class GlueCatalogHook(AwsBaseHook):
"""
def __init__(self, *args, **kwargs):
- super().__init__(client_type='glue', *args, **kwargs)
+ super().__init__(client_type="glue", *args, **kwargs)
def get_partitions(
self,
database_name: str,
table_name: str,
- expression: str = '',
- page_size: Optional[int] = None,
- max_items: Optional[int] = None,
- ) -> Set[tuple]:
+ expression: str = "",
+ page_size: int | None = None,
+ max_items: int | None = None,
+ ) -> set[tuple]:
"""
Retrieves the partition values for a table.
@@ -63,19 +61,19 @@ def get_partitions(
``{('2018-01-01','1'), ('2018-01-01','2')}``
"""
config = {
- 'PageSize': page_size,
- 'MaxItems': max_items,
+ "PageSize": page_size,
+ "MaxItems": max_items,
}
- paginator = self.get_conn().get_paginator('get_partitions')
+ paginator = self.get_conn().get_paginator("get_partitions")
response = paginator.paginate(
DatabaseName=database_name, TableName=table_name, Expression=expression, PaginationConfig=config
)
partitions = set()
for page in response:
- for partition in page['Partitions']:
- partitions.add(tuple(partition['Values']))
+ for partition in page["Partitions"]:
+ partitions.add(tuple(partition["Values"]))
return partitions
@@ -87,7 +85,6 @@ def check_for_partition(self, database_name: str, table_name: str, expression: s
:param table_name: Name of hive table @partition belongs to
:expression: Expression that matches the partitions to check for
(eg `a = 'b' AND c = 'd'`)
- :rtype: bool
>>> hook = GlueCatalogHook()
>>> t = 'static_babynames_partitioned'
@@ -104,7 +101,6 @@ def get_table(self, database_name: str, table_name: str) -> dict:
:param database_name: Name of hive database (schema) @table belongs to
:param table_name: Name of hive table
- :rtype: dict
>>> hook = GlueCatalogHook()
>>> r = hook.get_table('db', 'table_foo')
@@ -112,7 +108,7 @@ def get_table(self, database_name: str, table_name: str) -> dict:
"""
result = self.get_conn().get_table(DatabaseName=database_name, Name=table_name)
- return result['Table']
+ return result["Table"]
def get_table_location(self, database_name: str, table_name: str) -> str:
"""
@@ -124,9 +120,9 @@ def get_table_location(self, database_name: str, table_name: str) -> str:
"""
table = self.get_table(database_name, table_name)
- return table['StorageDescriptor']['Location']
+ return table["StorageDescriptor"]["Location"]
- def get_partition(self, database_name: str, table_name: str, partition_values: List[str]) -> Dict:
+ def get_partition(self, database_name: str, table_name: str, partition_values: list[str]) -> dict:
"""
Gets a Partition
@@ -136,7 +132,6 @@ def get_partition(self, database_name: str, table_name: str, partition_values: L
Please see official AWS documentation for further information.
https://docs.aws.amazon.com/glue/latest/dg/aws-glue-api-catalog-partitions.html#aws-glue-api-catalog-partitions-GetPartition
- :rtype: dict
:raises: AirflowException
@@ -153,7 +148,7 @@ def get_partition(self, database_name: str, table_name: str, partition_values: L
self.log.error("Client error: %s", e)
raise AirflowException("AWS request failed, check logs for more info")
- def create_partition(self, database_name: str, table_name: str, partition_input: Dict) -> Dict:
+ def create_partition(self, database_name: str, table_name: str, partition_input: dict) -> dict:
"""
Creates a new Partition
@@ -163,7 +158,6 @@ def create_partition(self, database_name: str, table_name: str, partition_input:
Please see official AWS documentation for further information.
https://docs.aws.amazon.com/glue/latest/dg/aws-glue-api-catalog-partitions.html#aws-glue-api-catalog-partitions-CreatePartition
- :rtype: dict
:raises: AirflowException
@@ -178,19 +172,3 @@ def create_partition(self, database_name: str, table_name: str, partition_input:
except ClientError as e:
self.log.error("Client error: %s", e)
raise AirflowException("AWS request failed, check logs for more info")
-
-
-class AwsGlueCatalogHook(GlueCatalogHook):
- """
- This hook is deprecated.
- Please use :class:`airflow.providers.amazon.aws.hooks.glue_catalog.GlueCatalogHook`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- "This hook is deprecated. "
- "Please use :class:`airflow.providers.amazon.aws.hooks.glue_catalog.GlueCatalogHook`.",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/providers/amazon/aws/hooks/glue_crawler.py b/airflow/providers/amazon/aws/hooks/glue_crawler.py
index 65f7df8d28566..917b96b2c6549 100644
--- a/airflow/providers/amazon/aws/hooks/glue_crawler.py
+++ b/airflow/providers/amazon/aws/hooks/glue_crawler.py
@@ -15,15 +15,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-import sys
-import warnings
-from time import sleep
+from __future__ import annotations
-if sys.version_info >= (3, 8):
- from functools import cached_property
-else:
- from cached_property import cached_property
+from time import sleep
+from airflow.compat.functools import cached_property
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
@@ -40,7 +36,7 @@ class GlueCrawlerHook(AwsBaseHook):
"""
def __init__(self, *args, **kwargs):
- kwargs['client_type'] = 'glue'
+ kwargs["client_type"] = "glue"
super().__init__(*args, **kwargs)
@cached_property
@@ -70,7 +66,7 @@ def get_crawler(self, crawler_name: str) -> dict:
:param crawler_name: unique crawler name per AWS account
:return: Nested dictionary of crawler configurations
"""
- return self.glue_client.get_crawler(Name=crawler_name)['Crawler']
+ return self.glue_client.get_crawler(Name=crawler_name)["Crawler"]
def update_crawler(self, **crawler_kwargs) -> bool:
"""
@@ -79,7 +75,7 @@ def update_crawler(self, **crawler_kwargs) -> bool:
:param crawler_kwargs: Keyword args that define the configurations used for the crawler
:return: True if crawler was updated and false otherwise
"""
- crawler_name = crawler_kwargs['Name']
+ crawler_name = crawler_kwargs["Name"]
current_crawler = self.get_crawler(crawler_name)
update_config = {
@@ -100,7 +96,7 @@ def create_crawler(self, **crawler_kwargs) -> str:
:param crawler_kwargs: Keyword args that define the configurations used to create the crawler
:return: Name of the crawler
"""
- crawler_name = crawler_kwargs['Name']
+ crawler_name = crawler_kwargs["Name"]
self.log.info("Creating crawler: %s", crawler_name)
return self.glue_client.create_crawler(**crawler_kwargs)
@@ -124,26 +120,26 @@ def wait_for_crawler_completion(self, crawler_name: str, poll_interval: int = 5)
:param poll_interval: Time (in seconds) to wait between two consecutive calls to check crawler status
:return: Crawler's status
"""
- failed_status = ['FAILED', 'CANCELLED']
+ failed_status = ["FAILED", "CANCELLED"]
while True:
crawler = self.get_crawler(crawler_name)
- crawler_state = crawler['State']
- if crawler_state == 'READY':
+ crawler_state = crawler["State"]
+ if crawler_state == "READY":
self.log.info("State: %s", crawler_state)
self.log.info("crawler_config: %s", crawler)
- crawler_status = crawler['LastCrawl']['Status']
+ crawler_status = crawler["LastCrawl"]["Status"]
if crawler_status in failed_status:
raise AirflowException(f"Status: {crawler_status}")
metrics = self.glue_client.get_crawler_metrics(CrawlerNameList=[crawler_name])[
- 'CrawlerMetricsList'
+ "CrawlerMetricsList"
][0]
self.log.info("Status: %s", crawler_status)
- self.log.info("Last Runtime Duration (seconds): %s", metrics['LastRuntimeSeconds'])
- self.log.info("Median Runtime Duration (seconds): %s", metrics['MedianRuntimeSeconds'])
- self.log.info("Tables Created: %s", metrics['TablesCreated'])
- self.log.info("Tables Updated: %s", metrics['TablesUpdated'])
- self.log.info("Tables Deleted: %s", metrics['TablesDeleted'])
+ self.log.info("Last Runtime Duration (seconds): %s", metrics["LastRuntimeSeconds"])
+ self.log.info("Median Runtime Duration (seconds): %s", metrics["MedianRuntimeSeconds"])
+ self.log.info("Tables Created: %s", metrics["TablesCreated"])
+ self.log.info("Tables Updated: %s", metrics["TablesUpdated"])
+ self.log.info("Tables Deleted: %s", metrics["TablesDeleted"])
return crawler_status
@@ -152,9 +148,9 @@ def wait_for_crawler_completion(self, crawler_name: str, poll_interval: int = 5)
self.log.info("State: %s", crawler_state)
metrics = self.glue_client.get_crawler_metrics(CrawlerNameList=[crawler_name])[
- 'CrawlerMetricsList'
+ "CrawlerMetricsList"
][0]
- time_left = int(metrics['TimeLeftSeconds'])
+ time_left = int(metrics["TimeLeftSeconds"])
if time_left > 0:
self.log.info("Estimated Time Left (seconds): %s", time_left)
@@ -162,19 +158,3 @@ def wait_for_crawler_completion(self, crawler_name: str, poll_interval: int = 5)
self.log.info("Crawler should finish soon")
sleep(poll_interval)
-
-
-class AwsGlueCrawlerHook(GlueCrawlerHook):
- """
- This hook is deprecated.
- Please use :class:`airflow.providers.amazon.aws.hooks.glue_crawler.GlueCrawlerHook`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- "This hook is deprecated. "
- "Please use :class:`airflow.providers.amazon.aws.hooks.glue_crawler.GlueCrawlerHook`.",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/providers/amazon/aws/hooks/kinesis.py b/airflow/providers/amazon/aws/hooks/kinesis.py
index f15457c6c28d7..71fe675466005 100644
--- a/airflow/providers/amazon/aws/hooks/kinesis.py
+++ b/airflow/providers/amazon/aws/hooks/kinesis.py
@@ -15,9 +15,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""This module contains AWS Firehose hook"""
-import warnings
+from __future__ import annotations
+
from typing import Iterable
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
@@ -44,19 +44,3 @@ def __init__(self, delivery_stream: str, *args, **kwargs) -> None:
def put_records(self, records: Iterable):
"""Write batch records to Kinesis Firehose"""
return self.get_conn().put_record_batch(DeliveryStreamName=self.delivery_stream, Records=records)
-
-
-class AwsFirehoseHook(FirehoseHook):
- """
- This hook is deprecated.
- Please use :class:`airflow.providers.amazon.aws.hooks.kinesis.FirehoseHook`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- "This hook is deprecated. "
- "Please use :class:`airflow.providers.amazon.aws.hooks.kinesis.FirehoseHook`.",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/providers/amazon/aws/hooks/lambda_function.py b/airflow/providers/amazon/aws/hooks/lambda_function.py
index b6819d9dba263..2919d37701aed 100644
--- a/airflow/providers/amazon/aws/hooks/lambda_function.py
+++ b/airflow/providers/amazon/aws/hooks/lambda_function.py
@@ -15,10 +15,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""This module contains AWS Lambda hook"""
-import warnings
-from typing import Any, List, Optional
+from __future__ import annotations
+
+from typing import Any
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
@@ -51,11 +51,11 @@ def invoke_lambda(
self,
*,
function_name: str,
- invocation_type: Optional[str] = None,
- log_type: Optional[str] = None,
- client_context: Optional[str] = None,
- payload: Optional[str] = None,
- qualifier: Optional[str] = None,
+ invocation_type: str | None = None,
+ log_type: str | None = None,
+ client_context: str | None = None,
+ payload: str | None = None,
+ qualifier: str | None = None,
):
"""Invoke Lambda Function. Refer to the boto3 documentation for more info."""
invoke_args = {
@@ -76,22 +76,22 @@ def create_lambda(
role: str,
handler: str,
code: dict,
- description: Optional[str] = None,
- timeout: Optional[int] = None,
- memory_size: Optional[int] = None,
- publish: Optional[bool] = None,
- vpc_config: Optional[Any] = None,
- package_type: Optional[str] = None,
- dead_letter_config: Optional[Any] = None,
- environment: Optional[Any] = None,
- kms_key_arn: Optional[str] = None,
- tracing_config: Optional[Any] = None,
- tags: Optional[Any] = None,
- layers: Optional[list] = None,
- file_system_configs: Optional[List[Any]] = None,
- image_config: Optional[Any] = None,
- code_signing_config_arn: Optional[str] = None,
- architectures: Optional[List[str]] = None,
+ description: str | None = None,
+ timeout: int | None = None,
+ memory_size: int | None = None,
+ publish: bool | None = None,
+ vpc_config: Any | None = None,
+ package_type: str | None = None,
+ dead_letter_config: Any | None = None,
+ environment: Any | None = None,
+ kms_key_arn: str | None = None,
+ tracing_config: Any | None = None,
+ tags: Any | None = None,
+ layers: list | None = None,
+ file_system_configs: list[Any] | None = None,
+ image_config: Any | None = None,
+ code_signing_config_arn: str | None = None,
+ architectures: list[str] | None = None,
) -> dict:
"""Create a Lambda Function"""
create_function_args = {
@@ -120,19 +120,3 @@ def create_lambda(
return self.conn.create_function(
**{k: v for k, v in create_function_args.items() if v is not None},
)
-
-
-class AwsLambdaHook(LambdaHook):
- """
- This hook is deprecated.
- Please use :class:`airflow.providers.amazon.aws.hooks.lambda_function.LambdaHook`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- "This hook is deprecated. "
- "Please use :class:`airflow.providers.amazon.aws.hooks.lambda_function.LambdaHook`.",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/providers/amazon/aws/hooks/logs.py b/airflow/providers/amazon/aws/hooks/logs.py
index 1c5e5e62fb709..e19ff275efc8c 100644
--- a/airflow/providers/amazon/aws/hooks/logs.py
+++ b/airflow/providers/amazon/aws/hooks/logs.py
@@ -15,12 +15,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""
This module contains a hook (AwsLogsHook) with some very basic
functionality for interacting with AWS CloudWatch.
"""
-from typing import Dict, Generator, Optional
+from __future__ import annotations
+
+from typing import Generator
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
@@ -59,7 +60,6 @@ def get_log_events(
This is for when there are multiple entries at the same timestamp.
:param start_from_head: whether to start from the beginning (True) of the log or
at the end of the log (False).
- :rtype: dict
:return: | A CloudWatch log event with the following key-value pairs:
| 'timestamp' (int): The time in milliseconds of the event.
| 'message' (str): The log event data.
@@ -68,7 +68,7 @@ def get_log_events(
next_token = None
while True:
if next_token is not None:
- token_arg: Optional[Dict[str, str]] = {'nextToken': next_token}
+ token_arg: dict[str, str] | None = {"nextToken": next_token}
else:
token_arg = {}
@@ -80,7 +80,7 @@ def get_log_events(
**token_arg,
)
- events = response['events']
+ events = response["events"]
event_count = len(events)
if event_count > skip:
@@ -92,7 +92,7 @@ def get_log_events(
yield from events
- if next_token != response['nextForwardToken']:
- next_token = response['nextForwardToken']
+ if next_token != response["nextForwardToken"]:
+ next_token = response["nextForwardToken"]
else:
return
diff --git a/airflow/providers/amazon/aws/hooks/quicksight.py b/airflow/providers/amazon/aws/hooks/quicksight.py
index a7e90c36cf92a..3a03f3a3be2da 100644
--- a/airflow/providers/amazon/aws/hooks/quicksight.py
+++ b/airflow/providers/amazon/aws/hooks/quicksight.py
@@ -15,21 +15,17 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
-import sys
import time
from botocore.exceptions import ClientError
from airflow import AirflowException
+from airflow.compat.functools import cached_property
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.hooks.sts import StsHook
-if sys.version_info >= (3, 8):
- from functools import cached_property
-else:
- from cached_property import cached_property
-
class QuickSightHook(AwsBaseHook):
"""
@@ -58,7 +54,7 @@ def create_ingestion(
ingestion_type: str,
wait_for_completion: bool = True,
check_interval: int = 30,
- ):
+ ) -> dict:
"""
Creates and starts a new SPICE ingestion for a dataset. Refreshes the SPICE datasets
@@ -70,9 +66,7 @@ def create_ingestion(
will check the status of QuickSight Ingestion
:return: Returns descriptive information about the created data ingestion
having Ingestion ARN, HTTP status, ingestion ID and ingestion status.
- :rtype: Dict
"""
-
self.log.info("Creating QuickSight Ingestion for data set id %s.", data_set_id)
quicksight_client = self.get_conn()
try:
@@ -97,7 +91,7 @@ def create_ingestion(
self.log.error("Failed to run Amazon QuickSight create_ingestion API, error: %s", general_error)
raise
- def get_status(self, aws_account_id: str, data_set_id: str, ingestion_id: str):
+ def get_status(self, aws_account_id: str, data_set_id: str, ingestion_id: str) -> str:
"""
Get the current status of QuickSight Create Ingestion API.
@@ -105,7 +99,6 @@ def get_status(self, aws_account_id: str, data_set_id: str, ingestion_id: str):
:param data_set_id: QuickSight Data Set ID
:param ingestion_id: QuickSight Ingestion ID
:return: An QuickSight Ingestion Status
- :rtype: str
"""
try:
describe_ingestion_response = self.get_conn().describe_ingestion(
@@ -136,7 +129,6 @@ def wait_for_state(
will check the status of QuickSight Ingestion
:return: response of describe_ingestion call after Ingestion is is done
"""
-
sec = 0
status = self.get_status(aws_account_id, data_set_id, ingestion_id)
while status in self.NON_TERMINAL_STATES and status != target_state:
diff --git a/airflow/providers/amazon/aws/hooks/rds.py b/airflow/providers/amazon/aws/hooks/rds.py
index 3539e951cade6..15790e3add96b 100644
--- a/airflow/providers/amazon/aws/hooks/rds.py
+++ b/airflow/providers/amazon/aws/hooks/rds.py
@@ -16,16 +16,19 @@
# specific language governing permissions and limitations
# under the License.
"""Interact with AWS RDS."""
+from __future__ import annotations
-from typing import TYPE_CHECKING
+import time
+from typing import TYPE_CHECKING, Callable
-from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
+from airflow.exceptions import AirflowException, AirflowNotFoundException
+from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
if TYPE_CHECKING:
- from mypy_boto3_rds import RDSClient
+ from mypy_boto3_rds import RDSClient # noqa
-class RdsHook(AwsBaseHook):
+class RdsHook(AwsGenericHook["RDSClient"]):
"""
Interact with AWS RDS using proper client from the boto3 library.
@@ -39,7 +42,7 @@ class RdsHook(AwsBaseHook):
are passed down to the underlying AwsBaseHook.
.. seealso::
- :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
+ :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsGenericHook`
:param aws_conn_id: The Airflow connection used for AWS credentials.
"""
@@ -48,12 +51,299 @@ def __init__(self, *args, **kwargs) -> None:
kwargs["client_type"] = "rds"
super().__init__(*args, **kwargs)
- @property
- def conn(self) -> 'RDSClient':
+ def get_db_snapshot_state(self, snapshot_id: str) -> str:
"""
- Get the underlying boto3 RDS client (cached)
+ Get the current state of a DB instance snapshot.
- :return: boto3 RDS client
- :rtype: botocore.client.RDS
+ :param snapshot_id: The ID of the target DB instance snapshot
+ :return: Returns the status of the DB snapshot as a string (eg. "available")
+ :rtype: str
+ :raises AirflowNotFoundException: If the DB instance snapshot does not exist.
"""
- return super().conn
+ try:
+ response = self.conn.describe_db_snapshots(DBSnapshotIdentifier=snapshot_id)
+ except self.conn.exceptions.ClientError as e:
+ if e.response["Error"]["Code"] == "DBSnapshotNotFound":
+ raise AirflowNotFoundException(e)
+ raise e
+ return response["DBSnapshots"][0]["Status"].lower()
+
+ def wait_for_db_snapshot_state(
+ self, snapshot_id: str, target_state: str, check_interval: int = 30, max_attempts: int = 40
+ ) -> None:
+ """
+ Polls :py:meth:`RDS.Client.describe_db_snapshots` until the target state is reached.
+ An error is raised after a max number of attempts.
+
+ :param snapshot_id: The ID of the target DB instance snapshot
+ :param target_state: Wait until this state is reached
+ :param check_interval: The amount of time in seconds to wait between attempts
+ :param max_attempts: The maximum number of attempts to be made
+ """
+
+ def poke():
+ return self.get_db_snapshot_state(snapshot_id)
+
+ target_state = target_state.lower()
+ if target_state in ("available", "deleted", "completed"):
+ waiter = self.conn.get_waiter(f"db_snapshot_{target_state}") # type: ignore
+ waiter.wait(
+ DBSnapshotIdentifier=snapshot_id,
+ WaiterConfig={"Delay": check_interval, "MaxAttempts": max_attempts},
+ )
+ else:
+ self._wait_for_state(poke, target_state, check_interval, max_attempts)
+ self.log.info("DB snapshot '%s' reached the '%s' state", snapshot_id, target_state)
+
+ def get_db_cluster_snapshot_state(self, snapshot_id: str) -> str:
+ """
+ Get the current state of a DB cluster snapshot.
+
+ :param snapshot_id: The ID of the target DB cluster.
+ :return: Returns the status of the DB cluster snapshot as a string (eg. "available")
+ :rtype: str
+ :raises AirflowNotFoundException: If the DB cluster snapshot does not exist.
+ """
+ try:
+ response = self.conn.describe_db_cluster_snapshots(DBClusterSnapshotIdentifier=snapshot_id)
+ except self.conn.exceptions.ClientError as e:
+ if e.response["Error"]["Code"] == "DBClusterSnapshotNotFoundFault":
+ raise AirflowNotFoundException(e)
+ raise e
+ return response["DBClusterSnapshots"][0]["Status"].lower()
+
+ def wait_for_db_cluster_snapshot_state(
+ self, snapshot_id: str, target_state: str, check_interval: int = 30, max_attempts: int = 40
+ ) -> None:
+ """
+ Polls :py:meth:`RDS.Client.describe_db_cluster_snapshots` until the target state is reached.
+ An error is raised after a max number of attempts.
+
+ :param snapshot_id: The ID of the target DB cluster snapshot
+ :param target_state: Wait until this state is reached
+ :param check_interval: The amount of time in seconds to wait between attempts
+ :param max_attempts: The maximum number of attempts to be made
+
+ .. seealso::
+ A list of possible values for target_state:
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.describe_db_cluster_snapshots
+ """
+
+ def poke():
+ return self.get_db_cluster_snapshot_state(snapshot_id)
+
+ target_state = target_state.lower()
+ if target_state in ("available", "deleted"):
+ waiter = self.conn.get_waiter(f"db_cluster_snapshot_{target_state}") # type: ignore
+ waiter.wait(
+ DBClusterSnapshotIdentifier=snapshot_id,
+ WaiterConfig={"Delay": check_interval, "MaxAttempts": max_attempts},
+ )
+ else:
+ self._wait_for_state(poke, target_state, check_interval, max_attempts)
+ self.log.info("DB cluster snapshot '%s' reached the '%s' state", snapshot_id, target_state)
+
+ def get_export_task_state(self, export_task_id: str) -> str:
+ """
+ Gets the current state of an RDS snapshot export to Amazon S3.
+
+ :param export_task_id: The identifier of the target snapshot export task.
+ :return: Returns the status of the snapshot export task as a string (eg. "canceled")
+ :rtype: str
+ :raises AirflowNotFoundException: If the export task does not exist.
+ """
+ try:
+ response = self.conn.describe_export_tasks(ExportTaskIdentifier=export_task_id)
+ except self.conn.exceptions.ClientError as e:
+ if e.response["Error"]["Code"] == "ExportTaskNotFoundFault":
+ raise AirflowNotFoundException(e)
+ raise e
+ return response["ExportTasks"][0]["Status"].lower()
+
+ def wait_for_export_task_state(
+ self, export_task_id: str, target_state: str, check_interval: int = 30, max_attempts: int = 40
+ ) -> None:
+ """
+ Polls :py:meth:`RDS.Client.describe_export_tasks` until the target state is reached.
+ An error is raised after a max number of attempts.
+
+ :param export_task_id: The identifier of the target snapshot export task.
+ :param target_state: Wait until this state is reached
+ :param check_interval: The amount of time in seconds to wait between attempts
+ :param max_attempts: The maximum number of attempts to be made
+
+ .. seealso::
+ A list of possible values for target_state:
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.describe_export_tasks
+ """
+
+ def poke():
+ return self.get_export_task_state(export_task_id)
+
+ target_state = target_state.lower()
+ self._wait_for_state(poke, target_state, check_interval, max_attempts)
+ self.log.info("export task '%s' reached the '%s' state", export_task_id, target_state)
+
+ def get_event_subscription_state(self, subscription_name: str) -> str:
+ """
+ Gets the current state of an RDS snapshot export to Amazon S3.
+
+ :param subscription_name: The name of the target RDS event notification subscription.
+ :return: Returns the status of the event subscription as a string (eg. "active")
+ :rtype: str
+ :raises AirflowNotFoundException: If the event subscription does not exist.
+ """
+ try:
+ response = self.conn.describe_event_subscriptions(SubscriptionName=subscription_name)
+ except self.conn.exceptions.ClientError as e:
+ if e.response["Error"]["Code"] == "SubscriptionNotFoundFault":
+ raise AirflowNotFoundException(e)
+ raise e
+ return response["EventSubscriptionsList"][0]["Status"].lower()
+
+ def wait_for_event_subscription_state(
+ self, subscription_name: str, target_state: str, check_interval: int = 30, max_attempts: int = 40
+ ) -> None:
+ """
+ Polls :py:meth:`RDS.Client.describe_event_subscriptions` until the target state is reached.
+ An error is raised after a max number of attempts.
+
+ :param subscription_name: The name of the target RDS event notification subscription.
+ :param target_state: Wait until this state is reached
+ :param check_interval: The amount of time in seconds to wait between attempts
+ :param max_attempts: The maximum number of attempts to be made
+
+ .. seealso::
+ A list of possible values for target_state:
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.describe_event_subscriptions
+ """
+
+ def poke():
+ return self.get_event_subscription_state(subscription_name)
+
+ target_state = target_state.lower()
+ self._wait_for_state(poke, target_state, check_interval, max_attempts)
+ self.log.info("event subscription '%s' reached the '%s' state", subscription_name, target_state)
+
+ def get_db_instance_state(self, db_instance_id: str) -> str:
+ """
+ Get the current state of a DB instance.
+
+ :param snapshot_id: The ID of the target DB instance.
+ :return: Returns the status of the DB instance as a string (eg. "available")
+ :rtype: str
+ :raises AirflowNotFoundException: If the DB instance does not exist.
+ """
+ try:
+ response = self.conn.describe_db_instances(DBInstanceIdentifier=db_instance_id)
+ except self.conn.exceptions.ClientError as e:
+ if e.response["Error"]["Code"] == "DBInstanceNotFoundFault":
+ raise AirflowNotFoundException(e)
+ raise e
+ return response["DBInstances"][0]["DBInstanceStatus"].lower()
+
+ def wait_for_db_instance_state(
+ self, db_instance_id: str, target_state: str, check_interval: int = 30, max_attempts: int = 40
+ ) -> None:
+ """
+ Polls :py:meth:`RDS.Client.describe_db_instances` until the target state is reached.
+ An error is raised after a max number of attempts.
+
+ :param db_instance_id: The ID of the target DB instance.
+ :param target_state: Wait until this state is reached
+ :param check_interval: The amount of time in seconds to wait between attempts
+ :param max_attempts: The maximum number of attempts to be made
+
+ .. seealso::
+ For information about DB instance statuses, see Viewing DB instance status in the Amazon RDS
+ User Guide.
+ https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/accessing-monitoring.html#Overview.DBInstance.Status
+ """
+
+ def poke():
+ return self.get_db_instance_state(db_instance_id)
+
+ target_state = target_state.lower()
+ if target_state in ("available", "deleted"):
+ waiter = self.conn.get_waiter(f"db_instance_{target_state}") # type: ignore
+ waiter.wait(
+ DBInstanceIdentifier=db_instance_id,
+ WaiterConfig={"Delay": check_interval, "MaxAttempts": max_attempts},
+ )
+ else:
+ self._wait_for_state(poke, target_state, check_interval, max_attempts)
+ self.log.info("DB cluster snapshot '%s' reached the '%s' state", db_instance_id, target_state)
+
+ def get_db_cluster_state(self, db_cluster_id: str) -> str:
+ """
+ Get the current state of a DB cluster.
+
+ :param snapshot_id: The ID of the target DB cluster.
+ :return: Returns the status of the DB cluster as a string (eg. "available")
+ :rtype: str
+ :raises AirflowNotFoundException: If the DB cluster does not exist.
+ """
+ try:
+ response = self.conn.describe_db_clusters(DBClusterIdentifier=db_cluster_id)
+ except self.conn.exceptions.ClientError as e:
+ if e.response["Error"]["Code"] == "DBClusterNotFoundFault":
+ raise AirflowNotFoundException(e)
+ raise e
+ return response["DBClusters"][0]["Status"].lower()
+
+ def wait_for_db_cluster_state(
+ self, db_cluster_id: str, target_state: str, check_interval: int = 30, max_attempts: int = 40
+ ) -> None:
+ """
+ Polls :py:meth:`RDS.Client.describe_db_clusters` until the target state is reached.
+ An error is raised after a max number of attempts.
+
+ :param db_cluster_id: The ID of the target DB cluster.
+ :param target_state: Wait until this state is reached
+ :param check_interval: The amount of time in seconds to wait between attempts
+ :param max_attempts: The maximum number of attempts to be made
+
+ .. seealso::
+ For information about DB instance statuses, see Viewing DB instance status in the Amazon RDS
+ User Guide.
+ https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/accessing-monitoring.html#Overview.DBInstance.Status
+ """
+
+ def poke():
+ return self.get_db_cluster_state(db_cluster_id)
+
+ target_state = target_state.lower()
+ if target_state in ("available", "deleted"):
+ waiter = self.conn.get_waiter(f"db_cluster_{target_state}") # type: ignore
+ waiter.wait(
+ DBClusterIdentifier=db_cluster_id,
+ WaiterConfig={"Delay": check_interval, "MaxAttempts": max_attempts},
+ )
+ else:
+ self._wait_for_state(poke, target_state, check_interval, max_attempts)
+ self.log.info("DB cluster snapshot '%s' reached the '%s' state", db_cluster_id, target_state)
+
+ def _wait_for_state(
+ self,
+ poke: Callable[..., str],
+ target_state: str,
+ check_interval: int,
+ max_attempts: int,
+ ) -> None:
+ """
+ Polls the poke function for the current state until it reaches the target_state.
+
+ :param poke: A function that returns the current state of the target resource as a string.
+ :param target_state: Wait until this state is reached
+ :param check_interval: The amount of time in seconds to wait between attempts
+ :param max_attempts: The maximum number of attempts to be made
+ """
+ state = poke()
+ tries = 1
+ while state != target_state:
+ self.log.info("Current state is %s", state)
+ if tries >= max_attempts:
+ raise AirflowException("Max attempts exceeded")
+ time.sleep(check_interval)
+ state = poke()
+ tries += 1
diff --git a/airflow/providers/amazon/aws/hooks/redshift.py b/airflow/providers/amazon/aws/hooks/redshift.py
deleted file mode 100644
index 1766564d273b5..0000000000000
--- a/airflow/providers/amazon/aws/hooks/redshift.py
+++ /dev/null
@@ -1,31 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Interact with AWS Redshift clusters."""
-import warnings
-
-from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook
-from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.hooks.redshift_cluster` "
- "or `airflow.providers.amazon.aws.hooks.redshift_sql` as appropriate.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-__all__ = ["RedshiftHook", "RedshiftSQLHook"]
diff --git a/airflow/providers/amazon/aws/hooks/redshift_cluster.py b/airflow/providers/amazon/aws/hooks/redshift_cluster.py
index c7b3c51722141..d85929d062024 100644
--- a/airflow/providers/amazon/aws/hooks/redshift_cluster.py
+++ b/airflow/providers/amazon/aws/hooks/redshift_cluster.py
@@ -14,8 +14,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
-from typing import Any, Dict, List, Optional
+import warnings
+from typing import Any, Sequence
from botocore.exceptions import ClientError
@@ -35,6 +37,8 @@ class RedshiftHook(AwsBaseHook):
:param aws_conn_id: The Airflow connection used for AWS credentials.
"""
+ template_fields: Sequence[str] = ("cluster_identifier",)
+
def __init__(self, *args, **kwargs) -> None:
kwargs["client_type"] = "redshift"
super().__init__(*args, **kwargs)
@@ -45,8 +49,8 @@ def create_cluster(
node_type: str,
master_username: str,
master_user_password: str,
- params: Dict[str, Any],
- ) -> Dict[str, Any]:
+ params: dict[str, Any],
+ ) -> dict[str, Any]:
"""
Creates a new cluster with the specified parameters
@@ -83,16 +87,16 @@ def cluster_status(self, cluster_identifier: str) -> str:
:param final_cluster_snapshot_identifier: Optional[str]
"""
try:
- response = self.get_conn().describe_clusters(ClusterIdentifier=cluster_identifier)['Clusters']
- return response[0]['ClusterStatus'] if response else None
+ response = self.get_conn().describe_clusters(ClusterIdentifier=cluster_identifier)["Clusters"]
+ return response[0]["ClusterStatus"] if response else None
except self.get_conn().exceptions.ClusterNotFoundFault:
- return 'cluster_not_found'
+ return "cluster_not_found"
def delete_cluster(
self,
cluster_identifier: str,
skip_final_cluster_snapshot: bool = True,
- final_cluster_snapshot_identifier: Optional[str] = None,
+ final_cluster_snapshot_identifier: str | None = None,
):
"""
Delete a cluster and optionally create a snapshot
@@ -101,27 +105,27 @@ def delete_cluster(
:param skip_final_cluster_snapshot: determines cluster snapshot creation
:param final_cluster_snapshot_identifier: name of final cluster snapshot
"""
- final_cluster_snapshot_identifier = final_cluster_snapshot_identifier or ''
+ final_cluster_snapshot_identifier = final_cluster_snapshot_identifier or ""
response = self.get_conn().delete_cluster(
ClusterIdentifier=cluster_identifier,
SkipFinalClusterSnapshot=skip_final_cluster_snapshot,
FinalClusterSnapshotIdentifier=final_cluster_snapshot_identifier,
)
- return response['Cluster'] if response['Cluster'] else None
+ return response["Cluster"] if response["Cluster"] else None
- def describe_cluster_snapshots(self, cluster_identifier: str) -> Optional[List[str]]:
+ def describe_cluster_snapshots(self, cluster_identifier: str) -> list[str] | None:
"""
Gets a list of snapshots for a cluster
:param cluster_identifier: unique identifier of a cluster
"""
response = self.get_conn().describe_cluster_snapshots(ClusterIdentifier=cluster_identifier)
- if 'Snapshots' not in response:
+ if "Snapshots" not in response:
return None
- snapshots = response['Snapshots']
+ snapshots = response["Snapshots"]
snapshots = [snapshot for snapshot in snapshots if snapshot["Status"]]
- snapshots.sort(key=lambda x: x['SnapshotCreateTime'], reverse=True)
+ snapshots.sort(key=lambda x: x["SnapshotCreateTime"], reverse=True)
return snapshots
def restore_from_cluster_snapshot(self, cluster_identifier: str, snapshot_identifier: str) -> str:
@@ -134,17 +138,48 @@ def restore_from_cluster_snapshot(self, cluster_identifier: str, snapshot_identi
response = self.get_conn().restore_from_cluster_snapshot(
ClusterIdentifier=cluster_identifier, SnapshotIdentifier=snapshot_identifier
)
- return response['Cluster'] if response['Cluster'] else None
+ return response["Cluster"] if response["Cluster"] else None
- def create_cluster_snapshot(self, snapshot_identifier: str, cluster_identifier: str) -> str:
+ def create_cluster_snapshot(
+ self, snapshot_identifier: str, cluster_identifier: str, retention_period: int = -1
+ ) -> str:
"""
Creates a snapshot of a cluster
:param snapshot_identifier: unique identifier for a snapshot of a cluster
:param cluster_identifier: unique identifier of a cluster
+ :param retention_period: The number of days that a manual snapshot is retained.
+ If the value is -1, the manual snapshot is retained indefinitely.
"""
response = self.get_conn().create_cluster_snapshot(
SnapshotIdentifier=snapshot_identifier,
ClusterIdentifier=cluster_identifier,
+ ManualSnapshotRetentionPeriod=retention_period,
)
- return response['Snapshot'] if response['Snapshot'] else None
+ return response["Snapshot"] if response["Snapshot"] else None
+
+ def get_cluster_snapshot_status(self, snapshot_identifier: str, cluster_identifier: str | None = None):
+ """
+ Return Redshift cluster snapshot status. If cluster snapshot not found return ``None``
+
+ :param snapshot_identifier: A unique identifier for the snapshot that you are requesting
+ :param cluster_identifier: (deprecated) The unique identifier of the cluster
+ the snapshot was created from
+ """
+ if cluster_identifier:
+ warnings.warn(
+ "Parameter `cluster_identifier` is deprecated."
+ "This option will be removed in a future version.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+
+ try:
+ response = self.get_conn().describe_cluster_snapshots(
+ SnapshotIdentifier=snapshot_identifier,
+ )
+ snapshot = response.get("Snapshots")[0]
+ snapshot_status: str = snapshot.get("Status")
+ return snapshot_status
+ except self.get_conn().exceptions.ClusterSnapshotNotFoundFault:
+ return None
diff --git a/airflow/providers/amazon/aws/hooks/redshift_data.py b/airflow/providers/amazon/aws/hooks/redshift_data.py
index 74459f58a5f8a..a2cb172979107 100644
--- a/airflow/providers/amazon/aws/hooks/redshift_data.py
+++ b/airflow/providers/amazon/aws/hooks/redshift_data.py
@@ -15,16 +15,17 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
from typing import TYPE_CHECKING
-from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
+from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
if TYPE_CHECKING:
- from mypy_boto3_redshift_data import RedshiftDataAPIServiceClient
+ from mypy_boto3_redshift_data import RedshiftDataAPIServiceClient # noqa
-class RedshiftDataHook(AwsBaseHook):
+class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
"""
Interact with AWS Redshift Data, using the boto3 library
Hook attribute `conn` has all methods that listed in documentation
@@ -37,7 +38,7 @@ class RedshiftDataHook(AwsBaseHook):
are passed down to the underlying AwsBaseHook.
.. seealso::
- :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
+ :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsGenericHook`
:param aws_conn_id: The Airflow connection used for AWS credentials.
"""
@@ -45,8 +46,3 @@ class RedshiftDataHook(AwsBaseHook):
def __init__(self, *args, **kwargs) -> None:
kwargs["client_type"] = "redshift-data"
super().__init__(*args, **kwargs)
-
- @property
- def conn(self) -> 'RedshiftDataAPIServiceClient':
- """Get the underlying boto3 RedshiftDataAPIService client (cached)"""
- return super().conn
diff --git a/airflow/providers/amazon/aws/hooks/redshift_sql.py b/airflow/providers/amazon/aws/hooks/redshift_sql.py
index 03bb45f7ee128..120ce190ccb1b 100644
--- a/airflow/providers/amazon/aws/hooks/redshift_sql.py
+++ b/airflow/providers/amazon/aws/hooks/redshift_sql.py
@@ -14,21 +14,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
-import sys
-from typing import Dict, List, Optional, Union
+from __future__ import annotations
import redshift_connector
from redshift_connector import Connection as RedshiftConnection
from sqlalchemy import create_engine
from sqlalchemy.engine.url import URL
-from airflow.hooks.dbapi import DbApiHook
-
-if sys.version_info >= (3, 8):
- from functools import cached_property
-else:
- from cached_property import cached_property
+from airflow.compat.functools import cached_property
+from airflow.providers.common.sql.hooks.sql import DbApiHook
class RedshiftSQLHook(DbApiHook):
@@ -44,40 +38,40 @@ class RedshiftSQLHook(DbApiHook):
get_sqlalchemy_engine() and get_uri() depend on sqlalchemy-amazon-redshift
"""
- conn_name_attr = 'redshift_conn_id'
- default_conn_name = 'redshift_default'
- conn_type = 'redshift'
- hook_name = 'Amazon Redshift'
+ conn_name_attr = "redshift_conn_id"
+ default_conn_name = "redshift_default"
+ conn_type = "redshift"
+ hook_name = "Amazon Redshift"
supports_autocommit = True
@staticmethod
- def get_ui_field_behavior() -> Dict:
+ def get_ui_field_behaviour() -> dict:
"""Returns custom field behavior"""
return {
"hidden_fields": [],
- "relabeling": {'login': 'User', 'schema': 'Database'},
+ "relabeling": {"login": "User", "schema": "Database"},
}
@cached_property
def conn(self):
return self.get_connection(self.redshift_conn_id) # type: ignore[attr-defined]
- def _get_conn_params(self) -> Dict[str, Union[str, int]]:
+ def _get_conn_params(self) -> dict[str, str | int]:
"""Helper method to retrieve connection args"""
conn = self.conn
- conn_params: Dict[str, Union[str, int]] = {}
+ conn_params: dict[str, str | int] = {}
if conn.login:
- conn_params['user'] = conn.login
+ conn_params["user"] = conn.login
if conn.password:
- conn_params['password'] = conn.password
+ conn_params["password"] = conn.password
if conn.host:
- conn_params['host'] = conn.host
+ conn_params["host"] = conn.host
if conn.port:
- conn_params['port'] = conn.port
+ conn_params["port"] = conn.port
if conn.schema:
- conn_params['database'] = conn.schema
+ conn_params["database"] = conn.schema
return conn_params
@@ -85,10 +79,13 @@ def get_uri(self) -> str:
"""Overrides DbApiHook get_uri to use redshift_connector sqlalchemy dialect as driver name"""
conn_params = self._get_conn_params()
- if 'user' in conn_params:
- conn_params['username'] = conn_params.pop('user')
+ if "user" in conn_params:
+ conn_params["username"] = conn_params.pop("user")
- return str(URL(drivername='redshift+redshift_connector', **conn_params))
+ # Compatibility: The 'create' factory method was added in SQLAlchemy 1.4
+ # to replace calling the default URL constructor directly.
+ create_url = getattr(URL, "create", URL)
+ return str(create_url(drivername="redshift+redshift_connector", **conn_params))
def get_sqlalchemy_engine(self, engine_kwargs=None):
"""Overrides DbApiHook get_sqlalchemy_engine to pass redshift_connector specific kwargs"""
@@ -103,13 +100,12 @@ def get_sqlalchemy_engine(self, engine_kwargs=None):
return create_engine(self.get_uri(), **engine_kwargs)
- def get_table_primary_key(self, table: str, schema: Optional[str] = "public") -> Optional[List[str]]:
+ def get_table_primary_key(self, table: str, schema: str | None = "public") -> list[str] | None:
"""
Helper method that returns the table primary key
:param table: Name of the target table
- :param table: Name of the target schema, public by default
+ :param schema: Name of the target schema, public by default
:return: Primary key columns list
- :rtype: List[str]
"""
sql = """
select kcu.column_name
@@ -129,5 +125,5 @@ def get_conn(self) -> RedshiftConnection:
"""Returns a redshift_connector.Connection object"""
conn_params = self._get_conn_params()
conn_kwargs_dejson = self.conn.extra_dejson
- conn_kwargs: Dict = {**conn_params, **conn_kwargs_dejson}
+ conn_kwargs: dict = {**conn_params, **conn_kwargs_dejson}
return redshift_connector.connect(**conn_kwargs)
diff --git a/airflow/providers/amazon/aws/hooks/s3.py b/airflow/providers/amazon/aws/hooks/s3.py
index e7e9f2de508da..e5b1bad5804f2 100644
--- a/airflow/providers/amazon/aws/hooks/s3.py
+++ b/airflow/providers/amazon/aws/hooks/s3.py
@@ -15,22 +15,24 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
-
"""Interact with AWS S3, using the boto3 library."""
+from __future__ import annotations
+
import fnmatch
import gzip as gz
import io
import re
import shutil
+from copy import deepcopy
from datetime import datetime
from functools import wraps
from inspect import signature
from io import BytesIO
from pathlib import Path
-from tempfile import NamedTemporaryFile
-from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union, cast
-from urllib.parse import urlparse
+from tempfile import NamedTemporaryFile, gettempdir
+from typing import Any, Callable, TypeVar, cast
+from urllib.parse import urlsplit
+from uuid import uuid4
from boto3.s3.transfer import S3Transfer, TransferConfig
from botocore.exceptions import ClientError
@@ -53,12 +55,12 @@ def provide_bucket_name(func: T) -> T:
def wrapper(*args, **kwargs) -> T:
bound_args = function_signature.bind(*args, **kwargs)
- if 'bucket_name' not in bound_args.arguments:
+ if "bucket_name" not in bound_args.arguments:
self = args[0]
if self.aws_conn_id:
connection = self.get_connection(self.aws_conn_id)
if connection.schema:
- bound_args.arguments['bucket_name'] = connection.schema
+ bound_args.arguments["bucket_name"] = connection.schema
return func(*bound_args.args, **bound_args.kwargs)
@@ -76,15 +78,15 @@ def unify_bucket_name_and_key(func: T) -> T:
def wrapper(*args, **kwargs) -> T:
bound_args = function_signature.bind(*args, **kwargs)
- if 'wildcard_key' in bound_args.arguments:
- key_name = 'wildcard_key'
- elif 'key' in bound_args.arguments:
- key_name = 'key'
+ if "wildcard_key" in bound_args.arguments:
+ key_name = "wildcard_key"
+ elif "key" in bound_args.arguments:
+ key_name = "key"
else:
- raise ValueError('Missing key parameter!')
+ raise ValueError("Missing key parameter!")
- if 'bucket_name' not in bound_args.arguments:
- bound_args.arguments['bucket_name'], bound_args.arguments[key_name] = S3Hook.parse_s3_url(
+ if "bucket_name" not in bound_args.arguments:
+ bound_args.arguments["bucket_name"], bound_args.arguments[key_name] = S3Hook.parse_s3_url(
bound_args.arguments[key_name]
)
@@ -97,6 +99,15 @@ class S3Hook(AwsBaseHook):
"""
Interact with AWS S3, using the boto3 library.
+ :param transfer_config_args: Configuration object for managed S3 transfers.
+ :param extra_args: Extra arguments that may be passed to the download/upload operations.
+
+ .. seealso::
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/customizations/s3.html#s3-transfers
+
+ - For allowed upload extra arguments see ``boto3.s3.transfer.S3Transfer.ALLOWED_UPLOAD_ARGS``.
+ - For allowed download extra arguments see ``boto3.s3.transfer.S3Transfer.ALLOWED_DOWNLOAD_ARGS``.
+
Additional arguments (such as ``aws_conn_id``) may be specified and
are passed down to the underlying AwsBaseHook.
@@ -104,52 +115,67 @@ class S3Hook(AwsBaseHook):
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
"""
- conn_type = 's3'
- hook_name = 'Amazon S3'
-
- def __init__(self, *args, **kwargs) -> None:
- kwargs['client_type'] = 's3'
+ def __init__(
+ self,
+ aws_conn_id: str | None = AwsBaseHook.default_conn_name,
+ transfer_config_args: dict | None = None,
+ extra_args: dict | None = None,
+ *args,
+ **kwargs,
+ ) -> None:
+ kwargs["client_type"] = "s3"
+ kwargs["aws_conn_id"] = aws_conn_id
- self.extra_args = {}
- if 'extra_args' in kwargs:
- self.extra_args = kwargs['extra_args']
- if not isinstance(self.extra_args, dict):
- raise ValueError(f"extra_args '{self.extra_args!r}' must be of type {dict}")
- del kwargs['extra_args']
+ if transfer_config_args and not isinstance(transfer_config_args, dict):
+ raise TypeError(f"transfer_config_args expected dict, got {type(transfer_config_args).__name__}.")
+ self.transfer_config = TransferConfig(**transfer_config_args or {})
- self.transfer_config = TransferConfig()
- if 'transfer_config_args' in kwargs:
- transport_config_args = kwargs['transfer_config_args']
- if not isinstance(transport_config_args, dict):
- raise ValueError(f"transfer_config_args '{transport_config_args!r} must be of type {dict}")
- self.transfer_config = TransferConfig(**transport_config_args)
- del kwargs['transfer_config_args']
+ if extra_args and not isinstance(extra_args, dict):
+ raise TypeError(f"extra_args expected dict, got {type(extra_args).__name__}.")
+ self._extra_args = extra_args or {}
super().__init__(*args, **kwargs)
+ @property
+ def extra_args(self):
+ """Return hook's extra arguments (immutable)."""
+ return deepcopy(self._extra_args)
+
@staticmethod
- def parse_s3_url(s3url: str) -> Tuple[str, str]:
+ def parse_s3_url(s3url: str) -> tuple[str, str]:
"""
Parses the S3 Url into a bucket name and key.
+ See https://docs.aws.amazon.com/AmazonS3/latest/userguide/access-bucket-intro.html
+ for valid url formats
:param s3url: The S3 Url to parse.
:return: the parsed bucket name and key
- :rtype: tuple of str
"""
- parsed_url = urlparse(s3url)
-
- if not parsed_url.netloc:
- raise AirflowException(f'Please provide a bucket_name instead of "{s3url}"')
-
- bucket_name = parsed_url.netloc
- key = parsed_url.path.lstrip('/')
-
+ format = s3url.split("//")
+ if format[0].lower() == "s3:":
+ parsed_url = urlsplit(s3url)
+ if not parsed_url.netloc:
+ raise AirflowException(f'Please provide a bucket name using a valid format: "{s3url}"')
+
+ bucket_name = parsed_url.netloc
+ key = parsed_url.path.lstrip("/")
+ elif format[0] == "https:":
+ temp_split = format[1].split(".")
+ if temp_split[0] == "s3":
+ split_url = format[1].split("/")
+ bucket_name = split_url[1]
+ key = "/".join(split_url[2:])
+ elif temp_split[1] == "s3":
+ bucket_name = temp_split[0]
+ key = "/".join(format[1].split("/")[1:])
+ else:
+ raise AirflowException(f'Please provide a bucket name using a valid format: "{s3url}"')
return bucket_name, key
@staticmethod
def get_s3_bucket_key(
- bucket: Optional[str], key: str, bucket_param_name: str, key_param_name: str
- ) -> Tuple[str, str]:
+ bucket: str | None, key: str, bucket_param_name: str, key_param_name: str
+ ) -> tuple[str, str]:
"""
Get the S3 bucket name and key from either:
- bucket name and key. Return the info as it is after checking `key` is a relative path
@@ -160,59 +186,64 @@ def get_s3_bucket_key(
:param bucket_param_name: The parameter name containing the bucket name
:param key_param_name: The parameter name containing the key name
:return: the parsed bucket name and key
- :rtype: tuple of str
"""
-
if bucket is None:
return S3Hook.parse_s3_url(key)
- parsed_url = urlparse(key)
- if parsed_url.scheme != '' or parsed_url.netloc != '':
+ parsed_url = urlsplit(key)
+ if parsed_url.scheme != "" or parsed_url.netloc != "":
raise TypeError(
- f'If `{bucket_param_name}` is provided, {key_param_name} should be a relative path '
- 'from root level, rather than a full s3:// url'
+ f"If `{bucket_param_name}` is provided, {key_param_name} should be a relative path "
+ "from root level, rather than a full s3:// url"
)
return bucket, key
@provide_bucket_name
- def check_for_bucket(self, bucket_name: Optional[str] = None) -> bool:
+ def check_for_bucket(self, bucket_name: str | None = None) -> bool:
"""
Check if bucket_name exists.
:param bucket_name: the name of the bucket
:return: True if it exists and False if not.
- :rtype: bool
"""
try:
self.get_conn().head_bucket(Bucket=bucket_name)
return True
except ClientError as e:
- self.log.error(e.response["Error"]["Message"])
+ # The head_bucket api is odd in that it cannot return proper
+ # exception objects, so error codes must be used. Only 200, 404 and 403
+ # are ever returned. See the following links for more details:
+ # https://github.com/boto/boto3/issues/2499
+ # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3.html#S3.Client.head_bucket
+ return_code = int(e.response["Error"]["Code"])
+ if return_code == 404:
+ self.log.info('Bucket "%s" does not exist', bucket_name)
+ elif return_code == 403:
+ self.log.error(
+ 'Access to bucket "%s" is forbidden or there was an error with the request', bucket_name
+ )
+ self.log.error(e)
return False
@provide_bucket_name
- def get_bucket(self, bucket_name: Optional[str] = None) -> object:
+ def get_bucket(self, bucket_name: str | None = None) -> object:
"""
Returns a boto3.S3.Bucket object
:param bucket_name: the name of the bucket
:return: the bucket object to the bucket name.
- :rtype: boto3.S3.Bucket
"""
- # Buckets have no regions, and we cannot remove the region name from _get_credentials as we would
- # break compatibility, so we set it explicitly to None.
- session, endpoint_url = self._get_credentials(region_name=None)
- s3_resource = session.resource(
+ s3_resource = self.get_session().resource(
"s3",
- endpoint_url=endpoint_url,
+ endpoint_url=self.conn_config.endpoint_url,
config=self.config,
verify=self.verify,
)
return s3_resource.Bucket(bucket_name)
@provide_bucket_name
- def create_bucket(self, bucket_name: Optional[str] = None, region_name: Optional[str] = None) -> None:
+ def create_bucket(self, bucket_name: str | None = None, region_name: str | None = None) -> None:
"""
Creates an Amazon S3 bucket.
@@ -220,16 +251,22 @@ def create_bucket(self, bucket_name: Optional[str] = None, region_name: Optional
:param region_name: The name of the aws region in which to create the bucket.
"""
if not region_name:
- region_name = self.get_conn().meta.region_name
- if region_name == 'us-east-1':
+ if self.conn_region_name == "aws-global":
+ raise AirflowException(
+ "Unable to create bucket if `region_name` not set "
+ "and boto3 configured to use s3 regional endpoints."
+ )
+ region_name = self.conn_region_name
+
+ if region_name == "us-east-1":
self.get_conn().create_bucket(Bucket=bucket_name)
else:
self.get_conn().create_bucket(
- Bucket=bucket_name, CreateBucketConfiguration={'LocationConstraint': region_name}
+ Bucket=bucket_name, CreateBucketConfiguration={"LocationConstraint": region_name}
)
@provide_bucket_name
- def check_for_prefix(self, prefix: str, delimiter: str, bucket_name: Optional[str] = None) -> bool:
+ def check_for_prefix(self, prefix: str, delimiter: str, bucket_name: str | None = None) -> bool:
"""
Checks that a prefix exists in a bucket
@@ -237,10 +274,9 @@ def check_for_prefix(self, prefix: str, delimiter: str, bucket_name: Optional[st
:param prefix: a key prefix
:param delimiter: the delimiter marks key hierarchy.
:return: False if the prefix does not exist in the bucket and True if it does.
- :rtype: bool
"""
prefix = prefix + delimiter if prefix[-1] != delimiter else prefix
- prefix_split = re.split(fr'(\w+[{delimiter}])$', prefix, 1)
+ prefix_split = re.split(rf"(\w+[{delimiter}])$", prefix, 1)
previous_level = prefix_split[0]
plist = self.list_prefixes(bucket_name, previous_level, delimiter)
return prefix in plist
@@ -248,11 +284,11 @@ def check_for_prefix(self, prefix: str, delimiter: str, bucket_name: Optional[st
@provide_bucket_name
def list_prefixes(
self,
- bucket_name: Optional[str] = None,
- prefix: Optional[str] = None,
- delimiter: Optional[str] = None,
- page_size: Optional[int] = None,
- max_items: Optional[int] = None,
+ bucket_name: str | None = None,
+ prefix: str | None = None,
+ delimiter: str | None = None,
+ page_size: int | None = None,
+ max_items: int | None = None,
) -> list:
"""
Lists prefixes in a bucket under prefix
@@ -263,29 +299,28 @@ def list_prefixes(
:param page_size: pagination size
:param max_items: maximum items to return
:return: a list of matched prefixes
- :rtype: list
"""
- prefix = prefix or ''
- delimiter = delimiter or ''
+ prefix = prefix or ""
+ delimiter = delimiter or ""
config = {
- 'PageSize': page_size,
- 'MaxItems': max_items,
+ "PageSize": page_size,
+ "MaxItems": max_items,
}
- paginator = self.get_conn().get_paginator('list_objects_v2')
+ paginator = self.get_conn().get_paginator("list_objects_v2")
response = paginator.paginate(
Bucket=bucket_name, Prefix=prefix, Delimiter=delimiter, PaginationConfig=config
)
- prefixes = [] # type: List[str]
+ prefixes: list[str] = []
for page in response:
- if 'CommonPrefixes' in page:
- prefixes.extend(common_prefix['Prefix'] for common_prefix in page['CommonPrefixes'])
+ if "CommonPrefixes" in page:
+ prefixes.extend(common_prefix["Prefix"] for common_prefix in page["CommonPrefixes"])
return prefixes
def _list_key_object_filter(
- self, keys: list, from_datetime: Optional[datetime] = None, to_datetime: Optional[datetime] = None
+ self, keys: list, from_datetime: datetime | None = None, to_datetime: datetime | None = None
) -> list:
def _is_in_period(input_date: datetime) -> bool:
if from_datetime is not None and input_date <= from_datetime:
@@ -294,20 +329,20 @@ def _is_in_period(input_date: datetime) -> bool:
return False
return True
- return [k['Key'] for k in keys if _is_in_period(k['LastModified'])]
+ return [k["Key"] for k in keys if _is_in_period(k["LastModified"])]
@provide_bucket_name
def list_keys(
self,
- bucket_name: Optional[str] = None,
- prefix: Optional[str] = None,
- delimiter: Optional[str] = None,
- page_size: Optional[int] = None,
- max_items: Optional[int] = None,
- start_after_key: Optional[str] = None,
- from_datetime: Optional[datetime] = None,
- to_datetime: Optional[datetime] = None,
- object_filter: Optional[Callable[..., list]] = None,
+ bucket_name: str | None = None,
+ prefix: str | None = None,
+ delimiter: str | None = None,
+ page_size: int | None = None,
+ max_items: int | None = None,
+ start_after_key: str | None = None,
+ from_datetime: datetime | None = None,
+ to_datetime: datetime | None = None,
+ object_filter: Callable[..., list] | None = None,
) -> list:
"""
Lists keys in a bucket under prefix and not containing delimiter
@@ -331,8 +366,8 @@ def list_keys(
def object_filter(
keys: list,
- from_datetime: Optional[datetime] = None,
- to_datetime: Optional[datetime] = None,
+ from_datetime: datetime | None = None,
+ to_datetime: datetime | None = None,
) -> list:
def _is_in_period(input_date: datetime) -> bool:
if from_datetime is not None and input_date < from_datetime:
@@ -345,18 +380,17 @@ def _is_in_period(input_date: datetime) -> bool:
return [k["Key"] for k in keys if _is_in_period(k["LastModified"])]
:return: a list of matched keys
- :rtype: list
"""
- prefix = prefix or ''
- delimiter = delimiter or ''
- start_after_key = start_after_key or ''
+ prefix = prefix or ""
+ delimiter = delimiter or ""
+ start_after_key = start_after_key or ""
self.object_filter_usr = object_filter
config = {
- 'PageSize': page_size,
- 'MaxItems': max_items,
+ "PageSize": page_size,
+ "MaxItems": max_items,
}
- paginator = self.get_conn().get_paginator('list_objects_v2')
+ paginator = self.get_conn().get_paginator("list_objects_v2")
response = paginator.paginate(
Bucket=bucket_name,
Prefix=prefix,
@@ -365,10 +399,10 @@ def _is_in_period(input_date: datetime) -> bool:
StartAfter=start_after_key,
)
- keys = [] # type: List[str]
+ keys: list[str] = []
for page in response:
- if 'Contents' in page:
- keys.extend(iter(page['Contents']))
+ if "Contents" in page:
+ keys.extend(iter(page["Contents"]))
if self.object_filter_usr is not None:
return self.object_filter_usr(keys, from_datetime, to_datetime)
@@ -378,10 +412,10 @@ def _is_in_period(input_date: datetime) -> bool:
def get_file_metadata(
self,
prefix: str,
- bucket_name: Optional[str] = None,
- page_size: Optional[int] = None,
- max_items: Optional[int] = None,
- ) -> List:
+ bucket_name: str | None = None,
+ page_size: int | None = None,
+ max_items: int | None = None,
+ ) -> list:
"""
Lists metadata objects in a bucket under prefix
@@ -390,32 +424,30 @@ def get_file_metadata(
:param page_size: pagination size
:param max_items: maximum items to return
:return: a list of metadata of objects
- :rtype: list
"""
config = {
- 'PageSize': page_size,
- 'MaxItems': max_items,
+ "PageSize": page_size,
+ "MaxItems": max_items,
}
- paginator = self.get_conn().get_paginator('list_objects_v2')
+ paginator = self.get_conn().get_paginator("list_objects_v2")
response = paginator.paginate(Bucket=bucket_name, Prefix=prefix, PaginationConfig=config)
files = []
for page in response:
- if 'Contents' in page:
- files += page['Contents']
+ if "Contents" in page:
+ files += page["Contents"]
return files
@provide_bucket_name
@unify_bucket_name_and_key
- def head_object(self, key: str, bucket_name: Optional[str] = None) -> Optional[dict]:
+ def head_object(self, key: str, bucket_name: str | None = None) -> dict | None:
"""
Retrieves metadata of an object
:param key: S3 key that will point to the file
:param bucket_name: Name of the bucket in which the file is stored
:return: metadata of an object
- :rtype: dict
"""
try:
return self.get_conn().head_object(Bucket=bucket_name, Key=key)
@@ -427,35 +459,30 @@ def head_object(self, key: str, bucket_name: Optional[str] = None) -> Optional[d
@provide_bucket_name
@unify_bucket_name_and_key
- def check_for_key(self, key: str, bucket_name: Optional[str] = None) -> bool:
+ def check_for_key(self, key: str, bucket_name: str | None = None) -> bool:
"""
Checks if a key exists in a bucket
:param key: S3 key that will point to the file
:param bucket_name: Name of the bucket in which the file is stored
:return: True if the key exists and False if not.
- :rtype: bool
"""
obj = self.head_object(key, bucket_name)
return obj is not None
@provide_bucket_name
@unify_bucket_name_and_key
- def get_key(self, key: str, bucket_name: Optional[str] = None) -> S3Transfer:
+ def get_key(self, key: str, bucket_name: str | None = None) -> S3Transfer:
"""
Returns a boto3.s3.Object
:param key: the path to the key
:param bucket_name: the name of the bucket
:return: the key object from the bucket
- :rtype: boto3.s3.Object
"""
- # Buckets have no regions, and we cannot remove the region name from _get_credentials as we would
- # break compatibility, so we set it explicitly to None.
- session, endpoint_url = self._get_credentials(region_name=None)
- s3_resource = session.resource(
+ s3_resource = self.get_session().resource(
"s3",
- endpoint_url=endpoint_url,
+ endpoint_url=self.conn_config.endpoint_url,
config=self.config,
verify=self.verify,
)
@@ -465,28 +492,27 @@ def get_key(self, key: str, bucket_name: Optional[str] = None) -> S3Transfer:
@provide_bucket_name
@unify_bucket_name_and_key
- def read_key(self, key: str, bucket_name: Optional[str] = None) -> str:
+ def read_key(self, key: str, bucket_name: str | None = None) -> str:
"""
Reads a key from S3
:param key: S3 key that will point to the file
:param bucket_name: Name of the bucket in which the file is stored
:return: the content of the key
- :rtype: str
"""
obj = self.get_key(key, bucket_name)
- return obj.get()['Body'].read().decode('utf-8')
+ return obj.get()["Body"].read().decode("utf-8")
@provide_bucket_name
@unify_bucket_name_and_key
def select_key(
self,
key: str,
- bucket_name: Optional[str] = None,
- expression: Optional[str] = None,
- expression_type: Optional[str] = None,
- input_serialization: Optional[Dict[str, Any]] = None,
- output_serialization: Optional[Dict[str, Any]] = None,
+ bucket_name: str | None = None,
+ expression: str | None = None,
+ expression_type: str | None = None,
+ input_serialization: dict[str, Any] | None = None,
+ output_serialization: dict[str, Any] | None = None,
) -> str:
"""
Reads a key with S3 Select.
@@ -498,19 +524,18 @@ def select_key(
:param input_serialization: S3 Select input data serialization format
:param output_serialization: S3 Select output data serialization format
:return: retrieved subset of original data by S3 Select
- :rtype: str
.. seealso::
For more details about S3 Select parameters:
- http://boto3.readthedocs.io/en/latest/reference/services/s3.html#S3.Client.select_object_content
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3.html#S3.Client.select_object_content
"""
- expression = expression or 'SELECT * FROM S3Object'
- expression_type = expression_type or 'SQL'
+ expression = expression or "SELECT * FROM S3Object"
+ expression_type = expression_type or "SQL"
if input_serialization is None:
- input_serialization = {'CSV': {}}
+ input_serialization = {"CSV": {}}
if output_serialization is None:
- output_serialization = {'CSV': {}}
+ output_serialization = {"CSV": {}}
response = self.get_conn().select_object_content(
Bucket=bucket_name,
@@ -521,14 +546,14 @@ def select_key(
OutputSerialization=output_serialization,
)
- return b''.join(
- event['Records']['Payload'] for event in response['Payload'] if 'Records' in event
- ).decode('utf-8')
+ return b"".join(
+ event["Records"]["Payload"] for event in response["Payload"] if "Records" in event
+ ).decode("utf-8")
@provide_bucket_name
@unify_bucket_name_and_key
def check_for_wildcard_key(
- self, wildcard_key: str, bucket_name: Optional[str] = None, delimiter: str = ''
+ self, wildcard_key: str, bucket_name: str | None = None, delimiter: str = ""
) -> bool:
"""
Checks that a key matching a wildcard expression exists in a bucket
@@ -537,7 +562,6 @@ def check_for_wildcard_key(
:param bucket_name: the name of the bucket
:param delimiter: the delimiter marks key hierarchy
:return: True if a key exists and False if not.
- :rtype: bool
"""
return (
self.get_wildcard_key(wildcard_key=wildcard_key, bucket_name=bucket_name, delimiter=delimiter)
@@ -547,7 +571,7 @@ def check_for_wildcard_key(
@provide_bucket_name
@unify_bucket_name_and_key
def get_wildcard_key(
- self, wildcard_key: str, bucket_name: Optional[str] = None, delimiter: str = ''
+ self, wildcard_key: str, bucket_name: str | None = None, delimiter: str = ""
) -> S3Transfer:
"""
Returns a boto3.s3.Object object matching the wildcard expression
@@ -556,9 +580,8 @@ def get_wildcard_key(
:param bucket_name: the name of the bucket
:param delimiter: the delimiter marks key hierarchy
:return: the key object from the bucket or None if none has been found.
- :rtype: boto3.s3.Object
"""
- prefix = re.split(r'[\[\*\?]', wildcard_key, 1)[0]
+ prefix = re.split(r"[\[\*\?]", wildcard_key, 1)[0]
key_list = self.list_keys(bucket_name, prefix=prefix, delimiter=delimiter)
key_matches = [k for k in key_list if fnmatch.fnmatch(k, wildcard_key)]
if key_matches:
@@ -569,13 +592,13 @@ def get_wildcard_key(
@unify_bucket_name_and_key
def load_file(
self,
- filename: Union[Path, str],
+ filename: Path | str,
key: str,
- bucket_name: Optional[str] = None,
+ bucket_name: str | None = None,
replace: bool = False,
encrypt: bool = False,
gzip: bool = False,
- acl_policy: Optional[str] = None,
+ acl_policy: str | None = None,
) -> None:
"""
Loads a local file to S3
@@ -598,15 +621,15 @@ def load_file(
extra_args = self.extra_args
if encrypt:
- extra_args['ServerSideEncryption'] = "AES256"
+ extra_args["ServerSideEncryption"] = "AES256"
if gzip:
- with open(filename, 'rb') as f_in:
- filename_gz = f'{f_in.name}.gz'
- with gz.open(filename_gz, 'wb') as f_out:
+ with open(filename, "rb") as f_in:
+ filename_gz = f"{f_in.name}.gz"
+ with gz.open(filename_gz, "wb") as f_out:
shutil.copyfileobj(f_in, f_out)
filename = filename_gz
if acl_policy:
- extra_args['ACL'] = acl_policy
+ extra_args["ACL"] = acl_policy
client = self.get_conn()
client.upload_file(filename, bucket_name, key, ExtraArgs=extra_args, Config=self.transfer_config)
@@ -617,12 +640,12 @@ def load_string(
self,
string_data: str,
key: str,
- bucket_name: Optional[str] = None,
+ bucket_name: str | None = None,
replace: bool = False,
encrypt: bool = False,
- encoding: Optional[str] = None,
- acl_policy: Optional[str] = None,
- compression: Optional[str] = None,
+ encoding: str | None = None,
+ acl_policy: str | None = None,
+ compression: str | None = None,
) -> None:
"""
Loads a string to S3
@@ -642,18 +665,18 @@ def load_string(
object to be uploaded
:param compression: Type of compression to use, currently only gzip is supported.
"""
- encoding = encoding or 'utf-8'
+ encoding = encoding or "utf-8"
bytes_data = string_data.encode(encoding)
# Compress string
- available_compressions = ['gzip']
+ available_compressions = ["gzip"]
if compression is not None and compression not in available_compressions:
raise NotImplementedError(
f"Received {compression} compression type. "
f"String can currently be compressed in {available_compressions} only."
)
- if compression == 'gzip':
+ if compression == "gzip":
bytes_data = gz.compress(bytes_data)
file_obj = io.BytesIO(bytes_data)
@@ -667,10 +690,10 @@ def load_bytes(
self,
bytes_data: bytes,
key: str,
- bucket_name: Optional[str] = None,
+ bucket_name: str | None = None,
replace: bool = False,
encrypt: bool = False,
- acl_policy: Optional[str] = None,
+ acl_policy: str | None = None,
) -> None:
"""
Loads bytes to S3
@@ -698,10 +721,10 @@ def load_file_obj(
self,
file_obj: BytesIO,
key: str,
- bucket_name: Optional[str] = None,
+ bucket_name: str | None = None,
replace: bool = False,
encrypt: bool = False,
- acl_policy: Optional[str] = None,
+ acl_policy: str | None = None,
) -> None:
"""
Loads a file object to S3
@@ -722,19 +745,19 @@ def _upload_file_obj(
self,
file_obj: BytesIO,
key: str,
- bucket_name: Optional[str] = None,
+ bucket_name: str | None = None,
replace: bool = False,
encrypt: bool = False,
- acl_policy: Optional[str] = None,
+ acl_policy: str | None = None,
) -> None:
if not replace and self.check_for_key(key, bucket_name):
raise ValueError(f"The key {key} already exists.")
extra_args = self.extra_args
if encrypt:
- extra_args['ServerSideEncryption'] = "AES256"
+ extra_args["ServerSideEncryption"] = "AES256"
if acl_policy:
- extra_args['ACL'] = acl_policy
+ extra_args["ACL"] = acl_policy
client = self.get_conn()
client.upload_fileobj(
@@ -749,10 +772,10 @@ def copy_object(
self,
source_bucket_key: str,
dest_bucket_key: str,
- source_bucket_name: Optional[str] = None,
- dest_bucket_name: Optional[str] = None,
- source_version_id: Optional[str] = None,
- acl_policy: Optional[str] = None,
+ source_bucket_name: str | None = None,
+ dest_bucket_name: str | None = None,
+ source_version_id: str | None = None,
+ acl_policy: str | None = None,
) -> None:
"""
Creates a copy of an object that is already stored in S3.
@@ -779,17 +802,17 @@ def copy_object(
:param acl_policy: The string to specify the canned ACL policy for the
object to be copied which is private by default.
"""
- acl_policy = acl_policy or 'private'
+ acl_policy = acl_policy or "private"
dest_bucket_name, dest_bucket_key = self.get_s3_bucket_key(
- dest_bucket_name, dest_bucket_key, 'dest_bucket_name', 'dest_bucket_key'
+ dest_bucket_name, dest_bucket_key, "dest_bucket_name", "dest_bucket_key"
)
source_bucket_name, source_bucket_key = self.get_s3_bucket_key(
- source_bucket_name, source_bucket_key, 'source_bucket_name', 'source_bucket_key'
+ source_bucket_name, source_bucket_key, "source_bucket_name", "source_bucket_key"
)
- copy_source = {'Bucket': source_bucket_name, 'Key': source_bucket_key, 'VersionId': source_version_id}
+ copy_source = {"Bucket": source_bucket_name, "Key": source_bucket_key, "VersionId": source_version_id}
response = self.get_conn().copy_object(
Bucket=dest_bucket_name, Key=dest_bucket_key, CopySource=copy_source, ACL=acl_policy
)
@@ -803,7 +826,6 @@ def delete_bucket(self, bucket_name: str, force_delete: bool = False) -> None:
:param bucket_name: Bucket name
:param force_delete: Enable this to delete bucket even if not empty
:return: None
- :rtype: None
"""
if force_delete:
bucket_keys = self.list_keys(bucket_name=bucket_name)
@@ -811,7 +833,7 @@ def delete_bucket(self, bucket_name: str, force_delete: bool = False) -> None:
self.delete_objects(bucket=bucket_name, keys=bucket_keys)
self.conn.delete_bucket(Bucket=bucket_name)
- def delete_objects(self, bucket: str, keys: Union[str, list]) -> None:
+ def delete_objects(self, bucket: str, keys: str | list) -> None:
"""
Delete keys from the bucket.
@@ -834,16 +856,21 @@ def delete_objects(self, bucket: str, keys: Union[str, list]) -> None:
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/s3.html#S3.Client.delete_objects
for chunk in chunks(keys, chunk_size=1000):
response = s3.delete_objects(Bucket=bucket, Delete={"Objects": [{"Key": k} for k in chunk]})
- deleted_keys = [x['Key'] for x in response.get("Deleted", [])]
+ deleted_keys = [x["Key"] for x in response.get("Deleted", [])]
self.log.info("Deleted: %s", deleted_keys)
if "Errors" in response:
- errors_keys = [x['Key'] for x in response.get("Errors", [])]
+ errors_keys = [x["Key"] for x in response.get("Errors", [])]
raise AirflowException(f"Errors when deleting: {errors_keys}")
@provide_bucket_name
@unify_bucket_name_and_key
def download_file(
- self, key: str, bucket_name: Optional[str] = None, local_path: Optional[str] = None
+ self,
+ key: str,
+ bucket_name: str | None = None,
+ local_path: str | None = None,
+ preserve_file_name: bool = False,
+ use_autogenerated_subdir: bool = True,
) -> str:
"""
Downloads a file from the S3 location to the local file system.
@@ -852,33 +879,66 @@ def download_file(
:param bucket_name: The specific bucket to use.
:param local_path: The local path to the downloaded file. If no path is provided it will use the
system's temporary directory.
+ :param preserve_file_name: If you want the downloaded file name to be the same name as it is in S3,
+ set this parameter to True. When set to False, a random filename will be generated.
+ Default: False.
+ :param use_autogenerated_subdir: Pairs with 'preserve_file_name = True' to download the file into a
+ random generated folder inside the 'local_path', useful to avoid collisions between various tasks
+ that might download the same file name. Set it to 'False' if you don't want it, and you want a
+ predictable path.
+ Default: True.
:return: the file name.
- :rtype: str
"""
- self.log.info('Downloading source S3 file from Bucket %s with path %s', bucket_name, key)
+ self.log.info(
+ "This function shadows the 'download_file' method of S3 API, but it is not the same. If you "
+ "want to use the original method from S3 API, please call "
+ "'S3Hook.get_conn().download_file()'"
+ )
+
+ self.log.info("Downloading source S3 file from Bucket %s with path %s", bucket_name, key)
try:
s3_obj = self.get_key(key, bucket_name)
except ClientError as e:
- if e.response.get('Error', {}).get('Code') == 404:
+ if e.response.get("Error", {}).get("Code") == 404:
raise AirflowException(
- f'The source file in Bucket {bucket_name} with path {key} does not exist'
+ f"The source file in Bucket {bucket_name} with path {key} does not exist"
)
else:
raise e
- with NamedTemporaryFile(dir=local_path, prefix='airflow_tmp_', delete=False) as local_tmp_file:
- s3_obj.download_fileobj(local_tmp_file)
+ if preserve_file_name:
+ local_dir = local_path if local_path else gettempdir()
+ subdir = f"airflow_tmp_dir_{uuid4().hex[0:8]}" if use_autogenerated_subdir else ""
+ filename_in_s3 = s3_obj.key.rsplit("/", 1)[-1]
+ file_path = Path(local_dir, subdir, filename_in_s3)
+
+ if file_path.is_file():
+ self.log.error("file '%s' already exists. Failing the task and not overwriting it", file_path)
+ raise FileExistsError
+
+ file_path.parent.mkdir(exist_ok=True, parents=True)
+
+ file = open(file_path, "wb")
+ else:
+ file = NamedTemporaryFile(dir=local_path, prefix="airflow_tmp_", delete=False) # type: ignore
+
+ with file:
+ s3_obj.download_fileobj(
+ file,
+ ExtraArgs=self.extra_args,
+ Config=self.transfer_config,
+ )
- return local_tmp_file.name
+ return file.name
def generate_presigned_url(
self,
client_method: str,
- params: Optional[dict] = None,
+ params: dict | None = None,
expires_in: int = 3600,
- http_method: Optional[str] = None,
- ) -> Optional[str]:
+ http_method: str | None = None,
+ ) -> str | None:
"""
Generate a presigned url given a client, its method, and arguments
@@ -889,7 +949,6 @@ def generate_presigned_url(
:param http_method: The http method to use on the generated url.
By default, the http method is whatever is used in the method's model.
:return: The presigned url.
- :rtype: str
"""
s3_client = self.get_conn()
try:
@@ -902,17 +961,16 @@ def generate_presigned_url(
return None
@provide_bucket_name
- def get_bucket_tagging(self, bucket_name: Optional[str] = None) -> Optional[List[Dict[str, str]]]:
+ def get_bucket_tagging(self, bucket_name: str | None = None) -> list[dict[str, str]] | None:
"""
Gets a List of tags from a bucket.
:param bucket_name: The name of the bucket.
:return: A List containing the key/value pairs for the tags
- :rtype: Optional[List[Dict[str, str]]]
"""
try:
s3_client = self.get_conn()
- result = s3_client.get_bucket_tagging(Bucket=bucket_name)['TagSet']
+ result = s3_client.get_bucket_tagging(Bucket=bucket_name)["TagSet"]
self.log.info("S3 Bucket Tag Info: %s", result)
return result
except ClientError as e:
@@ -922,10 +980,10 @@ def get_bucket_tagging(self, bucket_name: Optional[str] = None) -> Optional[List
@provide_bucket_name
def put_bucket_tagging(
self,
- tag_set: Optional[List[Dict[str, str]]] = None,
- key: Optional[str] = None,
- value: Optional[str] = None,
- bucket_name: Optional[str] = None,
+ tag_set: list[dict[str, str]] | None = None,
+ key: str | None = None,
+ value: str | None = None,
+ bucket_name: str | None = None,
) -> None:
"""
Overwrites the existing TagSet with provided tags. Must provide either a TagSet or a key/value pair.
@@ -935,33 +993,31 @@ def put_bucket_tagging(
:param value: The Value for the new TagSet entry.
:param bucket_name: The name of the bucket.
:return: None
- :rtype: None
"""
self.log.info("S3 Bucket Tag Info:\tKey: %s\tValue: %s\tSet: %s", key, value, tag_set)
if not tag_set:
tag_set = []
if key and value:
- tag_set.append({'Key': key, 'Value': value})
+ tag_set.append({"Key": key, "Value": value})
elif not tag_set or (key or value):
- message = 'put_bucket_tagging() requires either a predefined TagSet or a key/value pair.'
+ message = "put_bucket_tagging() requires either a predefined TagSet or a key/value pair."
self.log.error(message)
raise ValueError(message)
try:
s3_client = self.get_conn()
- s3_client.put_bucket_tagging(Bucket=bucket_name, Tagging={'TagSet': tag_set})
+ s3_client.put_bucket_tagging(Bucket=bucket_name, Tagging={"TagSet": tag_set})
except ClientError as e:
self.log.error(e)
raise e
@provide_bucket_name
- def delete_bucket_tagging(self, bucket_name: Optional[str] = None) -> None:
+ def delete_bucket_tagging(self, bucket_name: str | None = None) -> None:
"""
Deletes all tags from a bucket.
:param bucket_name: The name of the bucket.
:return: None
- :rtype: None
"""
s3_client = self.get_conn()
s3_client.delete_bucket_tagging(Bucket=bucket_name)
diff --git a/airflow/providers/amazon/aws/hooks/sagemaker.py b/airflow/providers/amazon/aws/hooks/sagemaker.py
index 2c8c28a738ec3..ae57e832d1b90 100644
--- a/airflow/providers/amazon/aws/hooks/sagemaker.py
+++ b/airflow/providers/amazon/aws/hooks/sagemaker.py
@@ -15,15 +15,18 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
import collections
import os
+import re
import tarfile
import tempfile
import time
import warnings
from datetime import datetime
from functools import partial
-from typing import Any, Callable, Dict, Generator, List, Optional, Set, cast
+from typing import Any, Callable, Generator, cast
from botocore.exceptions import ClientError
@@ -49,10 +52,10 @@ class LogState:
# Position is a tuple that includes the last read timestamp and the number of items that were read
# at that time. This is used to figure out which event to start with on the next read.
-Position = collections.namedtuple('Position', ['timestamp', 'skip'])
+Position = collections.namedtuple("Position", ["timestamp", "skip"])
-def argmin(arr, f: Callable) -> Optional[int]:
+def argmin(arr, f: Callable) -> int | None:
"""Return the index, i, in arr that minimizes f(arr[i])"""
min_value = None
min_idx = None
@@ -73,28 +76,28 @@ def secondary_training_status_changed(current_job_description: dict, prev_job_de
:return: Whether the secondary status message of a training job changed or not.
"""
- current_secondary_status_transitions = current_job_description.get('SecondaryStatusTransitions')
+ current_secondary_status_transitions = current_job_description.get("SecondaryStatusTransitions")
if current_secondary_status_transitions is None or len(current_secondary_status_transitions) == 0:
return False
prev_job_secondary_status_transitions = (
- prev_job_description.get('SecondaryStatusTransitions') if prev_job_description is not None else None
+ prev_job_description.get("SecondaryStatusTransitions") if prev_job_description is not None else None
)
last_message = (
- prev_job_secondary_status_transitions[-1]['StatusMessage']
+ prev_job_secondary_status_transitions[-1]["StatusMessage"]
if prev_job_secondary_status_transitions is not None
and len(prev_job_secondary_status_transitions) > 0
- else ''
+ else ""
)
- message = current_job_description['SecondaryStatusTransitions'][-1]['StatusMessage']
+ message = current_job_description["SecondaryStatusTransitions"][-1]["StatusMessage"]
return message != last_message
def secondary_training_status_message(
- job_description: Dict[str, List[Any]], prev_description: Optional[dict]
+ job_description: dict[str, list[Any]], prev_description: dict | None
) -> str:
"""
Returns a string contains start time and the secondary training job status message.
@@ -104,14 +107,14 @@ def secondary_training_status_message(
:return: Job status string to be printed.
"""
- current_transitions = job_description.get('SecondaryStatusTransitions')
+ current_transitions = job_description.get("SecondaryStatusTransitions")
if current_transitions is None or len(current_transitions) == 0:
- return ''
+ return ""
prev_transitions_num = 0
if prev_description is not None:
- if prev_description.get('SecondaryStatusTransitions') is not None:
- prev_transitions_num = len(prev_description['SecondaryStatusTransitions'])
+ if prev_description.get("SecondaryStatusTransitions") is not None:
+ prev_transitions_num = len(prev_description["SecondaryStatusTransitions"])
transitions_to_print = (
current_transitions[-1:]
@@ -121,13 +124,13 @@ def secondary_training_status_message(
status_strs = []
for transition in transitions_to_print:
- message = transition['StatusMessage']
- time_str = timezone.convert_to_utc(cast(datetime, job_description['LastModifiedTime'])).strftime(
- '%Y-%m-%d %H:%M:%S'
+ message = transition["StatusMessage"]
+ time_str = timezone.convert_to_utc(cast(datetime, job_description["LastModifiedTime"])).strftime(
+ "%Y-%m-%d %H:%M:%S"
)
status_strs.append(f"{time_str} {transition['Status']} - {message}")
- return '\n'.join(status_strs)
+ return "\n".join(status_strs)
class SageMakerHook(AwsBaseHook):
@@ -141,12 +144,12 @@ class SageMakerHook(AwsBaseHook):
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
"""
- non_terminal_states = {'InProgress', 'Stopping'}
- endpoint_non_terminal_states = {'Creating', 'Updating', 'SystemUpdating', 'RollingBack', 'Deleting'}
- failed_states = {'Failed'}
+ non_terminal_states = {"InProgress", "Stopping"}
+ endpoint_non_terminal_states = {"Creating", "Updating", "SystemUpdating", "RollingBack", "Deleting"}
+ failed_states = {"Failed"}
def __init__(self, *args, **kwargs):
- super().__init__(client_type='sagemaker', *args, **kwargs)
+ super().__init__(client_type="sagemaker", *args, **kwargs)
self.s3_hook = S3Hook(aws_conn_id=self.aws_conn_id)
self.logs_hook = AwsLogsHook(aws_conn_id=self.aws_conn_id)
@@ -164,7 +167,7 @@ def tar_and_s3_upload(self, path: str, key: str, bucket: str) -> None:
files = [os.path.join(path, name) for name in os.listdir(path)]
else:
files = [path]
- with tarfile.open(mode='w:gz', fileobj=temp_file) as tar_file:
+ with tarfile.open(mode="w:gz", fileobj=temp_file) as tar_file:
for f in files:
tar_file.add(f, arcname=os.path.basename(f))
temp_file.seek(0)
@@ -175,27 +178,25 @@ def configure_s3_resources(self, config: dict) -> None:
Extract the S3 operations from the configuration and execute them.
:param config: config of SageMaker operation
- :rtype: dict
"""
- s3_operations = config.pop('S3Operations', None)
+ s3_operations = config.pop("S3Operations", None)
if s3_operations is not None:
- create_bucket_ops = s3_operations.get('S3CreateBucket', [])
- upload_ops = s3_operations.get('S3Upload', [])
+ create_bucket_ops = s3_operations.get("S3CreateBucket", [])
+ upload_ops = s3_operations.get("S3Upload", [])
for op in create_bucket_ops:
- self.s3_hook.create_bucket(bucket_name=op['Bucket'])
+ self.s3_hook.create_bucket(bucket_name=op["Bucket"])
for op in upload_ops:
- if op['Tar']:
- self.tar_and_s3_upload(op['Path'], op['Key'], op['Bucket'])
+ if op["Tar"]:
+ self.tar_and_s3_upload(op["Path"], op["Key"], op["Bucket"])
else:
- self.s3_hook.load_file(op['Path'], op['Key'], op['Bucket'])
+ self.s3_hook.load_file(op["Path"], op["Key"], op["Bucket"])
def check_s3_url(self, s3url: str) -> bool:
"""
Check if an S3 URL exists
:param s3url: S3 url
- :rtype: bool
"""
bucket, key = S3Hook.parse_s3_url(s3url)
if not self.s3_hook.check_for_bucket(bucket_name=bucket):
@@ -203,7 +204,7 @@ def check_s3_url(self, s3url: str) -> bool:
if (
key
and not self.s3_hook.check_for_key(key=key, bucket_name=bucket)
- and not self.s3_hook.check_for_prefix(prefix=key, bucket_name=bucket, delimiter='/')
+ and not self.s3_hook.check_for_prefix(prefix=key, bucket_name=bucket, delimiter="/")
):
# check if s3 key exists in the case user provides a single file
# or if s3 prefix exists in the case user provides multiple files in
@@ -221,9 +222,9 @@ def check_training_config(self, training_config: dict) -> None:
:return: None
"""
if "InputDataConfig" in training_config:
- for channel in training_config['InputDataConfig']:
- if "S3DataSource" in channel['DataSource']:
- self.check_s3_url(channel['DataSource']['S3DataSource']['S3Uri'])
+ for channel in training_config["InputDataConfig"]:
+ if "S3DataSource" in channel["DataSource"]:
+ self.check_s3_url(channel["DataSource"]["S3DataSource"]["S3Uri"])
def check_tuning_config(self, tuning_config: dict) -> None:
"""
@@ -232,39 +233,9 @@ def check_tuning_config(self, tuning_config: dict) -> None:
:param tuning_config: tuning_config
:return: None
"""
- for channel in tuning_config['TrainingJobDefinition']['InputDataConfig']:
- if "S3DataSource" in channel['DataSource']:
- self.check_s3_url(channel['DataSource']['S3DataSource']['S3Uri'])
-
- def get_log_conn(self):
- """
- This method is deprecated.
- Please use :py:meth:`airflow.providers.amazon.aws.hooks.logs.AwsLogsHook.get_conn` instead.
- """
- warnings.warn(
- "Method `get_log_conn` has been deprecated. "
- "Please use `airflow.providers.amazon.aws.hooks.logs.AwsLogsHook.get_conn` instead.",
- category=DeprecationWarning,
- stacklevel=2,
- )
-
- return self.logs_hook.get_conn()
-
- def log_stream(self, log_group, stream_name, start_time=0, skip=0):
- """
- This method is deprecated.
- Please use
- :py:meth:`airflow.providers.amazon.aws.hooks.logs.AwsLogsHook.get_log_events` instead.
- """
- warnings.warn(
- "Method `log_stream` has been deprecated. "
- "Please use "
- "`airflow.providers.amazon.aws.hooks.logs.AwsLogsHook.get_log_events` instead.",
- category=DeprecationWarning,
- stacklevel=2,
- )
-
- return self.logs_hook.get_log_events(log_group, stream_name, start_time, skip)
+ for channel in tuning_config["TrainingJobDefinition"]["InputDataConfig"]:
+ if "S3DataSource" in channel["DataSource"]:
+ self.check_s3_url(channel["DataSource"]["S3DataSource"]["S3Uri"])
def multi_stream_iter(self, log_group: str, streams: list, positions=None) -> Generator:
"""
@@ -283,7 +254,7 @@ def multi_stream_iter(self, log_group: str, streams: list, positions=None) -> Ge
self.logs_hook.get_log_events(log_group, s, positions[s].timestamp, positions[s].skip)
for s in streams
]
- events: List[Optional[Any]] = []
+ events: list[Any | None] = []
for event_stream in event_iters:
if not event_stream:
events.append(None)
@@ -294,7 +265,7 @@ def multi_stream_iter(self, log_group: str, streams: list, positions=None) -> Ge
events.append(None)
while any(events):
- i = argmin(events, lambda x: x['timestamp'] if x else 9999999999) or 0
+ i = argmin(events, lambda x: x["timestamp"] if x else 9999999999) or 0
yield i, events[i]
try:
events[i] = next(event_iters[i])
@@ -307,7 +278,7 @@ def create_training_job(
wait_for_completion: bool = True,
print_log: bool = True,
check_interval: int = 30,
- max_ingestion_time: Optional[int] = None,
+ max_ingestion_time: int | None = None,
):
"""
Starts a model training job. After training completes, Amazon SageMaker saves
@@ -327,7 +298,7 @@ def create_training_job(
response = self.get_conn().create_training_job(**config)
if print_log:
self.check_training_status_with_log(
- config['TrainingJobName'],
+ config["TrainingJobName"],
self.non_terminal_states,
self.failed_states,
wait_for_completion,
@@ -336,17 +307,17 @@ def create_training_job(
)
elif wait_for_completion:
describe_response = self.check_status(
- config['TrainingJobName'],
- 'TrainingJobStatus',
+ config["TrainingJobName"],
+ "TrainingJobStatus",
self.describe_training_job,
check_interval,
max_ingestion_time,
)
billable_time = (
- describe_response['TrainingEndTime'] - describe_response['TrainingStartTime']
- ) * describe_response['ResourceConfig']['InstanceCount']
- self.log.info('Billable seconds: %d', int(billable_time.total_seconds()) + 1)
+ describe_response["TrainingEndTime"] - describe_response["TrainingStartTime"]
+ ) * describe_response["ResourceConfig"]["InstanceCount"]
+ self.log.info("Billable seconds: %d", int(billable_time.total_seconds()) + 1)
return response
@@ -355,7 +326,7 @@ def create_tuning_job(
config: dict,
wait_for_completion: bool = True,
check_interval: int = 30,
- max_ingestion_time: Optional[int] = None,
+ max_ingestion_time: int | None = None,
):
"""
Starts a hyperparameter tuning job. A hyperparameter tuning job finds the
@@ -378,8 +349,8 @@ def create_tuning_job(
response = self.get_conn().create_hyper_parameter_tuning_job(**config)
if wait_for_completion:
self.check_status(
- config['HyperParameterTuningJobName'],
- 'HyperParameterTuningJobStatus',
+ config["HyperParameterTuningJobName"],
+ "HyperParameterTuningJobStatus",
self.describe_tuning_job,
check_interval,
max_ingestion_time,
@@ -391,7 +362,7 @@ def create_transform_job(
config: dict,
wait_for_completion: bool = True,
check_interval: int = 30,
- max_ingestion_time: Optional[int] = None,
+ max_ingestion_time: int | None = None,
):
"""
Starts a transform job. A transform job uses a trained model to get inferences
@@ -406,14 +377,14 @@ def create_transform_job(
None implies no timeout for any SageMaker job.
:return: A response to transform job creation
"""
- if "S3DataSource" in config['TransformInput']['DataSource']:
- self.check_s3_url(config['TransformInput']['DataSource']['S3DataSource']['S3Uri'])
+ if "S3DataSource" in config["TransformInput"]["DataSource"]:
+ self.check_s3_url(config["TransformInput"]["DataSource"]["S3DataSource"]["S3Uri"])
response = self.get_conn().create_transform_job(**config)
if wait_for_completion:
self.check_status(
- config['TransformJobName'],
- 'TransformJobStatus',
+ config["TransformJobName"],
+ "TransformJobStatus",
self.describe_transform_job,
check_interval,
max_ingestion_time,
@@ -425,7 +396,7 @@ def create_processing_job(
config: dict,
wait_for_completion: bool = True,
check_interval: int = 30,
- max_ingestion_time: Optional[int] = None,
+ max_ingestion_time: int | None = None,
):
"""
Use Amazon SageMaker Processing to analyze data and evaluate machine learning
@@ -445,8 +416,8 @@ def create_processing_job(
response = self.get_conn().create_processing_job(**config)
if wait_for_completion:
self.check_status(
- config['ProcessingJobName'],
- 'ProcessingJobStatus',
+ config["ProcessingJobName"],
+ "ProcessingJobStatus",
self.describe_processing_job,
check_interval,
max_ingestion_time,
@@ -486,7 +457,7 @@ def create_endpoint(
config: dict,
wait_for_completion: bool = True,
check_interval: int = 30,
- max_ingestion_time: Optional[int] = None,
+ max_ingestion_time: int | None = None,
):
"""
When you create a serverless endpoint, SageMaker provisions and manages
@@ -511,8 +482,8 @@ def create_endpoint(
response = self.get_conn().create_endpoint(**config)
if wait_for_completion:
self.check_status(
- config['EndpointName'],
- 'EndpointStatus',
+ config["EndpointName"],
+ "EndpointStatus",
self.describe_endpoint,
check_interval,
max_ingestion_time,
@@ -525,7 +496,7 @@ def update_endpoint(
config: dict,
wait_for_completion: bool = True,
check_interval: int = 30,
- max_ingestion_time: Optional[int] = None,
+ max_ingestion_time: int | None = None,
):
"""
Deploys the new EndpointConfig specified in the request, switches to using
@@ -544,8 +515,8 @@ def update_endpoint(
response = self.get_conn().update_endpoint(**config)
if wait_for_completion:
self.check_status(
- config['EndpointName'],
- 'EndpointStatus',
+ config["EndpointName"],
+ "EndpointStatus",
self.describe_endpoint,
check_interval,
max_ingestion_time,
@@ -573,7 +544,7 @@ def describe_training_job_with_log(
last_describe_job_call: float,
):
"""Return the training job info associated with job_name and print CloudWatch logs"""
- log_group = '/aws/sagemaker/TrainingJobs'
+ log_group = "/aws/sagemaker/TrainingJobs"
if len(stream_names) < instance_count:
# Log streams are created whenever a container starts writing to stdout/err, so this list
@@ -582,11 +553,11 @@ def describe_training_job_with_log(
try:
streams = logs_conn.describe_log_streams(
logGroupName=log_group,
- logStreamNamePrefix=job_name + '/',
- orderBy='LogStreamName',
+ logStreamNamePrefix=job_name + "/",
+ orderBy="LogStreamName",
limit=instance_count,
)
- stream_names = [s['logStreamName'] for s in streams['logStreams']]
+ stream_names = [s["logStreamName"] for s in streams["logStreams"]]
positions.update(
[(s, Position(timestamp=0, skip=0)) for s in stream_names if s not in positions]
)
@@ -597,12 +568,12 @@ def describe_training_job_with_log(
if len(stream_names) > 0:
for idx, event in self.multi_stream_iter(log_group, stream_names, positions):
- self.log.info(event['message'])
+ self.log.info(event["message"])
ts, count = positions[stream_names[idx]]
- if event['timestamp'] == ts:
+ if event["timestamp"] == ts:
positions[stream_names[idx]] = Position(timestamp=ts, skip=count + 1)
else:
- positions[stream_names[idx]] = Position(timestamp=event['timestamp'], skip=1)
+ positions[stream_names[idx]] = Position(timestamp=event["timestamp"], skip=1)
if state == LogState.COMPLETE:
return state, last_description, last_describe_job_call
@@ -617,7 +588,7 @@ def describe_training_job_with_log(
self.log.info(secondary_training_status_message(description, last_description))
last_description = description
- status = description['TrainingJobStatus']
+ status = description["TrainingJobStatus"]
if status not in self.non_terminal_states:
state = LogState.JOB_COMPLETE
@@ -681,8 +652,8 @@ def check_status(
key: str,
describe_function: Callable,
check_interval: int,
- max_ingestion_time: Optional[int] = None,
- non_terminal_states: Optional[Set] = None,
+ max_ingestion_time: int | None = None,
+ non_terminal_states: set | None = None,
):
"""
Check status of a SageMaker job
@@ -704,34 +675,30 @@ def check_status(
non_terminal_states = self.non_terminal_states
sec = 0
- running = True
- while running:
+ while True:
time.sleep(check_interval)
sec += check_interval
try:
response = describe_function(job_name)
status = response[key]
- self.log.info('Job still running for %s seconds... current status is %s', sec, status)
+ self.log.info("Job still running for %s seconds... current status is %s", sec, status)
except KeyError:
- raise AirflowException('Could not get status of the SageMaker job')
+ raise AirflowException("Could not get status of the SageMaker job")
except ClientError:
- raise AirflowException('AWS request failed, check logs for more info')
+ raise AirflowException("AWS request failed, check logs for more info")
- if status in non_terminal_states:
- running = True
- elif status in self.failed_states:
+ if status in self.failed_states:
raise AirflowException(f"SageMaker job failed because {response['FailureReason']}")
- else:
- running = False
+ elif status not in non_terminal_states:
+ break
if max_ingestion_time and sec > max_ingestion_time:
# ensure that the job gets killed if the max ingestion time is exceeded
- raise AirflowException(f'SageMaker job took more than {max_ingestion_time} seconds')
+ raise AirflowException(f"SageMaker job took more than {max_ingestion_time} seconds")
- self.log.info('SageMaker Job completed')
- response = describe_function(job_name)
+ self.log.info("SageMaker Job completed")
return response
def check_training_status_with_log(
@@ -741,7 +708,7 @@ def check_training_status_with_log(
failed_states: set,
wait_for_completion: bool,
check_interval: int,
- max_ingestion_time: Optional[int] = None,
+ max_ingestion_time: int | None = None,
):
"""
Display the logs for a given training job, optionally tailing them until the
@@ -761,8 +728,8 @@ def check_training_status_with_log(
sec = 0
description = self.describe_training_job(job_name)
self.log.info(secondary_training_status_message(description, None))
- instance_count = description['ResourceConfig']['InstanceCount']
- status = description['TrainingJobStatus']
+ instance_count = description["ResourceConfig"]["InstanceCount"]
+ status = description["TrainingJobStatus"]
stream_names: list = [] # The list of log streams
positions: dict = {} # The current position in each stream, map of stream name -> position
@@ -812,21 +779,21 @@ def check_training_status_with_log(
if max_ingestion_time and sec > max_ingestion_time:
# ensure that the job gets killed if the max ingestion time is exceeded
- raise AirflowException(f'SageMaker job took more than {max_ingestion_time} seconds')
+ raise AirflowException(f"SageMaker job took more than {max_ingestion_time} seconds")
if wait_for_completion:
- status = last_description['TrainingJobStatus']
+ status = last_description["TrainingJobStatus"]
if status in failed_states:
- reason = last_description.get('FailureReason', '(No reason provided)')
- raise AirflowException(f'Error training {job_name}: {status} Reason: {reason}')
+ reason = last_description.get("FailureReason", "(No reason provided)")
+ raise AirflowException(f"Error training {job_name}: {status} Reason: {reason}")
billable_time = (
- last_description['TrainingEndTime'] - last_description['TrainingStartTime']
+ last_description["TrainingEndTime"] - last_description["TrainingStartTime"]
) * instance_count
- self.log.info('Billable seconds: %d', int(billable_time.total_seconds()) + 1)
+ self.log.info("Billable seconds: %d", int(billable_time.total_seconds()) + 1)
def list_training_jobs(
- self, name_contains: Optional[str] = None, max_results: Optional[int] = None, **kwargs
- ) -> List[Dict]:
+ self, name_contains: str | None = None, max_results: int | None = None, **kwargs
+ ) -> list[dict]:
"""
This method wraps boto3's `list_training_jobs`. The training job name and max results are configurable
via arguments. Other arguments are not, and should be provided via kwargs. Note boto3 expects these in
@@ -844,28 +811,42 @@ def list_training_jobs(
:param kwargs: (optional) kwargs to boto3's list_training_jobs method
:return: results of the list_training_jobs request
"""
- config = {}
+ config, max_results = self._preprocess_list_request_args(name_contains, max_results, **kwargs)
+ list_training_jobs_request = partial(self.get_conn().list_training_jobs, **config)
+ results = self._list_request(
+ list_training_jobs_request, "TrainingJobSummaries", max_results=max_results
+ )
+ return results
- if name_contains:
- if "NameContains" in kwargs:
- raise AirflowException("Either name_contains or NameContains can be provided, not both.")
- config["NameContains"] = name_contains
+ def list_transform_jobs(
+ self, name_contains: str | None = None, max_results: int | None = None, **kwargs
+ ) -> list[dict]:
+ """
+ This method wraps boto3's `list_transform_jobs`.
+ The transform job name and max results are configurable via arguments.
+ Other arguments are not, and should be provided via kwargs. Note boto3 expects these in
+ CamelCase format, for example:
- if "MaxResults" in kwargs and kwargs["MaxResults"] is not None:
- if max_results:
- raise AirflowException("Either max_results or MaxResults can be provided, not both.")
- # Unset MaxResults, we'll use the SageMakerHook's internal method for iteratively fetching results
- max_results = kwargs["MaxResults"]
- del kwargs["MaxResults"]
+ .. code-block:: python
- config.update(kwargs)
- list_training_jobs_request = partial(self.get_conn().list_training_jobs, **config)
+ list_transform_jobs(name_contains="myjob", StatusEquals="Failed")
+
+ .. seealso::
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.list_transform_jobs
+
+ :param name_contains: (optional) partial name to match
+ :param max_results: (optional) maximum number of results to return. None returns infinite results
+ :param kwargs: (optional) kwargs to boto3's list_transform_jobs method
+ :return: results of the list_transform_jobs request
+ """
+ config, max_results = self._preprocess_list_request_args(name_contains, max_results, **kwargs)
+ list_transform_jobs_request = partial(self.get_conn().list_transform_jobs, **config)
results = self._list_request(
- list_training_jobs_request, "TrainingJobSummaries", max_results=max_results
+ list_transform_jobs_request, "TransformJobSummaries", max_results=max_results
)
return results
- def list_processing_jobs(self, **kwargs) -> List[Dict]:
+ def list_processing_jobs(self, **kwargs) -> list[dict]:
"""
This method wraps boto3's `list_processing_jobs`. All arguments should be provided via kwargs.
Note boto3 expects these in CamelCase format, for example:
@@ -886,9 +867,40 @@ def list_processing_jobs(self, **kwargs) -> List[Dict]:
)
return results
+ def _preprocess_list_request_args(
+ self, name_contains: str | None = None, max_results: int | None = None, **kwargs
+ ) -> tuple[dict[str, Any], int | None]:
+ """
+ This method preprocesses the arguments to the boto3's list_* methods.
+ It will turn arguments name_contains and max_results as boto3 compliant CamelCase format.
+ This method also makes sure that these two arguments are only set once.
+
+ :param name_contains: boto3 function with arguments
+ :param max_results: the result key to iterate over
+ :param kwargs: (optional) kwargs to boto3's list_* method
+ :return: Tuple with config dict to be passed to boto3's list_* method and max_results parameter
+ """
+ config = {}
+
+ if name_contains:
+ if "NameContains" in kwargs:
+ raise AirflowException("Either name_contains or NameContains can be provided, not both.")
+ config["NameContains"] = name_contains
+
+ if "MaxResults" in kwargs and kwargs["MaxResults"] is not None:
+ if max_results:
+ raise AirflowException("Either max_results or MaxResults can be provided, not both.")
+ # Unset MaxResults, we'll use the SageMakerHook's internal method for iteratively fetching results
+ max_results = kwargs["MaxResults"]
+ del kwargs["MaxResults"]
+
+ config.update(kwargs)
+
+ return config, max_results
+
def _list_request(
- self, partial_func: Callable, result_key: str, max_results: Optional[int] = None
- ) -> List[Dict]:
+ self, partial_func: Callable, result_key: str, max_results: int | None = None
+ ) -> list[dict]:
"""
All AWS boto3 list_* requests return results in batches (if the key "NextToken" is contained in the
result, there are more results to fetch). The default AWS batch size is 10, and configurable up to
@@ -905,7 +917,7 @@ def _list_request(
"""
sagemaker_max_results = 100 # Fixed number set by AWS
- results: List[Dict] = []
+ results: list[dict] = []
next_token = None
while True:
@@ -929,13 +941,63 @@ def _list_request(
next_token = response["NextToken"]
def find_processing_job_by_name(self, processing_job_name: str) -> bool:
- """Query processing job by name"""
+ """
+ Query processing job by name
+
+ This method is deprecated.
+ Please use `airflow.providers.amazon.aws.hooks.sagemaker.count_processing_jobs_by_name`.
+ """
+ warnings.warn(
+ "This method is deprecated. "
+ "Please use `airflow.providers.amazon.aws.hooks.sagemaker.count_processing_jobs_by_name`.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ return bool(self.count_processing_jobs_by_name(processing_job_name))
+
+ @staticmethod
+ def _name_matches_pattern(
+ processing_job_name: str,
+ found_name: str,
+ job_name_suffix: str | None = None,
+ ) -> bool:
+ pattern = re.compile(f"^{processing_job_name}({job_name_suffix})?$")
+ return pattern.fullmatch(found_name) is not None
+
+ def count_processing_jobs_by_name(
+ self,
+ processing_job_name: str,
+ job_name_suffix: str | None = None,
+ throttle_retry_delay: int = 2,
+ retries: int = 3,
+ ) -> int:
+ """
+ Returns the number of processing jobs found with the provided name prefix.
+ :param processing_job_name: The prefix to look for.
+ :param job_name_suffix: The optional suffix which may be appended to deduplicate an existing job name.
+ :param throttle_retry_delay: Seconds to wait if a ThrottlingException is hit.
+ :param retries: The max number of times to retry.
+ :returns: The number of processing jobs that start with the provided prefix.
+ """
try:
- self.get_conn().describe_processing_job(ProcessingJobName=processing_job_name)
- return True
+ jobs = self.get_conn().list_processing_jobs(NameContains=processing_job_name)
+ # We want to make sure the job name starts with the provided name, not just contains it.
+ matching_jobs = [
+ job["ProcessingJobName"]
+ for job in jobs["ProcessingJobSummaries"]
+ if self._name_matches_pattern(processing_job_name, job["ProcessingJobName"], job_name_suffix)
+ ]
+ return len(matching_jobs)
except ClientError as e:
- if e.response['Error']['Code'] in ['ValidationException', 'ResourceNotFound']:
- return False
+ if e.response["Error"]["Code"] == "ResourceNotFound":
+ # No jobs found with that name. This is good, return 0.
+ return 0
+ if e.response["Error"]["Code"] == "ThrottlingException" and retries:
+ # If we hit a ThrottlingException, back off a little and try again.
+ time.sleep(throttle_retry_delay)
+ return self.count_processing_jobs_by_name(
+ processing_job_name, job_name_suffix, throttle_retry_delay * 2, retries - 1
+ )
raise
def delete_model(self, model_name: str):
diff --git a/airflow/providers/amazon/aws/hooks/secrets_manager.py b/airflow/providers/amazon/aws/hooks/secrets_manager.py
index f1596d4961f4e..f20c833473236 100644
--- a/airflow/providers/amazon/aws/hooks/secrets_manager.py
+++ b/airflow/providers/amazon/aws/hooks/secrets_manager.py
@@ -15,11 +15,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-#
+from __future__ import annotations
import base64
import json
-from typing import Union
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
@@ -36,24 +35,23 @@ class SecretsManagerHook(AwsBaseHook):
"""
def __init__(self, *args, **kwargs):
- super().__init__(client_type='secretsmanager', *args, **kwargs)
+ super().__init__(client_type="secretsmanager", *args, **kwargs)
- def get_secret(self, secret_name: str) -> Union[str, bytes]:
+ def get_secret(self, secret_name: str) -> str | bytes:
"""
Retrieve secret value from AWS Secrets Manager as a str or bytes
reflecting format it stored in the AWS Secrets Manager
:param secret_name: name of the secrets.
:return: Union[str, bytes] with the information about the secrets
- :rtype: Union[str, bytes]
"""
# Depending on whether the secret is a string or binary, one of
# these fields will be populated.
get_secret_value_response = self.get_conn().get_secret_value(SecretId=secret_name)
- if 'SecretString' in get_secret_value_response:
- secret = get_secret_value_response['SecretString']
+ if "SecretString" in get_secret_value_response:
+ secret = get_secret_value_response["SecretString"]
else:
- secret = base64.b64decode(get_secret_value_response['SecretBinary'])
+ secret = base64.b64decode(get_secret_value_response["SecretBinary"])
return secret
def get_secret_as_dict(self, secret_name: str) -> dict:
@@ -62,6 +60,5 @@ def get_secret_as_dict(self, secret_name: str) -> dict:
:param secret_name: name of the secrets.
:return: dict with the information about the secrets
- :rtype: dict
"""
return json.loads(self.get_secret(secret_name))
diff --git a/airflow/providers/amazon/aws/hooks/ses.py b/airflow/providers/amazon/aws/hooks/ses.py
index 92dcce7ecb617..b58024cfa439b 100644
--- a/airflow/providers/amazon/aws/hooks/ses.py
+++ b/airflow/providers/amazon/aws/hooks/ses.py
@@ -15,8 +15,9 @@
# specific language governing permissions and limitations
# under the License.
"""This module contains AWS SES Hook"""
-import warnings
-from typing import Any, Dict, Iterable, List, Optional, Union
+from __future__ import annotations
+
+from typing import Any, Iterable
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.utils.email import build_mime_message
@@ -34,23 +35,23 @@ class SesHook(AwsBaseHook):
"""
def __init__(self, *args, **kwargs) -> None:
- kwargs['client_type'] = 'ses'
+ kwargs["client_type"] = "ses"
super().__init__(*args, **kwargs)
def send_email(
self,
mail_from: str,
- to: Union[str, Iterable[str]],
+ to: str | Iterable[str],
subject: str,
html_content: str,
- files: Optional[List[str]] = None,
- cc: Optional[Union[str, Iterable[str]]] = None,
- bcc: Optional[Union[str, Iterable[str]]] = None,
- mime_subtype: str = 'mixed',
- mime_charset: str = 'utf-8',
- reply_to: Optional[str] = None,
- return_path: Optional[str] = None,
- custom_headers: Optional[Dict[str, Any]] = None,
+ files: list[str] | None = None,
+ cc: str | Iterable[str] | None = None,
+ bcc: str | Iterable[str] | None = None,
+ mime_subtype: str = "mixed",
+ mime_charset: str = "utf-8",
+ reply_to: str | None = None,
+ return_path: str | None = None,
+ custom_headers: dict[str, Any] | None = None,
) -> dict:
"""
Send email using Amazon Simple Email Service
@@ -76,9 +77,9 @@ def send_email(
custom_headers = custom_headers or {}
if reply_to:
- custom_headers['Reply-To'] = reply_to
+ custom_headers["Reply-To"] = reply_to
if return_path:
- custom_headers['Return-Path'] = return_path
+ custom_headers["Return-Path"] = return_path
message, recipients = build_mime_message(
mail_from=mail_from,
@@ -94,20 +95,5 @@ def send_email(
)
return ses_client.send_raw_email(
- Source=mail_from, Destinations=recipients, RawMessage={'Data': message.as_string()}
- )
-
-
-class SESHook(SesHook):
- """
- This hook is deprecated.
- Please use :class:`airflow.providers.amazon.aws.hooks.ses.SesHook`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- "This hook is deprecated. Please use :class:`airflow.providers.amazon.aws.hooks.ses.SesHook`.",
- DeprecationWarning,
- stacklevel=2,
+ Source=mail_from, Destinations=recipients, RawMessage={"Data": message.as_string()}
)
- super().__init__(*args, **kwargs)
diff --git a/airflow/providers/amazon/aws/hooks/sns.py b/airflow/providers/amazon/aws/hooks/sns.py
index fc009d9f9bdae..7b9fcea3dae9c 100644
--- a/airflow/providers/amazon/aws/hooks/sns.py
+++ b/airflow/providers/amazon/aws/hooks/sns.py
@@ -15,26 +15,25 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""This module contains AWS SNS hook"""
+from __future__ import annotations
+
import json
-import warnings
-from typing import Dict, Optional, Union
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
def _get_message_attribute(o):
if isinstance(o, bytes):
- return {'DataType': 'Binary', 'BinaryValue': o}
+ return {"DataType": "Binary", "BinaryValue": o}
if isinstance(o, str):
- return {'DataType': 'String', 'StringValue': o}
+ return {"DataType": "String", "StringValue": o}
if isinstance(o, (int, float)):
- return {'DataType': 'Number', 'StringValue': str(o)}
- if hasattr(o, '__iter__'):
- return {'DataType': 'String.Array', 'StringValue': json.dumps(o)}
+ return {"DataType": "Number", "StringValue": str(o)}
+ if hasattr(o, "__iter__"):
+ return {"DataType": "String.Array", "StringValue": json.dumps(o)}
raise TypeError(
- f'Values in MessageAttributes must be one of bytes, str, int, float, or iterable; got {type(o)}'
+ f"Values in MessageAttributes must be one of bytes, str, int, float, or iterable; got {type(o)}"
)
@@ -50,14 +49,14 @@ class SnsHook(AwsBaseHook):
"""
def __init__(self, *args, **kwargs):
- super().__init__(client_type='sns', *args, **kwargs)
+ super().__init__(client_type="sns", *args, **kwargs)
def publish_to_target(
self,
target_arn: str,
message: str,
- subject: Optional[str] = None,
- message_attributes: Optional[dict] = None,
+ subject: str | None = None,
+ message_attributes: dict | None = None,
):
"""
Publish a message to a topic or an endpoint.
@@ -75,33 +74,18 @@ def publish_to_target(
- iterable = String.Array
"""
- publish_kwargs: Dict[str, Union[str, dict]] = {
- 'TargetArn': target_arn,
- 'MessageStructure': 'json',
- 'Message': json.dumps({'default': message}),
+ publish_kwargs: dict[str, str | dict] = {
+ "TargetArn": target_arn,
+ "MessageStructure": "json",
+ "Message": json.dumps({"default": message}),
}
# Construct args this way because boto3 distinguishes from missing args and those set to None
if subject:
- publish_kwargs['Subject'] = subject
+ publish_kwargs["Subject"] = subject
if message_attributes:
- publish_kwargs['MessageAttributes'] = {
+ publish_kwargs["MessageAttributes"] = {
key: _get_message_attribute(val) for key, val in message_attributes.items()
}
return self.get_conn().publish(**publish_kwargs)
-
-
-class AwsSnsHook(SnsHook):
- """
- This hook is deprecated.
- Please use :class:`airflow.providers.amazon.aws.hooks.sns.SnsHook`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- "This hook is deprecated. Please use :class:`airflow.providers.amazon.aws.hooks.sns.SnsHook`.",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/providers/amazon/aws/hooks/sqs.py b/airflow/providers/amazon/aws/hooks/sqs.py
index b94756f63aa26..2d3fd9de20e0f 100644
--- a/airflow/providers/amazon/aws/hooks/sqs.py
+++ b/airflow/providers/amazon/aws/hooks/sqs.py
@@ -15,10 +15,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""This module contains AWS SQS hook"""
-import warnings
-from typing import Dict, Optional
+from __future__ import annotations
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
@@ -38,7 +36,7 @@ def __init__(self, *args, **kwargs) -> None:
kwargs["client_type"] = "sqs"
super().__init__(*args, **kwargs)
- def create_queue(self, queue_name: str, attributes: Optional[Dict] = None) -> Dict:
+ def create_queue(self, queue_name: str, attributes: dict | None = None) -> dict:
"""
Create queue using connection object
@@ -48,7 +46,6 @@ def create_queue(self, queue_name: str, attributes: Optional[Dict] = None) -> Di
:return: dict with the information about the queue
For details of the returned value see :py:meth:`SQS.create_queue`
- :rtype: dict
"""
return self.get_conn().create_queue(QueueName=queue_name, Attributes=attributes or {})
@@ -57,8 +54,9 @@ def send_message(
queue_url: str,
message_body: str,
delay_seconds: int = 0,
- message_attributes: Optional[Dict] = None,
- ) -> Dict:
+ message_attributes: dict | None = None,
+ message_group_id: str | None = None,
+ ) -> dict:
"""
Send message to the queue
@@ -67,29 +65,19 @@ def send_message(
:param delay_seconds: seconds to delay the message
:param message_attributes: additional attributes for the message (default: None)
For details of the attributes parameter see :py:meth:`botocore.client.SQS.send_message`
+ :param message_group_id: This applies only to FIFO (first-in-first-out) queues. (default: None)
+ For details of the attributes parameter see :py:meth:`botocore.client.SQS.send_message`
:return: dict with the information about the message sent
For details of the returned value see :py:meth:`botocore.client.SQS.send_message`
- :rtype: dict
"""
- return self.get_conn().send_message(
- QueueUrl=queue_url,
- MessageBody=message_body,
- DelaySeconds=delay_seconds,
- MessageAttributes=message_attributes or {},
- )
-
-
-class SQSHook(SqsHook):
- """
- This hook is deprecated.
- Please use :class:`airflow.providers.amazon.aws.hooks.sqs.SqsHook`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- "This hook is deprecated. Please use :class:`airflow.providers.amazon.aws.hooks.sqs.SqsHook`.",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
+ params = {
+ "QueueUrl": queue_url,
+ "MessageBody": message_body,
+ "DelaySeconds": delay_seconds,
+ "MessageAttributes": message_attributes or {},
+ }
+ if message_group_id:
+ params["MessageGroupId"] = message_group_id
+
+ return self.get_conn().send_message(**params)
diff --git a/airflow/providers/amazon/aws/hooks/step_function.py b/airflow/providers/amazon/aws/hooks/step_function.py
index 97ffb10c04147..1819be12b7602 100644
--- a/airflow/providers/amazon/aws/hooks/step_function.py
+++ b/airflow/providers/amazon/aws/hooks/step_function.py
@@ -14,9 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
import json
-from typing import Optional, Union
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
@@ -32,15 +32,15 @@ class StepFunctionHook(AwsBaseHook):
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
"""
- def __init__(self, region_name: Optional[str] = None, *args, **kwargs) -> None:
+ def __init__(self, *args, **kwargs) -> None:
kwargs["client_type"] = "stepfunctions"
super().__init__(*args, **kwargs)
def start_execution(
self,
state_machine_arn: str,
- name: Optional[str] = None,
- state_machine_input: Union[dict, str, None] = None,
+ name: str | None = None,
+ state_machine_input: dict | str | None = None,
) -> str:
"""
Start Execution of the State Machine.
@@ -50,21 +50,20 @@ def start_execution(
:param name: The name of the execution.
:param state_machine_input: JSON data input to pass to the State Machine
:return: Execution ARN
- :rtype: str
"""
- execution_args = {'stateMachineArn': state_machine_arn}
+ execution_args = {"stateMachineArn": state_machine_arn}
if name is not None:
- execution_args['name'] = name
+ execution_args["name"] = name
if state_machine_input is not None:
if isinstance(state_machine_input, str):
- execution_args['input'] = state_machine_input
+ execution_args["input"] = state_machine_input
elif isinstance(state_machine_input, dict):
- execution_args['input'] = json.dumps(state_machine_input)
+ execution_args["input"] = json.dumps(state_machine_input)
- self.log.info('Executing Step Function State Machine: %s', state_machine_arn)
+ self.log.info("Executing Step Function State Machine: %s", state_machine_arn)
response = self.conn.start_execution(**execution_args)
- return response.get('executionArn')
+ return response.get("executionArn")
def describe_execution(self, execution_arn: str) -> dict:
"""
@@ -73,6 +72,5 @@ def describe_execution(self, execution_arn: str) -> dict:
:param execution_arn: ARN of the State Machine Execution
:return: Dict with Execution details
- :rtype: dict
"""
return self.get_conn().describe_execution(executionArn=execution_arn)
diff --git a/airflow/providers/amazon/aws/hooks/sts.py b/airflow/providers/amazon/aws/hooks/sts.py
index aff787ee5d70e..8323ed2163f63 100644
--- a/airflow/providers/amazon/aws/hooks/sts.py
+++ b/airflow/providers/amazon/aws/hooks/sts.py
@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
@@ -33,9 +34,8 @@ def __init__(self, *args, **kwargs):
def get_account_number(self) -> str:
"""Get the account Number"""
-
try:
- return self.get_conn().get_caller_identity()['Account']
+ return self.get_conn().get_caller_identity()["Account"]
except Exception as general_error:
self.log.error("Failed to get the AWS Account Number, error: %s", general_error)
raise
diff --git a/airflow/providers/apache/cassandra/example_dags/__init__.py b/airflow/providers/amazon/aws/links/__init__.py
similarity index 100%
rename from airflow/providers/apache/cassandra/example_dags/__init__.py
rename to airflow/providers/amazon/aws/links/__init__.py
diff --git a/airflow/providers/amazon/aws/links/base_aws.py b/airflow/providers/amazon/aws/links/base_aws.py
new file mode 100644
index 0000000000000..fba2f17e96f39
--- /dev/null
+++ b/airflow/providers/amazon/aws/links/base_aws.py
@@ -0,0 +1,95 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, ClassVar
+
+from airflow.models import BaseOperatorLink, XCom
+
+if TYPE_CHECKING:
+ from airflow.models import BaseOperator
+ from airflow.models.taskinstance import TaskInstanceKey
+ from airflow.utils.context import Context
+
+
+BASE_AWS_CONSOLE_LINK = "https://console.{aws_domain}"
+
+
+class BaseAwsLink(BaseOperatorLink):
+ """Base Helper class for constructing AWS Console Link"""
+
+ name: ClassVar[str]
+ key: ClassVar[str]
+ format_str: ClassVar[str]
+
+ @staticmethod
+ def get_aws_domain(aws_partition) -> str | None:
+ if aws_partition == "aws":
+ return "aws.amazon.com"
+ elif aws_partition == "aws-cn":
+ return "amazonaws.cn"
+ elif aws_partition == "aws-us-gov":
+ return "amazonaws-us-gov.com"
+
+ return None
+
+ def format_link(self, **kwargs) -> str:
+ """
+ Format AWS Service Link
+
+ Some AWS Service Link should require additional escaping
+ in this case this method should be overridden.
+ """
+ try:
+ return self.format_str.format(**kwargs)
+ except KeyError:
+ return ""
+
+ def get_link(
+ self,
+ operator: BaseOperator,
+ *,
+ ti_key: TaskInstanceKey,
+ ) -> str:
+ """
+ Link to Amazon Web Services Console.
+
+ :param operator: airflow operator
+ :param ti_key: TaskInstance ID to return link for
+ :return: link to external system
+ """
+ conf = XCom.get_value(key=self.key, ti_key=ti_key)
+ return self.format_link(**conf) if conf else ""
+
+ @classmethod
+ def persist(
+ cls, context: Context, operator: BaseOperator, region_name: str, aws_partition: str, **kwargs
+ ) -> None:
+ """Store link information into XCom"""
+ if not operator.do_xcom_push:
+ return
+
+ operator.xcom_push(
+ context,
+ key=cls.key,
+ value={
+ "region_name": region_name,
+ "aws_domain": cls.get_aws_domain(aws_partition),
+ **kwargs,
+ },
+ )
diff --git a/airflow/providers/amazon/aws/links/batch.py b/airflow/providers/amazon/aws/links/batch.py
new file mode 100644
index 0000000000000..432d129a7c328
--- /dev/null
+++ b/airflow/providers/amazon/aws/links/batch.py
@@ -0,0 +1,45 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from airflow.providers.amazon.aws.links.base_aws import BASE_AWS_CONSOLE_LINK, BaseAwsLink
+
+
+class BatchJobDefinitionLink(BaseAwsLink):
+ """Helper class for constructing AWS Batch Job Definition Link"""
+
+ name = "Batch Job Definition"
+ key = "batch_job_definition"
+ format_str = (
+ BASE_AWS_CONSOLE_LINK + "/batch/home?region={region_name}#job-definition/detail/{job_definition_arn}"
+ )
+
+
+class BatchJobDetailsLink(BaseAwsLink):
+ """Helper class for constructing AWS Batch Job Details Link"""
+
+ name = "Batch Job Details"
+ key = "batch_job_details"
+ format_str = BASE_AWS_CONSOLE_LINK + "/batch/home?region={region_name}#jobs/detail/{job_id}"
+
+
+class BatchJobQueueLink(BaseAwsLink):
+ """Helper class for constructing AWS Batch Job Queue Link"""
+
+ name = "Batch Job Queue"
+ key = "batch_job_queue"
+ format_str = BASE_AWS_CONSOLE_LINK + "/batch/home?region={region_name}#queues/detail/{job_queue_arn}"
diff --git a/airflow/providers/amazon/aws/links/emr.py b/airflow/providers/amazon/aws/links/emr.py
new file mode 100644
index 0000000000000..aa739567fb919
--- /dev/null
+++ b/airflow/providers/amazon/aws/links/emr.py
@@ -0,0 +1,29 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from airflow.providers.amazon.aws.links.base_aws import BASE_AWS_CONSOLE_LINK, BaseAwsLink
+
+
+class EmrClusterLink(BaseAwsLink):
+ """Helper class for constructing AWS EMR Cluster Link"""
+
+ name = "EMR Cluster"
+ key = "emr_cluster"
+ format_str = (
+ BASE_AWS_CONSOLE_LINK + "/elasticmapreduce/home?region={region_name}#cluster-details:{job_flow_id}"
+ )
diff --git a/airflow/providers/amazon/aws/links/logs.py b/airflow/providers/amazon/aws/links/logs.py
new file mode 100644
index 0000000000000..7998191d9226d
--- /dev/null
+++ b/airflow/providers/amazon/aws/links/logs.py
@@ -0,0 +1,42 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from urllib.parse import quote_plus
+
+from airflow.providers.amazon.aws.links.base_aws import BASE_AWS_CONSOLE_LINK, BaseAwsLink
+
+
+class CloudWatchEventsLink(BaseAwsLink):
+ """Helper class for constructing AWS CloudWatch Events Link"""
+
+ name = "CloudWatch Events"
+ key = "cloudwatch_events"
+ format_str = (
+ BASE_AWS_CONSOLE_LINK
+ + "/cloudwatch/home?region={awslogs_region}#logsV2:log-groups/log-group/{awslogs_group}"
+ + "/log-events/{awslogs_stream_name}"
+ )
+
+ def format_link(self, **kwargs) -> str:
+ for field in ("awslogs_stream_name", "awslogs_group"):
+ if field in kwargs:
+ kwargs[field] = quote_plus(kwargs[field])
+ else:
+ return ""
+
+ return super().format_link(**kwargs)
diff --git a/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py b/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py
index c975a2cb83fc6..e50a6d4d74e2e 100644
--- a/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py
+++ b/airflow/providers/amazon/aws/log/cloudwatch_task_handler.py
@@ -15,17 +15,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-import sys
+from __future__ import annotations
+
from datetime import datetime
import watchtower
-if sys.version_info >= (3, 8):
- from functools import cached_property
-else:
- from cached_property import cached_property
-
+from airflow.compat.functools import cached_property
from airflow.configuration import conf
+from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
from airflow.utils.log.file_task_handler import FileTaskHandler
from airflow.utils.log.logging_mixin import LoggingMixin
@@ -42,9 +40,9 @@ class CloudwatchTaskHandler(FileTaskHandler, LoggingMixin):
:param filename_template: template for file name (local storage) or log stream name (remote)
"""
- def __init__(self, base_log_folder: str, log_group_arn: str, filename_template: str):
+ def __init__(self, base_log_folder: str, log_group_arn: str, filename_template: str | None = None):
super().__init__(base_log_folder, filename_template)
- split_arn = log_group_arn.split(':')
+ split_arn = log_group_arn.split(":")
self.handler = None
self.log_group = split_arn[6]
@@ -54,30 +52,19 @@ def __init__(self, base_log_folder: str, log_group_arn: str, filename_template:
@cached_property
def hook(self):
"""Returns AwsLogsHook."""
- remote_conn_id = conf.get('logging', 'REMOTE_LOG_CONN_ID')
- try:
- from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
-
- return AwsLogsHook(aws_conn_id=remote_conn_id, region_name=self.region_name)
- except Exception as e:
- self.log.error(
- 'Could not create an AwsLogsHook with connection id "%s". '
- 'Please make sure that apache-airflow[aws] is installed and '
- 'the Cloudwatch logs connection exists. Exception: "%s"',
- remote_conn_id,
- e,
- )
- return None
+ return AwsLogsHook(
+ aws_conn_id=conf.get("logging", "REMOTE_LOG_CONN_ID"), region_name=self.region_name
+ )
def _render_filename(self, ti, try_number):
# Replace unsupported log group name characters
- return super()._render_filename(ti, try_number).replace(':', '_')
+ return super()._render_filename(ti, try_number).replace(":", "_")
def set_context(self, ti):
super().set_context(ti)
self.handler = watchtower.CloudWatchLogHandler(
- log_group=self.log_group,
- stream_name=self._render_filename(ti, ti.try_number),
+ log_group_name=self.log_group,
+ log_stream_name=self._render_filename(ti, ti.try_number),
boto3_client=self.hook.get_conn(),
)
@@ -97,11 +84,21 @@ def close(self):
def _read(self, task_instance, try_number, metadata=None):
stream_name = self._render_filename(task_instance, try_number)
- return (
- f'*** Reading remote log from Cloudwatch log_group: {self.log_group} '
- f'log_stream: {stream_name}.\n{self.get_cloudwatch_logs(stream_name=stream_name)}\n',
- {'end_of_log': True},
- )
+ try:
+ return (
+ f"*** Reading remote log from Cloudwatch log_group: {self.log_group} "
+ f"log_stream: {stream_name}.\n{self.get_cloudwatch_logs(stream_name=stream_name)}\n",
+ {"end_of_log": True},
+ )
+ except Exception as e:
+ log = (
+ f"*** Unable to read remote logs from Cloudwatch (log_group: {self.log_group}, log_stream: "
+ f"{stream_name})\n*** {str(e)}\n\n"
+ )
+ self.log.error(log)
+ local_log, metadata = super()._read(task_instance, try_number, metadata)
+ log += local_log
+ return log, metadata
def get_cloudwatch_logs(self, stream_name: str) -> str:
"""
@@ -110,21 +107,15 @@ def get_cloudwatch_logs(self, stream_name: str) -> str:
:param stream_name: name of the Cloudwatch log stream to get all logs from
:return: string of all logs from the given log stream
"""
- try:
- events = list(
- self.hook.get_log_events(
- log_group=self.log_group, log_stream_name=stream_name, start_from_head=True
- )
- )
-
- return '\n'.join(self._event_to_str(event) for event in events)
- except Exception:
- msg = f'Could not read remote logs from log_group: {self.log_group} log_stream: {stream_name}.'
- self.log.exception(msg)
- return msg
+ events = self.hook.get_log_events(
+ log_group=self.log_group,
+ log_stream_name=stream_name,
+ start_from_head=True,
+ )
+ return "\n".join(self._event_to_str(event) for event in events)
def _event_to_str(self, event: dict) -> str:
- event_dt = datetime.utcfromtimestamp(event['timestamp'] / 1000.0)
- formatted_event_dt = event_dt.strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
- message = event['message']
- return f'[{formatted_event_dt}] {message}'
+ event_dt = datetime.utcfromtimestamp(event["timestamp"] / 1000.0)
+ formatted_event_dt = event_dt.strftime("%Y-%m-%d %H:%M:%S,%f")[:-3]
+ message = event["message"]
+ return f"[{formatted_event_dt}] {message}"
diff --git a/airflow/providers/amazon/aws/log/s3_task_handler.py b/airflow/providers/amazon/aws/log/s3_task_handler.py
index ce3da88f16916..8535277b13e70 100644
--- a/airflow/providers/amazon/aws/log/s3_task_handler.py
+++ b/airflow/providers/amazon/aws/log/s3_task_handler.py
@@ -15,16 +15,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
import os
import pathlib
-import sys
-
-if sys.version_info >= (3, 8):
- from functools import cached_property
-else:
- from cached_property import cached_property
+from airflow.compat.functools import cached_property
from airflow.configuration import conf
+from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.utils.log.file_task_handler import FileTaskHandler
from airflow.utils.log.logging_mixin import LoggingMixin
@@ -36,10 +34,10 @@ class S3TaskHandler(FileTaskHandler, LoggingMixin):
uploads to and reads from S3 remote storage.
"""
- def __init__(self, base_log_folder: str, s3_log_folder: str, filename_template: str):
+ def __init__(self, base_log_folder: str, s3_log_folder: str, filename_template: str | None = None):
super().__init__(base_log_folder, filename_template)
self.remote_base = s3_log_folder
- self.log_relative_path = ''
+ self.log_relative_path = ""
self._hook = None
self.closed = False
self.upload_on_close = True
@@ -47,20 +45,9 @@ def __init__(self, base_log_folder: str, s3_log_folder: str, filename_template:
@cached_property
def hook(self):
"""Returns S3Hook."""
- remote_conn_id = conf.get('logging', 'REMOTE_LOG_CONN_ID')
- try:
- from airflow.providers.amazon.aws.hooks.s3 import S3Hook
-
- return S3Hook(remote_conn_id, transfer_config_args={"use_threads": False})
- except Exception as e:
- self.log.exception(
- 'Could not create an S3Hook with connection id "%s". '
- 'Please make sure that apache-airflow[aws] is installed and '
- 'the S3 connection exists. Exception : "%s"',
- remote_conn_id,
- e,
- )
- return None
+ return S3Hook(
+ aws_conn_id=conf.get("logging", "REMOTE_LOG_CONN_ID"), transfer_config_args={"use_threads": False}
+ )
def set_context(self, ti):
super().set_context(ti)
@@ -72,7 +59,7 @@ def set_context(self, ti):
# Clear the file first so that duplicate data is not uploaded
# when re-using the same path (e.g. with rescheduled sensors)
if self.upload_on_close:
- with open(self.handler.baseFilename, 'w'):
+ with open(self.handler.baseFilename, "w"):
pass
def close(self):
@@ -122,18 +109,18 @@ def _read(self, ti, try_number, metadata=None):
log_exists = self.s3_log_exists(remote_loc)
except Exception as error:
self.log.exception("Failed to verify remote log exists %s.", remote_loc)
- log = f'*** Failed to verify remote log exists {remote_loc}.\n{error}\n'
+ log = f"*** Failed to verify remote log exists {remote_loc}.\n{error}\n"
if log_exists:
# If S3 remote file exists, we do not fetch logs from task instance
# local machine even if there are errors reading remote logs, as
# returned remote_log will contain error messages.
remote_log = self.s3_read(remote_loc, return_error=True)
- log = f'*** Reading remote log from {remote_loc}.\n{remote_log}\n'
- return log, {'end_of_log': True}
+ log = f"*** Reading remote log from {remote_loc}.\n{remote_log}\n"
+ return log, {"end_of_log": True}
else:
- log += '*** Falling back to local log\n'
- local_log, metadata = super()._read(ti, try_number)
+ log += "*** Falling back to local log\n"
+ local_log, metadata = super()._read(ti, try_number, metadata)
return log + local_log, metadata
def s3_log_exists(self, remote_log_location: str) -> bool:
@@ -158,12 +145,12 @@ def s3_read(self, remote_log_location: str, return_error: bool = False) -> str:
try:
return self.hook.read_key(remote_log_location)
except Exception as error:
- msg = f'Could not read logs from {remote_log_location} with error: {error}'
+ msg = f"Could not read logs from {remote_log_location} with error: {error}"
self.log.exception(msg)
# return error if needed
if return_error:
return msg
- return ''
+ return ""
def s3_write(self, log: str, remote_log_location: str, append: bool = True, max_retry: int = 1):
"""
@@ -179,9 +166,9 @@ def s3_write(self, log: str, remote_log_location: str, append: bool = True, max_
try:
if append and self.s3_log_exists(remote_log_location):
old_log = self.s3_read(remote_log_location)
- log = '\n'.join([old_log, log]) if old_log else log
+ log = "\n".join([old_log, log]) if old_log else log
except Exception:
- self.log.exception('Could not verify previous log to append')
+ self.log.exception("Could not verify previous log to append")
# Default to a single retry attempt because s3 upload failures are
# rare but occasionally occur. Multiple retry attempts are unlikely
@@ -192,11 +179,11 @@ def s3_write(self, log: str, remote_log_location: str, append: bool = True, max_
log,
key=remote_log_location,
replace=True,
- encrypt=conf.getboolean('logging', 'ENCRYPT_S3_LOGS'),
+ encrypt=conf.getboolean("logging", "ENCRYPT_S3_LOGS"),
)
break
except Exception:
if try_num < max_retry:
- self.log.warning('Failed attempt to write logs to %s, will retry', remote_log_location)
+ self.log.warning("Failed attempt to write logs to %s, will retry", remote_log_location)
else:
- self.log.exception('Could not write logs to %s', remote_log_location)
+ self.log.exception("Could not write logs to %s", remote_log_location)
diff --git a/airflow/providers/amazon/aws/operators/appflow.py b/airflow/providers/amazon/aws/operators/appflow.py
new file mode 100644
index 0000000000000..b077b04bb59dd
--- /dev/null
+++ b/airflow/providers/amazon/aws/operators/appflow.py
@@ -0,0 +1,480 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from datetime import datetime, timedelta
+from typing import TYPE_CHECKING, cast
+
+from airflow.compat.functools import cached_property
+from airflow.exceptions import AirflowException
+from airflow.models import BaseOperator
+from airflow.operators.python import ShortCircuitOperator
+from airflow.providers.amazon.aws.hooks.appflow import AppflowHook
+from airflow.providers.amazon.aws.utils import datetime_to_epoch_ms
+
+if TYPE_CHECKING:
+ from mypy_boto3_appflow.type_defs import (
+ DescribeFlowExecutionRecordsResponseTypeDef,
+ ExecutionRecordTypeDef,
+ TaskTypeDef,
+ )
+
+ from airflow.utils.context import Context
+
+
+SUPPORTED_SOURCES = {"salesforce", "zendesk"}
+MANDATORY_FILTER_DATE_MSG = "The filter_date argument is mandatory for {entity}!"
+NOT_SUPPORTED_SOURCE_MSG = "Source {source} is not supported for {entity}!"
+
+
+class AppflowBaseOperator(BaseOperator):
+ """
+ Amazon Appflow Base Operator class (not supposed to be used directly in DAGs).
+
+ :param source: The source name (Supported: salesforce, zendesk)
+ :param flow_name: The flow name
+ :param flow_update: A boolean to enable/disable a flow update before the run
+ :param source_field: The field name to apply filters
+ :param filter_date: The date value (or template) to be used in filters.
+ :param poll_interval: how often in seconds to check the query status
+ :param aws_conn_id: aws connection to use
+ :param region: aws region to use
+ """
+
+ ui_color = "#2bccbd"
+
+ def __init__(
+ self,
+ source: str,
+ flow_name: str,
+ flow_update: bool,
+ source_field: str | None = None,
+ filter_date: str | None = None,
+ poll_interval: int = 20,
+ aws_conn_id: str = "aws_default",
+ region: str | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ if source not in SUPPORTED_SOURCES:
+ raise ValueError(f"{source} is not a supported source (options: {SUPPORTED_SOURCES})!")
+ self.filter_date = filter_date
+ self.flow_name = flow_name
+ self.source = source
+ self.source_field = source_field
+ self.poll_interval = poll_interval
+ self.aws_conn_id = aws_conn_id
+ self.region = region
+ self.flow_update = flow_update
+
+ @cached_property
+ def hook(self) -> AppflowHook:
+ """Create and return an AppflowHook."""
+ return AppflowHook(aws_conn_id=self.aws_conn_id, region_name=self.region)
+
+ def execute(self, context: Context) -> None:
+ self.filter_date_parsed: datetime | None = (
+ datetime.fromisoformat(self.filter_date) if self.filter_date else None
+ )
+ self.connector_type = self._get_connector_type()
+ if self.flow_update:
+ self._update_flow()
+ self._run_flow(context)
+
+ def _get_connector_type(self) -> str:
+ response = self.hook.conn.describe_flow(flowName=self.flow_name)
+ connector_type = response["sourceFlowConfig"]["connectorType"]
+ if self.source != connector_type.lower():
+ raise ValueError(f"Incompatible source ({self.source} and connector type ({connector_type})!")
+ return connector_type
+
+ def _update_flow(self) -> None:
+ self.hook.update_flow_filter(flow_name=self.flow_name, filter_tasks=[], set_trigger_ondemand=True)
+
+ def _run_flow(self, context) -> str:
+ execution_id = self.hook.run_flow(flow_name=self.flow_name, poll_interval=self.poll_interval)
+ task_instance = context["task_instance"]
+ task_instance.xcom_push("execution_id", execution_id)
+ return execution_id
+
+
+class AppflowRunOperator(AppflowBaseOperator):
+ """
+ Execute a Appflow run with filters as is.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:AppflowRunOperator`
+
+ :param source: The source name (Supported: salesforce, zendesk)
+ :param flow_name: The flow name
+ :param poll_interval: how often in seconds to check the query status
+ :param aws_conn_id: aws connection to use
+ :param region: aws region to use
+ """
+
+ def __init__(
+ self,
+ source: str,
+ flow_name: str,
+ poll_interval: int = 20,
+ aws_conn_id: str = "aws_default",
+ region: str | None = None,
+ **kwargs,
+ ) -> None:
+ if source not in {"salesforce", "zendesk"}:
+ raise ValueError(NOT_SUPPORTED_SOURCE_MSG.format(source=source, entity="AppflowRunOperator"))
+ super().__init__(
+ source=source,
+ flow_name=flow_name,
+ flow_update=False,
+ source_field=None,
+ filter_date=None,
+ poll_interval=poll_interval,
+ aws_conn_id=aws_conn_id,
+ region=region,
+ **kwargs,
+ )
+
+
+class AppflowRunFullOperator(AppflowBaseOperator):
+ """
+ Execute a Appflow full run removing any filter.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:AppflowRunFullOperator`
+
+ :param source: The source name (Supported: salesforce, zendesk)
+ :param flow_name: The flow name
+ :param poll_interval: how often in seconds to check the query status
+ :param aws_conn_id: aws connection to use
+ :param region: aws region to use
+ """
+
+ def __init__(
+ self,
+ source: str,
+ flow_name: str,
+ poll_interval: int = 20,
+ aws_conn_id: str = "aws_default",
+ region: str | None = None,
+ **kwargs,
+ ) -> None:
+ if source not in {"salesforce", "zendesk"}:
+ raise ValueError(NOT_SUPPORTED_SOURCE_MSG.format(source=source, entity="AppflowRunFullOperator"))
+ super().__init__(
+ source=source,
+ flow_name=flow_name,
+ flow_update=True,
+ source_field=None,
+ filter_date=None,
+ poll_interval=poll_interval,
+ aws_conn_id=aws_conn_id,
+ region=region,
+ **kwargs,
+ )
+
+
+class AppflowRunBeforeOperator(AppflowBaseOperator):
+ """
+ Execute a Appflow run after updating the filters to select only previous data.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:AppflowRunBeforeOperator`
+
+ :param source: The source name (Supported: salesforce)
+ :param flow_name: The flow name
+ :param source_field: The field name to apply filters
+ :param filter_date: The date value (or template) to be used in filters.
+ :param poll_interval: how often in seconds to check the query status
+ :param aws_conn_id: aws connection to use
+ :param region: aws region to use
+ """
+
+ template_fields = ("filter_date",)
+
+ def __init__(
+ self,
+ source: str,
+ flow_name: str,
+ source_field: str,
+ filter_date: str,
+ poll_interval: int = 20,
+ aws_conn_id: str = "aws_default",
+ region: str | None = None,
+ **kwargs,
+ ) -> None:
+ if not filter_date:
+ raise ValueError(MANDATORY_FILTER_DATE_MSG.format(entity="AppflowRunBeforeOperator"))
+ if source != "salesforce":
+ raise ValueError(
+ NOT_SUPPORTED_SOURCE_MSG.format(source=source, entity="AppflowRunBeforeOperator")
+ )
+ super().__init__(
+ source=source,
+ flow_name=flow_name,
+ flow_update=True,
+ source_field=source_field,
+ filter_date=filter_date,
+ poll_interval=poll_interval,
+ aws_conn_id=aws_conn_id,
+ region=region,
+ **kwargs,
+ )
+
+ def _update_flow(self) -> None:
+ if not self.filter_date_parsed:
+ raise ValueError(f"Invalid filter_date argument parser value: {self.filter_date_parsed}")
+ if not self.source_field:
+ raise ValueError(f"Invalid source_field argument value: {self.source_field}")
+ filter_task: TaskTypeDef = {
+ "taskType": "Filter",
+ "connectorOperator": {self.connector_type: "LESS_THAN"}, # type: ignore
+ "sourceFields": [self.source_field],
+ "taskProperties": {
+ "DATA_TYPE": "datetime",
+ "VALUE": str(datetime_to_epoch_ms(self.filter_date_parsed)),
+ }, # NOT inclusive
+ }
+ self.hook.update_flow_filter(
+ flow_name=self.flow_name, filter_tasks=[filter_task], set_trigger_ondemand=True
+ )
+
+
+class AppflowRunAfterOperator(AppflowBaseOperator):
+ """
+ Execute a Appflow run after updating the filters to select only future data.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:AppflowRunAfterOperator`
+
+ :param source: The source name (Supported: salesforce, zendesk)
+ :param flow_name: The flow name
+ :param source_field: The field name to apply filters
+ :param filter_date: The date value (or template) to be used in filters.
+ :param poll_interval: how often in seconds to check the query status
+ :param aws_conn_id: aws connection to use
+ :param region: aws region to use
+ """
+
+ template_fields = ("filter_date",)
+
+ def __init__(
+ self,
+ source: str,
+ flow_name: str,
+ source_field: str,
+ filter_date: str,
+ poll_interval: int = 20,
+ aws_conn_id: str = "aws_default",
+ region: str | None = None,
+ **kwargs,
+ ) -> None:
+ if not filter_date:
+ raise ValueError(MANDATORY_FILTER_DATE_MSG.format(entity="AppflowRunAfterOperator"))
+ if source not in {"salesforce", "zendesk"}:
+ raise ValueError(NOT_SUPPORTED_SOURCE_MSG.format(source=source, entity="AppflowRunAfterOperator"))
+ super().__init__(
+ source=source,
+ flow_name=flow_name,
+ flow_update=True,
+ source_field=source_field,
+ filter_date=filter_date,
+ poll_interval=poll_interval,
+ aws_conn_id=aws_conn_id,
+ region=region,
+ **kwargs,
+ )
+
+ def _update_flow(self) -> None:
+ if not self.filter_date_parsed:
+ raise ValueError(f"Invalid filter_date argument parser value: {self.filter_date_parsed}")
+ if not self.source_field:
+ raise ValueError(f"Invalid source_field argument value: {self.source_field}")
+ filter_task: TaskTypeDef = {
+ "taskType": "Filter",
+ "connectorOperator": {self.connector_type: "GREATER_THAN"}, # type: ignore
+ "sourceFields": [self.source_field],
+ "taskProperties": {
+ "DATA_TYPE": "datetime",
+ "VALUE": str(datetime_to_epoch_ms(self.filter_date_parsed)),
+ }, # NOT inclusive
+ }
+ self.hook.update_flow_filter(
+ flow_name=self.flow_name, filter_tasks=[filter_task], set_trigger_ondemand=True
+ )
+
+
+class AppflowRunDailyOperator(AppflowBaseOperator):
+ """
+ Execute a Appflow run after updating the filters to select only a single day.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:AppflowRunDailyOperator`
+
+ :param source: The source name (Supported: salesforce)
+ :param flow_name: The flow name
+ :param source_field: The field name to apply filters
+ :param filter_date: The date value (or template) to be used in filters.
+ :param poll_interval: how often in seconds to check the query status
+ :param aws_conn_id: aws connection to use
+ :param region: aws region to use
+ """
+
+ template_fields = ("filter_date",)
+
+ def __init__(
+ self,
+ source: str,
+ flow_name: str,
+ source_field: str,
+ filter_date: str,
+ poll_interval: int = 20,
+ aws_conn_id: str = "aws_default",
+ region: str | None = None,
+ **kwargs,
+ ) -> None:
+ if not filter_date:
+ raise ValueError(MANDATORY_FILTER_DATE_MSG.format(entity="AppflowRunDailyOperator"))
+ if source != "salesforce":
+ raise ValueError(NOT_SUPPORTED_SOURCE_MSG.format(source=source, entity="AppflowRunDailyOperator"))
+ super().__init__(
+ source=source,
+ flow_name=flow_name,
+ flow_update=True,
+ source_field=source_field,
+ filter_date=filter_date,
+ poll_interval=poll_interval,
+ aws_conn_id=aws_conn_id,
+ region=region,
+ **kwargs,
+ )
+
+ def _update_flow(self) -> None:
+ if not self.filter_date_parsed:
+ raise ValueError(f"Invalid filter_date argument parser value: {self.filter_date_parsed}")
+ if not self.source_field:
+ raise ValueError(f"Invalid source_field argument value: {self.source_field}")
+ start_filter_date = self.filter_date_parsed - timedelta(milliseconds=1)
+ end_filter_date = self.filter_date_parsed + timedelta(days=1)
+ filter_task: TaskTypeDef = {
+ "taskType": "Filter",
+ "connectorOperator": {self.connector_type: "BETWEEN"}, # type: ignore
+ "sourceFields": [self.source_field],
+ "taskProperties": {
+ "DATA_TYPE": "datetime",
+ "LOWER_BOUND": str(datetime_to_epoch_ms(start_filter_date)), # NOT inclusive
+ "UPPER_BOUND": str(datetime_to_epoch_ms(end_filter_date)), # NOT inclusive
+ },
+ }
+ self.hook.update_flow_filter(
+ flow_name=self.flow_name, filter_tasks=[filter_task], set_trigger_ondemand=True
+ )
+
+
+class AppflowRecordsShortCircuitOperator(ShortCircuitOperator):
+ """
+ Short-circuit in case of a empty Appflow's run.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:AppflowRecordsShortCircuitOperator`
+
+ :param flow_name: The flow name
+ :param appflow_run_task_id: Run task ID from where this operator should extract the execution ID
+ :param ignore_downstream_trigger_rules: Ignore downstream trigger rules
+ :param aws_conn_id: aws connection to use
+ :param region: aws region to use
+ """
+
+ ui_color = "#33ffec" # Light blue
+
+ def __init__(
+ self,
+ *,
+ flow_name: str,
+ appflow_run_task_id: str,
+ ignore_downstream_trigger_rules: bool = True,
+ aws_conn_id: str = "aws_default",
+ region: str | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ python_callable=self._has_new_records_func,
+ op_kwargs={
+ "flow_name": flow_name,
+ "appflow_run_task_id": appflow_run_task_id,
+ },
+ ignore_downstream_trigger_rules=ignore_downstream_trigger_rules,
+ **kwargs,
+ )
+ self.aws_conn_id = aws_conn_id
+ self.region = region
+
+ @staticmethod
+ def _get_target_execution_id(
+ records: list[ExecutionRecordTypeDef], execution_id: str
+ ) -> ExecutionRecordTypeDef | None:
+ for record in records:
+ if record.get("executionId") == execution_id:
+ return record
+ return None
+
+ @cached_property
+ def hook(self) -> AppflowHook:
+ """Create and return an AppflowHook."""
+ return AppflowHook(aws_conn_id=self.aws_conn_id, region_name=self.region)
+
+ def _has_new_records_func(self, **kwargs) -> bool:
+ appflow_task_id = kwargs["appflow_run_task_id"]
+ self.log.info("appflow_task_id: %s", appflow_task_id)
+ flow_name = kwargs["flow_name"]
+ self.log.info("flow_name: %s", flow_name)
+ af_client = self.hook.conn
+ task_instance = kwargs["task_instance"]
+ execution_id = task_instance.xcom_pull(task_ids=appflow_task_id, key="execution_id") # type: ignore
+ if not execution_id:
+ raise AirflowException(f"No execution_id found from task_id {appflow_task_id}!")
+ self.log.info("execution_id: %s", execution_id)
+ args = {"flowName": flow_name, "maxResults": 100}
+ response: DescribeFlowExecutionRecordsResponseTypeDef = cast(
+ "DescribeFlowExecutionRecordsResponseTypeDef", {}
+ )
+ record = None
+
+ while not record:
+ if "nextToken" in response:
+ response = af_client.describe_flow_execution_records(nextToken=response["nextToken"], **args)
+ else:
+ response = af_client.describe_flow_execution_records(**args)
+ record = AppflowRecordsShortCircuitOperator._get_target_execution_id(
+ response["flowExecutions"], execution_id
+ )
+ if not record and "nextToken" not in response:
+ raise AirflowException(f"Flow ({execution_id}) without recordsProcessed info.")
+
+ execution = record.get("executionResult", {})
+ if "recordsProcessed" not in execution:
+ raise AirflowException(f"Flow ({execution_id}) without recordsProcessed info!")
+ records_processed = execution["recordsProcessed"]
+ self.log.info("records_processed: %d", records_processed)
+ task_instance.xcom_push("records_processed", records_processed) # type: ignore
+ return records_processed > 0
diff --git a/airflow/providers/amazon/aws/operators/athena.py b/airflow/providers/amazon/aws/operators/athena.py
index 6febe2a917532..61c897a4817d3 100644
--- a/airflow/providers/amazon/aws/operators/athena.py
+++ b/airflow/providers/amazon/aws/operators/athena.py
@@ -15,16 +15,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-#
-import sys
-import warnings
-from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence
+from __future__ import annotations
-if sys.version_info >= (3, 8):
- from functools import cached_property
-else:
- from cached_property import cached_property
+import warnings
+from typing import TYPE_CHECKING, Any, Sequence
+from airflow.compat.functools import cached_property
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.athena import AthenaHook
@@ -49,12 +45,14 @@ class AthenaOperator(BaseOperator):
:param query_execution_context: Context in which query need to be run
:param result_configuration: Dict with path to store results in and config related to encryption
:param sleep_time: Time (in seconds) to wait between two consecutive calls to check query status on Athena
- :param max_tries: Number of times to poll for query state before function exits
+ :param max_tries: Deprecated - use max_polling_attempts instead.
+ :param max_polling_attempts: Number of times to poll for query state before function exits
+ To limit task execution time, use execution_timeout.
"""
- ui_color = '#44b5e2'
- template_fields: Sequence[str] = ('query', 'database', 'output_location')
- template_ext: Sequence[str] = ('.sql',)
+ ui_color = "#44b5e2"
+ template_fields: Sequence[str] = ("query", "database", "output_location")
+ template_ext: Sequence[str] = (".sql",)
template_fields_renderers = {"query": "sql"}
def __init__(
@@ -64,12 +62,13 @@ def __init__(
database: str,
output_location: str,
aws_conn_id: str = "aws_default",
- client_request_token: Optional[str] = None,
+ client_request_token: str | None = None,
workgroup: str = "primary",
- query_execution_context: Optional[Dict[str, str]] = None,
- result_configuration: Optional[Dict[str, Any]] = None,
+ query_execution_context: dict[str, str] | None = None,
+ result_configuration: dict[str, Any] | None = None,
sleep_time: int = 30,
- max_tries: Optional[int] = None,
+ max_tries: int | None = None,
+ max_polling_attempts: int | None = None,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
@@ -82,18 +81,30 @@ def __init__(
self.query_execution_context = query_execution_context or {}
self.result_configuration = result_configuration or {}
self.sleep_time = sleep_time
- self.max_tries = max_tries
- self.query_execution_id = None # type: Optional[str]
+ self.max_polling_attempts = max_polling_attempts
+ self.query_execution_id: str | None = None
+
+ if max_tries:
+ warnings.warn(
+ f"Parameter `{self.__class__.__name__}.max_tries` is deprecated and will be removed "
+ "in a future release. Please use method `max_polling_attempts` instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ if max_polling_attempts and max_polling_attempts != max_tries:
+ raise Exception("max_polling_attempts must be the same value as max_tries")
+ else:
+ self.max_polling_attempts = max_tries
@cached_property
def hook(self) -> AthenaHook:
"""Create and return an AthenaHook."""
return AthenaHook(self.aws_conn_id, sleep_time=self.sleep_time)
- def execute(self, context: 'Context') -> Optional[str]:
+ def execute(self, context: Context) -> str | None:
"""Run Presto Query on Athena"""
- self.query_execution_context['Database'] = self.database
- self.result_configuration['OutputLocation'] = self.output_location
+ self.query_execution_context["Database"] = self.database
+ self.result_configuration["OutputLocation"] = self.output_location
self.query_execution_id = self.hook.run_query(
self.query,
self.query_execution_context,
@@ -101,18 +112,21 @@ def execute(self, context: 'Context') -> Optional[str]:
self.client_request_token,
self.workgroup,
)
- query_status = self.hook.poll_query_status(self.query_execution_id, self.max_tries)
+ query_status = self.hook.poll_query_status(
+ self.query_execution_id,
+ max_polling_attempts=self.max_polling_attempts,
+ )
if query_status in AthenaHook.FAILURE_STATES:
error_message = self.hook.get_state_change_reason(self.query_execution_id)
raise Exception(
- f'Final state of Athena job is {query_status}, query_execution_id is '
- f'{self.query_execution_id}. Error: {error_message}'
+ f"Final state of Athena job is {query_status}, query_execution_id is "
+ f"{self.query_execution_id}. Error: {error_message}"
)
elif not query_status or query_status in AthenaHook.INTERMEDIATE_STATES:
raise Exception(
- f'Final state of Athena job is {query_status}. Max tries of poll status exceeded, '
- f'query_execution_id is {self.query_execution_id}.'
+ f"Final state of Athena job is {query_status}. Max tries of poll status exceeded, "
+ f"query_execution_id is {self.query_execution_id}."
)
return self.query_execution_id
@@ -120,35 +134,19 @@ def execute(self, context: 'Context') -> Optional[str]:
def on_kill(self) -> None:
"""Cancel the submitted athena query"""
if self.query_execution_id:
- self.log.info('Received a kill signal.')
- self.log.info('Stopping Query with executionId - %s', self.query_execution_id)
+ self.log.info("Received a kill signal.")
+ self.log.info("Stopping Query with executionId - %s", self.query_execution_id)
response = self.hook.stop_query(self.query_execution_id)
http_status_code = None
try:
- http_status_code = response['ResponseMetadata']['HTTPStatusCode']
+ http_status_code = response["ResponseMetadata"]["HTTPStatusCode"]
except Exception as ex:
- self.log.error('Exception while cancelling query: %s', ex)
+ self.log.error("Exception while cancelling query: %s", ex)
finally:
if http_status_code is None or http_status_code != 200:
- self.log.error('Unable to request query cancel on athena. Exiting')
+ self.log.error("Unable to request query cancel on athena. Exiting")
else:
self.log.info(
- 'Polling Athena for query with id %s to reach final state', self.query_execution_id
+ "Polling Athena for query with id %s to reach final state", self.query_execution_id
)
self.hook.poll_query_status(self.query_execution_id)
-
-
-class AWSAthenaOperator(AthenaOperator):
- """
- This operator is deprecated.
- Please use :class:`airflow.providers.amazon.aws.operators.athena.AthenaOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- "This operator is deprecated. Please use "
- "`airflow.providers.amazon.aws.operators.athena.AthenaOperator`.",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/providers/amazon/aws/operators/aws_lambda.py b/airflow/providers/amazon/aws/operators/aws_lambda.py
index c2d9d022fbcbc..70d96c094d2a8 100644
--- a/airflow/providers/amazon/aws/operators/aws_lambda.py
+++ b/airflow/providers/amazon/aws/operators/aws_lambda.py
@@ -15,88 +15,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.lambda_function`."""
+from __future__ import annotations
-import json
-from typing import TYPE_CHECKING, Optional, Sequence
+import warnings
-from airflow.models import BaseOperator
-from airflow.providers.amazon.aws.hooks.lambda_function import LambdaHook
+from airflow.providers.amazon.aws.operators.lambda_function import AwsLambdaInvokeFunctionOperator # noqa
-if TYPE_CHECKING:
- from airflow.utils.context import Context
-
-
-class AwsLambdaInvokeFunctionOperator(BaseOperator):
- """
- Invokes an AWS Lambda function.
- You can invoke a function synchronously (and wait for the response),
- or asynchronously.
- To invoke a function asynchronously,
- set `invocation_type` to `Event`. For more details,
- review the boto3 Lambda invoke docs.
-
- :param function_name: The name of the AWS Lambda function, version, or alias.
- :param payload: The JSON string that you want to provide to your Lambda function as input.
- :param log_type: Set to Tail to include the execution log in the response. Otherwise, set to "None".
- :param qualifier: Specify a version or alias to invoke a published version of the function.
- :param aws_conn_id: The AWS connection ID to use
-
- .. seealso::
- For more information on how to use this operator, take a look at the guide:
- :ref:`howto/operator:AwsLambdaInvokeFunctionOperator`
-
- """
-
- template_fields: Sequence[str] = ('function_name', 'payload', 'qualifier', 'invocation_type')
- ui_color = '#ff7300'
-
- def __init__(
- self,
- *,
- function_name: str,
- log_type: Optional[str] = None,
- qualifier: Optional[str] = None,
- invocation_type: Optional[str] = None,
- client_context: Optional[str] = None,
- payload: Optional[str] = None,
- aws_conn_id: str = 'aws_default',
- **kwargs,
- ):
- super().__init__(**kwargs)
- self.function_name = function_name
- self.payload = payload
- self.log_type = log_type
- self.qualifier = qualifier
- self.invocation_type = invocation_type
- self.client_context = client_context
- self.aws_conn_id = aws_conn_id
-
- def execute(self, context: 'Context'):
- """
- Invokes the target AWS Lambda function from Airflow.
-
- :return: The response payload from the function, or an error object.
- """
- hook = LambdaHook(aws_conn_id=self.aws_conn_id)
- success_status_codes = [200, 202, 204]
- self.log.info("Invoking AWS Lambda function: %s with payload: %s", self.function_name, self.payload)
- response = hook.invoke_lambda(
- function_name=self.function_name,
- invocation_type=self.invocation_type,
- log_type=self.log_type,
- client_context=self.client_context,
- payload=self.payload,
- qualifier=self.qualifier,
- )
- self.log.info("Lambda response metadata: %r", response.get("ResponseMetadata"))
- if response.get("StatusCode") not in success_status_codes:
- raise ValueError('Lambda function did not execute', json.dumps(response.get("ResponseMetadata")))
- payload_stream = response.get("Payload")
- payload = payload_stream.read().decode()
- if "FunctionError" in response:
- raise ValueError(
- 'Lambda function execution resulted in error',
- {"ResponseMetadata": response.get("ResponseMetadata"), "Payload": payload},
- )
- self.log.info('Lambda function invocation succeeded: %r', response.get("ResponseMetadata"))
- return payload
+warnings.warn(
+ "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.lambda_function`.",
+ DeprecationWarning,
+ stacklevel=2,
+)
diff --git a/airflow/providers/amazon/aws/operators/batch.py b/airflow/providers/amazon/aws/operators/batch.py
index fcd925f9ae32e..9e85afaf4c0ce 100644
--- a/airflow/providers/amazon/aws/operators/batch.py
+++ b/airflow/providers/amazon/aws/operators/batch.py
@@ -14,23 +14,36 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-#
-
"""
An Airflow operator for AWS Batch services
.. seealso::
- - http://boto3.readthedocs.io/en/latest/guide/configuration.html
- - http://boto3.readthedocs.io/en/latest/reference/services/batch.html
+ - https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html
+ - https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/batch.html
- https://docs.aws.amazon.com/batch/latest/APIReference/Welcome.html
"""
-import warnings
-from typing import TYPE_CHECKING, Any, Optional, Sequence
+from __future__ import annotations
+
+import sys
+from typing import TYPE_CHECKING, Any, Sequence
+
+from airflow.providers.amazon.aws.utils import trim_none_values
+
+if sys.version_info >= (3, 8):
+ from functools import cached_property
+else:
+ from cached_property import cached_property
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
+from airflow.providers.amazon.aws.links.batch import (
+ BatchJobDefinitionLink,
+ BatchJobDetailsLink,
+ BatchJobQueueLink,
+)
+from airflow.providers.amazon.aws.links.logs import CloudWatchEventsLink
if TYPE_CHECKING:
from airflow.utils.context import Context
@@ -87,14 +100,32 @@ class BatchOperator(BaseOperator):
"""
ui_color = "#c3dae0"
- arn = None # type: Optional[str]
+ arn: str | None = None
template_fields: Sequence[str] = (
+ "job_id",
"job_name",
+ "job_definition",
+ "job_queue",
"overrides",
+ "array_properties",
"parameters",
+ "waiters",
+ "tags",
+ "wait_for_completion",
)
template_fields_renderers = {"overrides": "json", "parameters": "json"}
+ @property
+ def operator_extra_links(self):
+ op_extra_links = [BatchJobDetailsLink()]
+ if self.wait_for_completion:
+ op_extra_links.extend([BatchJobDefinitionLink(), BatchJobQueueLink()])
+ if not self.array_properties:
+ # There is no CloudWatch Link to the parent Batch Job available.
+ op_extra_links.append(CloudWatchEventsLink())
+
+ return tuple(op_extra_links)
+
def __init__(
self,
*,
@@ -102,15 +133,16 @@ def __init__(
job_definition: str,
job_queue: str,
overrides: dict,
- array_properties: Optional[dict] = None,
- parameters: Optional[dict] = None,
- job_id: Optional[str] = None,
- waiters: Optional[Any] = None,
- max_retries: Optional[int] = None,
- status_retries: Optional[int] = None,
- aws_conn_id: Optional[str] = None,
- region_name: Optional[str] = None,
- tags: Optional[dict] = None,
+ array_properties: dict | None = None,
+ parameters: dict | None = None,
+ job_id: str | None = None,
+ waiters: Any | None = None,
+ max_retries: int | None = None,
+ status_retries: int | None = None,
+ aws_conn_id: str | None = None,
+ region_name: str | None = None,
+ tags: dict | None = None,
+ wait_for_completion: bool = True,
**kwargs,
):
@@ -124,6 +156,7 @@ def __init__(
self.parameters = parameters or {}
self.waiters = waiters
self.tags = tags or {}
+ self.wait_for_completion = wait_for_completion
self.hook = BatchClientHook(
max_retries=max_retries,
status_retries=status_retries,
@@ -131,20 +164,24 @@ def __init__(
region_name=region_name,
)
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
"""
Submit and monitor an AWS Batch job
:raises: AirflowException
"""
self.submit_job(context)
- self.monitor_job(context)
+
+ if self.wait_for_completion:
+ self.monitor_job(context)
+
+ return self.job_id
def on_kill(self):
response = self.hook.client.terminate_job(jobId=self.job_id, reason="Task killed by the user")
self.log.info("AWS Batch job (%s) terminated: %s", self.job_id, response)
- def submit_job(self, context: 'Context'):
+ def submit_job(self, context: Context):
"""
Submit an AWS Batch job
@@ -167,14 +204,25 @@ def submit_job(self, context: 'Context'):
containerOverrides=self.overrides,
tags=self.tags,
)
- self.job_id = response["jobId"]
-
- self.log.info("AWS Batch job (%s) started: %s", self.job_id, response)
except Exception as e:
- self.log.error("AWS Batch job (%s) failed submission", self.job_id)
+ self.log.error(
+ "AWS Batch job failed submission - job definition: %s - on queue %s",
+ self.job_definition,
+ self.job_queue,
+ )
raise AirflowException(e)
- def monitor_job(self, context: 'Context'):
+ self.job_id = response["jobId"]
+ self.log.info("AWS Batch job (%s) started: %s", self.job_id, response)
+ BatchJobDetailsLink.persist(
+ context=context,
+ operator=self,
+ region_name=self.hook.conn_region_name,
+ aws_partition=self.hook.conn_partition,
+ job_id=self.job_id,
+ )
+
+ def monitor_job(self, context: Context):
"""
Monitor an AWS Batch job
monitor_job can raise an exception or an AirflowTaskTimeout can be raised if execution_timeout
@@ -184,28 +232,152 @@ def monitor_job(self, context: 'Context'):
:raises: AirflowException
"""
if not self.job_id:
- raise AirflowException('AWS Batch job - job_id was not found')
+ raise AirflowException("AWS Batch job - job_id was not found")
+
+ try:
+ job_desc = self.hook.get_job_description(self.job_id)
+ job_definition_arn = job_desc["jobDefinition"]
+ job_queue_arn = job_desc["jobQueue"]
+ self.log.info(
+ "AWS Batch job (%s) Job Definition ARN: %r, Job Queue ARN: %r",
+ self.job_id,
+ job_definition_arn,
+ job_queue_arn,
+ )
+ except KeyError:
+ self.log.warning("AWS Batch job (%s) can't get Job Definition ARN and Job Queue ARN", self.job_id)
+ else:
+ BatchJobDefinitionLink.persist(
+ context=context,
+ operator=self,
+ region_name=self.hook.conn_region_name,
+ aws_partition=self.hook.conn_partition,
+ job_definition_arn=job_definition_arn,
+ )
+ BatchJobQueueLink.persist(
+ context=context,
+ operator=self,
+ region_name=self.hook.conn_region_name,
+ aws_partition=self.hook.conn_partition,
+ job_queue_arn=job_queue_arn,
+ )
if self.waiters:
self.waiters.wait_for_job(self.job_id)
else:
self.hook.wait_for_job(self.job_id)
+ awslogs = self.hook.get_job_awslogs_info(self.job_id)
+ if awslogs:
+ self.log.info("AWS Batch job (%s) CloudWatch Events details found: %s", self.job_id, awslogs)
+ CloudWatchEventsLink.persist(
+ context=context,
+ operator=self,
+ region_name=self.hook.conn_region_name,
+ aws_partition=self.hook.conn_partition,
+ **awslogs,
+ )
+
self.hook.check_job_success(self.job_id)
self.log.info("AWS Batch job (%s) succeeded", self.job_id)
-class AwsBatchOperator(BatchOperator):
+class BatchCreateComputeEnvironmentOperator(BaseOperator):
"""
- This operator is deprecated.
- Please use :class:`airflow.providers.amazon.aws.operators.batch.BatchOperator`.
+ Create an AWS Batch compute environment
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:BatchCreateComputeEnvironmentOperator`
+
+ :param compute_environment_name: the name of the AWS batch compute environment (templated)
+
+ :param environment_type: the type of the compute-environment
+
+ :param state: the state of the compute-environment
+
+ :param compute_resources: details about the resources managed by the compute-environment (templated).
+ See more details here
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/batch.html#Batch.Client.create_compute_environment
+
+ :param unmanaged_v_cpus: the maximum number of vCPU for an unmanaged compute environment.
+ This parameter is only supported when the ``type`` parameter is set to ``UNMANAGED``.
+
+ :param service_role: the IAM role that allows Batch to make calls to other AWS services on your behalf
+ (templated)
+
+ :param tags: the tags that you apply to the compute-environment to help you categorize and organize your
+ resources
+
+ :param max_retries: exponential back-off retries, 4200 = 48 hours;
+ polling is only used when waiters is None
+
+ :param status_retries: number of HTTP retries to get job status, 10;
+ polling is only used when waiters is None
+
+ :param aws_conn_id: connection id of AWS credentials / region name. If None,
+ credential boto3 strategy will be used.
+
+ :param region_name: region name to use in AWS Hook.
+ Override the region_name in connection (if provided)
"""
- def __init__(self, *args, **kwargs):
- warnings.warn(
- "This operator is deprecated. "
- "Please use :class:`airflow.providers.amazon.aws.operators.batch.BatchOperator`.",
- DeprecationWarning,
- stacklevel=2,
+ template_fields: Sequence[str] = (
+ "compute_environment_name",
+ "compute_resources",
+ "service_role",
+ )
+ template_fields_renderers = {"compute_resources": "json"}
+
+ def __init__(
+ self,
+ compute_environment_name: str,
+ environment_type: str,
+ state: str,
+ compute_resources: dict,
+ unmanaged_v_cpus: int | None = None,
+ service_role: str | None = None,
+ tags: dict | None = None,
+ max_retries: int | None = None,
+ status_retries: int | None = None,
+ aws_conn_id: str | None = None,
+ region_name: str | None = None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.compute_environment_name = compute_environment_name
+ self.environment_type = environment_type
+ self.state = state
+ self.unmanaged_v_cpus = unmanaged_v_cpus
+ self.compute_resources = compute_resources
+ self.service_role = service_role
+ self.tags = tags or {}
+ self.max_retries = max_retries
+ self.status_retries = status_retries
+ self.aws_conn_id = aws_conn_id
+ self.region_name = region_name
+
+ @cached_property
+ def hook(self):
+ """Create and return a BatchClientHook"""
+ return BatchClientHook(
+ max_retries=self.max_retries,
+ status_retries=self.status_retries,
+ aws_conn_id=self.aws_conn_id,
+ region_name=self.region_name,
)
- super().__init__(*args, **kwargs)
+
+ def execute(self, context: Context):
+ """Create an AWS batch compute environment"""
+ kwargs: dict[str, Any] = {
+ "computeEnvironmentName": self.compute_environment_name,
+ "type": self.environment_type,
+ "state": self.state,
+ "unmanagedvCpus": self.unmanaged_v_cpus,
+ "computeResources": self.compute_resources,
+ "serviceRole": self.service_role,
+ "tags": self.tags,
+ }
+ self.hook.client.create_compute_environment(**trim_none_values(kwargs))
+
+ self.log.info("AWS Batch compute environment created successfully")
diff --git a/airflow/providers/amazon/aws/operators/cloud_formation.py b/airflow/providers/amazon/aws/operators/cloud_formation.py
index 2e7bb8ae121c6..423ec36e598e3 100644
--- a/airflow/providers/amazon/aws/operators/cloud_formation.py
+++ b/airflow/providers/amazon/aws/operators/cloud_formation.py
@@ -16,7 +16,9 @@
# specific language governing permissions and limitations
# under the License.
"""This module contains CloudFormation create/delete stack operators."""
-from typing import TYPE_CHECKING, Optional, Sequence
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Sequence
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.cloud_formation import CloudFormationHook
@@ -39,20 +41,20 @@ class CloudFormationCreateStackOperator(BaseOperator):
:param aws_conn_id: aws connection to uses
"""
- template_fields: Sequence[str] = ('stack_name',)
+ template_fields: Sequence[str] = ("stack_name", "cloudformation_parameters")
template_ext: Sequence[str] = ()
- ui_color = '#6b9659'
+ ui_color = "#6b9659"
def __init__(
- self, *, stack_name: str, cloudformation_parameters: dict, aws_conn_id: str = 'aws_default', **kwargs
+ self, *, stack_name: str, cloudformation_parameters: dict, aws_conn_id: str = "aws_default", **kwargs
):
super().__init__(**kwargs)
self.stack_name = stack_name
self.cloudformation_parameters = cloudformation_parameters
self.aws_conn_id = aws_conn_id
- def execute(self, context: 'Context'):
- self.log.info('CloudFormation parameters: %s', self.cloudformation_parameters)
+ def execute(self, context: Context):
+ self.log.info("CloudFormation parameters: %s", self.cloudformation_parameters)
cloudformation_hook = CloudFormationHook(aws_conn_id=self.aws_conn_id)
cloudformation_hook.create_stack(self.stack_name, self.cloudformation_parameters)
@@ -72,17 +74,17 @@ class CloudFormationDeleteStackOperator(BaseOperator):
:param aws_conn_id: aws connection to uses
"""
- template_fields: Sequence[str] = ('stack_name',)
+ template_fields: Sequence[str] = ("stack_name",)
template_ext: Sequence[str] = ()
- ui_color = '#1d472b'
- ui_fgcolor = '#FFF'
+ ui_color = "#1d472b"
+ ui_fgcolor = "#FFF"
def __init__(
self,
*,
stack_name: str,
- cloudformation_parameters: Optional[dict] = None,
- aws_conn_id: str = 'aws_default',
+ cloudformation_parameters: dict | None = None,
+ aws_conn_id: str = "aws_default",
**kwargs,
):
super().__init__(**kwargs)
@@ -90,8 +92,8 @@ def __init__(
self.stack_name = stack_name
self.aws_conn_id = aws_conn_id
- def execute(self, context: 'Context'):
- self.log.info('CloudFormation Parameters: %s', self.cloudformation_parameters)
+ def execute(self, context: Context):
+ self.log.info("CloudFormation Parameters: %s", self.cloudformation_parameters)
cloudformation_hook = CloudFormationHook(aws_conn_id=self.aws_conn_id)
cloudformation_hook.delete_stack(self.stack_name, self.cloudformation_parameters)
diff --git a/airflow/providers/amazon/aws/operators/datasync.py b/airflow/providers/amazon/aws/operators/datasync.py
index 5a0c86071139f..508d87e163658 100644
--- a/airflow/providers/amazon/aws/operators/datasync.py
+++ b/airflow/providers/amazon/aws/operators/datasync.py
@@ -14,13 +14,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Create, get, update, execute and delete an AWS DataSync Task."""
+from __future__ import annotations
import logging
import random
-import warnings
-from typing import TYPE_CHECKING, List, Optional, Sequence
+from typing import TYPE_CHECKING, Sequence
from airflow.exceptions import AirflowException, AirflowTaskTimeout
from airflow.models import BaseOperator
@@ -49,6 +48,7 @@ class DataSyncOperator(BaseOperator):
consecutive calls to check TaskExecution status.
:param max_iterations: Maximum number of
consecutive calls to check TaskExecution status.
+ :param wait_for_completion: If True, wait for the task execution to reach a final state
:param task_arn: AWS DataSync TaskArn to use. If None, then this operator will
attempt to either search for an existing Task or attempt to create a new Task.
:param source_location_uri: Source location URI to search for. All DataSync
@@ -122,16 +122,17 @@ def __init__(
aws_conn_id: str = "aws_default",
wait_interval_seconds: int = 30,
max_iterations: int = 60,
- task_arn: Optional[str] = None,
- source_location_uri: Optional[str] = None,
- destination_location_uri: Optional[str] = None,
+ wait_for_completion: bool = True,
+ task_arn: str | None = None,
+ source_location_uri: str | None = None,
+ destination_location_uri: str | None = None,
allow_random_task_choice: bool = False,
allow_random_location_choice: bool = False,
- create_task_kwargs: Optional[dict] = None,
- create_source_location_kwargs: Optional[dict] = None,
- create_destination_location_kwargs: Optional[dict] = None,
- update_task_kwargs: Optional[dict] = None,
- task_execution_kwargs: Optional[dict] = None,
+ create_task_kwargs: dict | None = None,
+ create_source_location_kwargs: dict | None = None,
+ create_destination_location_kwargs: dict | None = None,
+ update_task_kwargs: dict | None = None,
+ task_execution_kwargs: dict | None = None,
delete_task_after_execution: bool = False,
**kwargs,
):
@@ -141,6 +142,7 @@ def __init__(
self.aws_conn_id = aws_conn_id
self.wait_interval_seconds = wait_interval_seconds
self.max_iterations = max_iterations
+ self.wait_for_completion = wait_for_completion
self.task_arn = task_arn
@@ -175,16 +177,16 @@ def __init__(
)
# Others
- self.hook: Optional[DataSyncHook] = None
+ self.hook: DataSyncHook | None = None
# Candidates - these are found in AWS as possible things
# for us to use
- self.candidate_source_location_arns: Optional[List[str]] = None
- self.candidate_destination_location_arns: Optional[List[str]] = None
- self.candidate_task_arns: Optional[List[str]] = None
+ self.candidate_source_location_arns: list[str] | None = None
+ self.candidate_destination_location_arns: list[str] | None = None
+ self.candidate_task_arns: list[str] | None = None
# Actuals
- self.source_location_arn: Optional[str] = None
- self.destination_location_arn: Optional[str] = None
- self.task_execution_arn: Optional[str] = None
+ self.source_location_arn: str | None = None
+ self.destination_location_arn: str | None = None
+ self.task_execution_arn: str | None = None
def get_hook(self) -> DataSyncHook:
"""Create and return DataSyncHook.
@@ -200,7 +202,7 @@ def get_hook(self) -> DataSyncHook:
)
return self.hook
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
# If task_arn was not specified then try to
# find 0, 1 or many candidate DataSync Tasks to run
if not self.task_arn:
@@ -258,7 +260,7 @@ def _get_tasks_and_locations(self) -> None:
)
self.log.info("Found candidate DataSync TaskArns %s", self.candidate_task_arns)
- def choose_task(self, task_arn_list: list) -> Optional[str]:
+ def choose_task(self, task_arn_list: list) -> str | None:
"""Select 1 DataSync TaskArn from a list"""
if not task_arn_list:
return None
@@ -272,7 +274,7 @@ def choose_task(self, task_arn_list: list) -> Optional[str]:
return random.choice(task_arn_list)
raise AirflowException(f"Unable to choose a Task from {task_arn_list}")
- def choose_location(self, location_arn_list: Optional[List[str]]) -> Optional[str]:
+ def choose_location(self, location_arn_list: list[str] | None) -> str | None:
"""Select 1 DataSync LocationArn from a list"""
if not location_arn_list:
return None
@@ -292,7 +294,7 @@ def _create_datasync_task(self) -> None:
self.source_location_arn = self.choose_location(self.candidate_source_location_arns)
if not self.source_location_arn and self.source_location_uri and self.create_source_location_kwargs:
- self.log.info('Attempting to create source Location')
+ self.log.info("Attempting to create source Location")
self.source_location_arn = hook.create_location(
self.source_location_uri, **self.create_source_location_kwargs
)
@@ -307,7 +309,7 @@ def _create_datasync_task(self) -> None:
and self.destination_location_uri
and self.create_destination_location_kwargs
):
- self.log.info('Attempting to create destination Location')
+ self.log.info("Attempting to create destination Location")
self.destination_location_arn = hook.create_location(
self.destination_location_uri, **self.create_destination_location_kwargs
)
@@ -346,12 +348,15 @@ def _execute_datasync_task(self) -> None:
self.task_execution_arn = hook.start_task_execution(self.task_arn, **self.task_execution_kwargs)
self.log.info("Started TaskExecutionArn %s", self.task_execution_arn)
+ if not self.wait_for_completion:
+ return
+
# Wait for task execution to complete
self.log.info("Waiting for TaskExecutionArn %s", self.task_execution_arn)
try:
result = hook.wait_for_task_execution(self.task_execution_arn, max_iterations=self.max_iterations)
except (AirflowTaskTimeout, AirflowException) as e:
- self.log.error('Cancelling TaskExecution after Exception: %s', e)
+ self.log.error("Cancelling TaskExecution after Exception: %s", e)
self._cancel_datasync_task_execution()
raise
self.log.info("Completed TaskExecutionArn %s", self.task_execution_arn)
@@ -361,11 +366,11 @@ def _execute_datasync_task(self) -> None:
# Log some meaningful statuses
level = logging.ERROR if not result else logging.INFO
- self.log.log(level, 'Status=%s', task_execution_description['Status'])
- if 'Result' in task_execution_description:
- for k, v in task_execution_description['Result'].items():
- if 'Status' in k or 'Error' in k:
- self.log.log(level, '%s=%s', k, v)
+ self.log.log(level, "Status=%s", task_execution_description["Status"])
+ if "Result" in task_execution_description:
+ for k, v in task_execution_description["Result"].items():
+ if "Status" in k or "Error" in k:
+ self.log.log(level, "%s=%s", k, v)
if not result:
raise AirflowException(f"Failed TaskExecutionArn {self.task_execution_arn}")
@@ -379,7 +384,7 @@ def _cancel_datasync_task_execution(self):
self.log.info("Cancelled TaskExecutionArn %s", self.task_execution_arn)
def on_kill(self):
- self.log.error('Cancelling TaskExecution after task was killed')
+ self.log.error("Cancelling TaskExecution after task was killed")
self._cancel_datasync_task_execution()
def _delete_datasync_task(self) -> None:
@@ -393,23 +398,7 @@ def _delete_datasync_task(self) -> None:
hook.delete_task(self.task_arn)
self.log.info("Task Deleted")
- def _get_location_arns(self, location_uri) -> List[str]:
+ def _get_location_arns(self, location_uri) -> list[str]:
location_arns = self.get_hook().get_location_arns(location_uri)
self.log.info("Found LocationArns %s for LocationUri %s", location_arns, location_uri)
return location_arns
-
-
-class AWSDataSyncOperator(DataSyncOperator):
- """
- This operator is deprecated.
- Please use :class:`airflow.providers.amazon.aws.operators.datasync.DataSyncOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- "This operator is deprecated. Please use "
- "`airflow.providers.amazon.aws.operators.datasync.DataSyncHook`.",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/providers/amazon/aws/operators/dms.py b/airflow/providers/amazon/aws/operators/dms.py
index aca515cfed3a9..6303afcbaf318 100644
--- a/airflow/providers/amazon/aws/operators/dms.py
+++ b/airflow/providers/amazon/aws/operators/dms.py
@@ -15,7 +15,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import TYPE_CHECKING, Dict, Optional, Sequence
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Sequence
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.dms import DmsHook
@@ -47,13 +49,13 @@ class DmsCreateTaskOperator(BaseOperator):
"""
template_fields: Sequence[str] = (
- 'replication_task_id',
- 'source_endpoint_arn',
- 'target_endpoint_arn',
- 'replication_instance_arn',
- 'table_mappings',
- 'migration_type',
- 'create_task_kwargs',
+ "replication_task_id",
+ "source_endpoint_arn",
+ "target_endpoint_arn",
+ "replication_instance_arn",
+ "table_mappings",
+ "migration_type",
+ "create_task_kwargs",
)
template_ext: Sequence[str] = ()
template_fields_renderers = {
@@ -69,9 +71,9 @@ def __init__(
target_endpoint_arn: str,
replication_instance_arn: str,
table_mappings: dict,
- migration_type: str = 'full-load',
- create_task_kwargs: Optional[dict] = None,
- aws_conn_id: str = 'aws_default',
+ migration_type: str = "full-load",
+ create_task_kwargs: dict | None = None,
+ aws_conn_id: str = "aws_default",
**kwargs,
):
super().__init__(**kwargs)
@@ -84,7 +86,7 @@ def __init__(
self.create_task_kwargs = create_task_kwargs or {}
self.aws_conn_id = aws_conn_id
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
"""
Creates AWS DMS replication task from Airflow
@@ -122,22 +124,22 @@ class DmsDeleteTaskOperator(BaseOperator):
maintained on each worker node).
"""
- template_fields: Sequence[str] = ('replication_task_arn',)
+ template_fields: Sequence[str] = ("replication_task_arn",)
template_ext: Sequence[str] = ()
- template_fields_renderers: Dict[str, str] = {}
+ template_fields_renderers: dict[str, str] = {}
def __init__(
self,
*,
- replication_task_arn: Optional[str] = None,
- aws_conn_id: str = 'aws_default',
+ replication_task_arn: str | None = None,
+ aws_conn_id: str = "aws_default",
**kwargs,
):
super().__init__(**kwargs)
self.replication_task_arn = replication_task_arn
self.aws_conn_id = aws_conn_id
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
"""
Deletes AWS DMS replication task from Airflow
@@ -164,27 +166,26 @@ class DmsDescribeTasksOperator(BaseOperator):
maintained on each worker node).
"""
- template_fields: Sequence[str] = ('describe_tasks_kwargs',)
+ template_fields: Sequence[str] = ("describe_tasks_kwargs",)
template_ext: Sequence[str] = ()
- template_fields_renderers: Dict[str, str] = {'describe_tasks_kwargs': 'json'}
+ template_fields_renderers: dict[str, str] = {"describe_tasks_kwargs": "json"}
def __init__(
self,
*,
- describe_tasks_kwargs: Optional[dict] = None,
- aws_conn_id: str = 'aws_default',
+ describe_tasks_kwargs: dict | None = None,
+ aws_conn_id: str = "aws_default",
**kwargs,
):
super().__init__(**kwargs)
self.describe_tasks_kwargs = describe_tasks_kwargs or {}
self.aws_conn_id = aws_conn_id
- def execute(self, context: 'Context'):
+ def execute(self, context: Context) -> tuple[str | None, list]:
"""
Describes AWS DMS replication tasks from Airflow
:return: Marker and list of replication tasks
- :rtype: (Optional[str], list)
"""
dms_hook = DmsHook(aws_conn_id=self.aws_conn_id)
return dms_hook.describe_replication_tasks(**self.describe_tasks_kwargs)
@@ -210,20 +211,20 @@ class DmsStartTaskOperator(BaseOperator):
"""
template_fields: Sequence[str] = (
- 'replication_task_arn',
- 'start_replication_task_type',
- 'start_task_kwargs',
+ "replication_task_arn",
+ "start_replication_task_type",
+ "start_task_kwargs",
)
template_ext: Sequence[str] = ()
- template_fields_renderers = {'start_task_kwargs': 'json'}
+ template_fields_renderers = {"start_task_kwargs": "json"}
def __init__(
self,
*,
replication_task_arn: str,
- start_replication_task_type: str = 'start-replication',
- start_task_kwargs: Optional[dict] = None,
- aws_conn_id: str = 'aws_default',
+ start_replication_task_type: str = "start-replication",
+ start_task_kwargs: dict | None = None,
+ aws_conn_id: str = "aws_default",
**kwargs,
):
super().__init__(**kwargs)
@@ -232,7 +233,7 @@ def __init__(
self.start_task_kwargs = start_task_kwargs or {}
self.aws_conn_id = aws_conn_id
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
"""
Starts AWS DMS replication task from Airflow
@@ -264,22 +265,22 @@ class DmsStopTaskOperator(BaseOperator):
maintained on each worker node).
"""
- template_fields: Sequence[str] = ('replication_task_arn',)
+ template_fields: Sequence[str] = ("replication_task_arn",)
template_ext: Sequence[str] = ()
- template_fields_renderers: Dict[str, str] = {}
+ template_fields_renderers: dict[str, str] = {}
def __init__(
self,
*,
- replication_task_arn: Optional[str] = None,
- aws_conn_id: str = 'aws_default',
+ replication_task_arn: str | None = None,
+ aws_conn_id: str = "aws_default",
**kwargs,
):
super().__init__(**kwargs)
self.replication_task_arn = replication_task_arn
self.aws_conn_id = aws_conn_id
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
"""
Stops AWS DMS replication task from Airflow
diff --git a/airflow/providers/amazon/aws/operators/dms_create_task.py b/airflow/providers/amazon/aws/operators/dms_create_task.py
deleted file mode 100644
index 0443772dc6aee..0000000000000
--- a/airflow/providers/amazon/aws/operators/dms_create_task.py
+++ /dev/null
@@ -1,29 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.dms`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.operators.dms import DmsCreateTaskOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.dms`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/providers/amazon/aws/operators/dms_delete_task.py b/airflow/providers/amazon/aws/operators/dms_delete_task.py
deleted file mode 100644
index ddb929efce1ce..0000000000000
--- a/airflow/providers/amazon/aws/operators/dms_delete_task.py
+++ /dev/null
@@ -1,29 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.dms`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.operators.dms import DmsDeleteTaskOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.dms`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/providers/amazon/aws/operators/dms_describe_tasks.py b/airflow/providers/amazon/aws/operators/dms_describe_tasks.py
deleted file mode 100644
index 022c028c89b96..0000000000000
--- a/airflow/providers/amazon/aws/operators/dms_describe_tasks.py
+++ /dev/null
@@ -1,29 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.dms`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.operators.dms import DmsDescribeTasksOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.dms`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/providers/amazon/aws/operators/dms_start_task.py b/airflow/providers/amazon/aws/operators/dms_start_task.py
deleted file mode 100644
index ba78ea60bed01..0000000000000
--- a/airflow/providers/amazon/aws/operators/dms_start_task.py
+++ /dev/null
@@ -1,29 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.dms`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.operators.dms import DmsStartTaskOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.dms`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/providers/amazon/aws/operators/dms_stop_task.py b/airflow/providers/amazon/aws/operators/dms_stop_task.py
deleted file mode 100644
index c4ad5f3a4a7cf..0000000000000
--- a/airflow/providers/amazon/aws/operators/dms_stop_task.py
+++ /dev/null
@@ -1,29 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.dms`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.operators.dms import DmsStopTaskOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.dms`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/providers/amazon/aws/operators/ec2.py b/airflow/providers/amazon/aws/operators/ec2.py
index 133596929e423..60cb43a32dd09 100644
--- a/airflow/providers/amazon/aws/operators/ec2.py
+++ b/airflow/providers/amazon/aws/operators/ec2.py
@@ -15,9 +15,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-#
+from __future__ import annotations
-from typing import TYPE_CHECKING, Optional, Sequence
+from typing import TYPE_CHECKING, Sequence
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook
@@ -50,7 +50,7 @@ def __init__(
*,
instance_id: str,
aws_conn_id: str = "aws_default",
- region_name: Optional[str] = None,
+ region_name: str | None = None,
check_interval: float = 15,
**kwargs,
):
@@ -60,7 +60,7 @@ def __init__(
self.region_name = region_name
self.check_interval = check_interval
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
ec2_hook = EC2Hook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
self.log.info("Starting EC2 instance %s", self.instance_id)
instance = ec2_hook.get_instance(instance_id=self.instance_id)
@@ -96,7 +96,7 @@ def __init__(
*,
instance_id: str,
aws_conn_id: str = "aws_default",
- region_name: Optional[str] = None,
+ region_name: str | None = None,
check_interval: float = 15,
**kwargs,
):
@@ -106,7 +106,7 @@ def __init__(
self.region_name = region_name
self.check_interval = check_interval
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
ec2_hook = EC2Hook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
self.log.info("Stopping EC2 instance %s", self.instance_id)
instance = ec2_hook.get_instance(instance_id=self.instance_id)
diff --git a/airflow/providers/amazon/aws/operators/ec2_start_instance.py b/airflow/providers/amazon/aws/operators/ec2_start_instance.py
deleted file mode 100644
index c2c25e5708b0a..0000000000000
--- a/airflow/providers/amazon/aws/operators/ec2_start_instance.py
+++ /dev/null
@@ -1,29 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.ec2`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.operators.ec2 import EC2StartInstanceOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.ec2`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/providers/amazon/aws/operators/ec2_stop_instance.py b/airflow/providers/amazon/aws/operators/ec2_stop_instance.py
deleted file mode 100644
index ddafa21c5bd5e..0000000000000
--- a/airflow/providers/amazon/aws/operators/ec2_stop_instance.py
+++ /dev/null
@@ -1,29 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.ec2`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.operators.ec2 import EC2StopInstanceOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.ec2`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/providers/amazon/aws/operators/ecs.py b/airflow/providers/amazon/aws/operators/ecs.py
index d1112edf445ee..0652589ea55bf 100644
--- a/airflow/providers/amazon/aws/operators/ecs.py
+++ b/airflow/providers/amazon/aws/operators/ecs.py
@@ -15,163 +15,261 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
import re
import sys
-import time
import warnings
-from collections import deque
-from datetime import datetime, timedelta
-from logging import Logger
-from threading import Event, Thread
-from typing import Dict, Generator, Optional, Sequence
+from datetime import timedelta
+from typing import TYPE_CHECKING, Sequence
-from botocore.exceptions import ClientError, ConnectionClosedError
-from botocore.waiter import Waiter
+import boto3
+from airflow.compat.functools import cached_property
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator, XCom
from airflow.providers.amazon.aws.exceptions import EcsOperatorError, EcsTaskFailToStart
+
+# TODO: Remove the following import when EcsProtocol and EcsTaskLogFetcher deprecations are removed.
+from airflow.providers.amazon.aws.hooks import ecs
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
-from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
-from airflow.typing_compat import Protocol, runtime_checkable
+from airflow.providers.amazon.aws.hooks.ecs import (
+ EcsClusterStates,
+ EcsHook,
+ EcsTaskDefinitionStates,
+ should_retry_eni,
+)
+from airflow.providers.amazon.aws.sensors.ecs import EcsClusterStateSensor, EcsTaskDefinitionStateSensor
from airflow.utils.session import provide_session
+if TYPE_CHECKING:
+ from airflow.utils.context import Context
-def should_retry(exception: Exception):
- """Check if exception is related to ECS resource quota (CPU, MEM)."""
- if isinstance(exception, EcsOperatorError):
- return any(
- quota_reason in failure['reason']
- for quota_reason in ['RESOURCE:MEMORY', 'RESOURCE:CPU']
- for failure in exception.failures
- )
- return False
+DEFAULT_CONN_ID = "aws_default"
-def should_retry_eni(exception: Exception):
- """Check if exception is related to ENI (Elastic Network Interfaces)."""
- if isinstance(exception, EcsTaskFailToStart):
- return any(
- eni_reason in exception.message
- for eni_reason in ['network interface provisioning', 'ResourceInitializationError']
- )
- return False
+class EcsBaseOperator(BaseOperator):
+ """This is the base operator for all Elastic Container Service operators."""
+
+ def __init__(self, *, aws_conn_id: str | None = DEFAULT_CONN_ID, region: str | None = None, **kwargs):
+ self.aws_conn_id = aws_conn_id
+ self.region = region
+ super().__init__(**kwargs)
+
+ @cached_property
+ def hook(self) -> EcsHook:
+ """Create and return an EcsHook."""
+ return EcsHook(aws_conn_id=self.aws_conn_id, region_name=self.region)
+ @cached_property
+ def client(self) -> boto3.client:
+ """Create and return the EcsHook's client."""
+ return self.hook.conn
-@runtime_checkable
-class EcsProtocol(Protocol):
+ def execute(self, context: Context):
+ """Must overwrite in child classes."""
+ raise NotImplementedError("Please implement execute() in subclass")
+
+
+class EcsCreateClusterOperator(EcsBaseOperator):
"""
- A structured Protocol for ``boto3.client('ecs')``. This is used for type hints on
- :py:meth:`.EcsOperator.client`.
+ Creates an AWS ECS cluster.
.. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:EcsCreateClusterOperator`
- - https://mypy.readthedocs.io/en/latest/protocols.html
- - https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html
+ :param cluster_name: The name of your cluster. If you don't specify a name for your
+ cluster, you create a cluster that's named default.
+ :param create_cluster_kwargs: Extra arguments for Cluster Creation.
+ :param wait_for_completion: If True, waits for creation of the cluster to complete. (default: True)
"""
- def run_task(self, **kwargs) -> Dict:
- """https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.run_task""" # noqa: E501
- ...
-
- def get_waiter(self, x: str) -> Waiter:
- """https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.get_waiter""" # noqa: E501
- ...
+ template_fields: Sequence[str] = ("cluster_name", "create_cluster_kwargs", "wait_for_completion")
- def describe_tasks(self, cluster: str, tasks) -> Dict:
- """https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.describe_tasks""" # noqa: E501
- ...
+ def __init__(
+ self,
+ *,
+ cluster_name: str,
+ create_cluster_kwargs: dict | None = None,
+ wait_for_completion: bool = True,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.cluster_name = cluster_name
+ self.create_cluster_kwargs = create_cluster_kwargs or {}
+ self.wait_for_completion = wait_for_completion
- def stop_task(self, cluster, task, reason: str) -> Dict:
- """https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.stop_task""" # noqa: E501
- ...
+ def execute(self, context: Context):
+ self.log.info(
+ "Creating cluster %s using the following values: %s",
+ self.cluster_name,
+ self.create_cluster_kwargs,
+ )
+ result = self.client.create_cluster(clusterName=self.cluster_name, **self.create_cluster_kwargs)
- def describe_task_definition(self, taskDefinition: str) -> Dict:
- """https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.describe_task_definition""" # noqa: E501
- ...
+ if self.wait_for_completion:
+ while not EcsClusterStateSensor(
+ task_id="await_cluster",
+ cluster_name=self.cluster_name,
+ ).poke(context):
+ # The sensor has a built-in delay and will try again until
+ # the cluster is ready or has reached a failed state.
+ pass
- def list_tasks(self, cluster: str, launchType: str, desiredStatus: str, family: str) -> Dict:
- """https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.list_tasks""" # noqa: E501
- ...
+ return result["cluster"]
-class EcsTaskLogFetcher(Thread):
+class EcsDeleteClusterOperator(EcsBaseOperator):
"""
- Fetches Cloudwatch log events with specific interval as a thread
- and sends the log events to the info channel of the provided logger.
+ Deletes an AWS ECS cluster.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:EcsDeleteClusterOperator`
+
+ :param cluster_name: The short name or full Amazon Resource Name (ARN) of the cluster to delete.
+ :param wait_for_completion: If True, waits for creation of the cluster to complete. (default: True)
"""
+ template_fields: Sequence[str] = ("cluster_name", "wait_for_completion")
+
def __init__(
self,
*,
- aws_conn_id: Optional[str] = 'aws_default',
- region_name: Optional[str] = None,
- log_group: str,
- log_stream_name: str,
- fetch_interval: timedelta,
- logger: Logger,
- ):
- super().__init__()
- self._event = Event()
+ cluster_name: str,
+ wait_for_completion: bool = True,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.cluster_name = cluster_name
+ self.wait_for_completion = wait_for_completion
+
+ def execute(self, context: Context):
+ self.log.info("Deleting cluster %s.", self.cluster_name)
+ result = self.client.delete_cluster(cluster=self.cluster_name)
+
+ if self.wait_for_completion:
+ while not EcsClusterStateSensor(
+ task_id="await_cluster_delete",
+ cluster_name=self.cluster_name,
+ target_state=EcsClusterStates.INACTIVE,
+ failure_states={EcsClusterStates.FAILED},
+ ).poke(context):
+ # The sensor has a built-in delay and will try again until
+ # the cluster is deleted or reaches a failed state.
+ pass
+
+ return result["cluster"]
+
+
+class EcsDeregisterTaskDefinitionOperator(EcsBaseOperator):
+ """
+ Deregister a task definition on AWS ECS.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:EcsDeregisterTaskDefinitionOperator`
+
+ :param task_definition: The family and revision (family:revision) or full Amazon Resource Name (ARN)
+ of the task definition to deregister. If you use a family name, you must specify a revision.
+ :param wait_for_completion: If True, waits for creation of the cluster to complete. (default: True)
+ """
+
+ template_fields: Sequence[str] = ("task_definition", "wait_for_completion")
+
+ def __init__(self, *, task_definition: str, wait_for_completion: bool = True, **kwargs):
+ super().__init__(**kwargs)
+ self.task_definition = task_definition
+ self.wait_for_completion = wait_for_completion
- self.fetch_interval = fetch_interval
+ def execute(self, context: Context):
+ self.log.info("Deregistering task definition %s.", self.task_definition)
+ result = self.client.deregister_task_definition(taskDefinition=self.task_definition)
- self.logger = logger
- self.log_group = log_group
- self.log_stream_name = log_stream_name
+ if self.wait_for_completion:
+ while not EcsTaskDefinitionStateSensor(
+ task_id="await_deregister_task_definition",
+ task_definition=self.task_definition,
+ target_state=EcsTaskDefinitionStates.INACTIVE,
+ ).poke(context):
+ # The sensor has a built-in delay and will try again until the
+ # task definition is deregistered or reaches a failed state.
+ pass
- self.hook = AwsLogsHook(aws_conn_id=aws_conn_id, region_name=region_name)
+ return result["taskDefinition"]["taskDefinitionArn"]
- def run(self) -> None:
- logs_to_skip = 0
- while not self.is_stopped():
- time.sleep(self.fetch_interval.total_seconds())
- log_events = self._get_log_events(logs_to_skip)
- for log_event in log_events:
- self.logger.info(self._event_to_str(log_event))
- logs_to_skip += 1
- def _get_log_events(self, skip: int = 0) -> Generator:
- try:
- yield from self.hook.get_log_events(self.log_group, self.log_stream_name, skip=skip)
- except ClientError as error:
- if error.response['Error']['Code'] != 'ResourceNotFoundException':
- self.logger.warning('Error on retrieving Cloudwatch log events', error)
+class EcsRegisterTaskDefinitionOperator(EcsBaseOperator):
+ """
+ Register a task definition on AWS ECS.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:EcsRegisterTaskDefinitionOperator`
- yield from ()
- except ConnectionClosedError as error:
- self.logger.warning('ConnectionClosedError on retrieving Cloudwatch log events', error)
- yield from ()
+ :param family: The family name of a task definition to create.
+ :param container_definitions: A list of container definitions in JSON format that describe
+ the different containers that make up your task.
+ :param register_task_kwargs: Extra arguments for Register Task Definition.
+ :param wait_for_completion: If True, waits for creation of the cluster to complete. (default: True)
+ """
- def _event_to_str(self, event: dict) -> str:
- event_dt = datetime.utcfromtimestamp(event['timestamp'] / 1000.0)
- formatted_event_dt = event_dt.strftime('%Y-%m-%d %H:%M:%S,%f')[:-3]
- message = event['message']
- return f'[{formatted_event_dt}] {message}'
+ template_fields: Sequence[str] = (
+ "family",
+ "container_definitions",
+ "register_task_kwargs",
+ "wait_for_completion",
+ )
- def get_last_log_messages(self, number_messages) -> list:
- return [log['message'] for log in deque(self._get_log_events(), maxlen=number_messages)]
+ def __init__(
+ self,
+ *,
+ family: str,
+ container_definitions: list[dict],
+ register_task_kwargs: dict | None = None,
+ wait_for_completion: bool = True,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.family = family
+ self.container_definitions = container_definitions
+ self.register_task_kwargs = register_task_kwargs or {}
+ self.wait_for_completion = wait_for_completion
- def get_last_log_message(self) -> Optional[str]:
- try:
- return self.get_last_log_messages(1)[0]
- except IndexError:
- return None
+ def execute(self, context: Context):
+ self.log.info(
+ "Registering task definition %s using the following values: %s",
+ self.family,
+ self.register_task_kwargs,
+ )
+ self.log.info("Using container definition %s", self.container_definitions)
+ response = self.client.register_task_definition(
+ family=self.family,
+ containerDefinitions=self.container_definitions,
+ **self.register_task_kwargs,
+ )
+ task_arn = response["taskDefinition"]["taskDefinitionArn"]
- def is_stopped(self) -> bool:
- return self._event.is_set()
+ if self.wait_for_completion:
+ while not EcsTaskDefinitionStateSensor(
+ task_id="await_register_task_definition", task_definition=task_arn
+ ).poke(context):
+ # The sensor has a built-in delay and will try again until
+ # the task definition is registered or reaches a failed state.
+ pass
- def stop(self):
- self._event.set()
+ context["ti"].xcom_push(key="task_definition_arn", value=task_arn)
+ return task_arn
-class EcsOperator(BaseOperator):
+class EcsRunTaskOperator(EcsBaseOperator):
"""
Execute a task on AWS ECS (Elastic Container Service)
.. seealso::
For more information on how to use this operator, take a look at the guide:
- :ref:`howto/operator:EcsOperator`
+ :ref:`howto/operator:EcsRunTaskOperator`
:param task_definition: the task definition name on Elastic Container Service
:param cluster: the cluster name on Elastic Container Service
@@ -179,7 +277,7 @@ class EcsOperator(BaseOperator):
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.run_task
:param aws_conn_id: connection id of AWS credentials / region name. If None,
credential boto3 strategy will be used
- (http://boto3.readthedocs.io/en/latest/guide/configuration.html).
+ (https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html).
:param region_name: region name to use in AWS Hook.
Override the region_name in connection (if provided)
:param launch_type: the launch type on which to run your task ('EC2', 'EXTERNAL', or 'FARGATE')
@@ -216,15 +314,35 @@ class EcsOperator(BaseOperator):
:param number_logs_exception: Number of lines from the last Cloudwatch logs to return in the
AirflowException if an ECS task is stopped (to receive Airflow alerts with the logs of what
failed in the code running in ECS).
+ :param wait_for_completion: If True, waits for creation of the cluster to complete. (default: True)
"""
- ui_color = '#f0ede4'
- template_fields: Sequence[str] = ('overrides',)
+ ui_color = "#f0ede4"
+ template_fields: Sequence[str] = (
+ "task_definition",
+ "cluster",
+ "overrides",
+ "launch_type",
+ "capacity_provider_strategy",
+ "group",
+ "placement_constraints",
+ "placement_strategy",
+ "platform_version",
+ "network_configuration",
+ "tags",
+ "awslogs_group",
+ "awslogs_region",
+ "awslogs_stream_prefix",
+ "awslogs_fetch_interval",
+ "propagate_tags",
+ "reattach",
+ "number_logs_exception",
+ "wait_for_completion",
+ )
template_fields_renderers = {
"overrides": "json",
"network_configuration": "json",
"tags": "json",
- "quota_retry": "json",
}
REATTACH_XCOM_KEY = "ecs_task_arn"
REATTACH_XCOM_TASK_ID_TEMPLATE = "{task_id}_task_arn"
@@ -235,30 +353,27 @@ def __init__(
task_definition: str,
cluster: str,
overrides: dict,
- aws_conn_id: Optional[str] = None,
- region_name: Optional[str] = None,
- launch_type: str = 'EC2',
- capacity_provider_strategy: Optional[list] = None,
- group: Optional[str] = None,
- placement_constraints: Optional[list] = None,
- placement_strategy: Optional[list] = None,
- platform_version: Optional[str] = None,
- network_configuration: Optional[dict] = None,
- tags: Optional[dict] = None,
- awslogs_group: Optional[str] = None,
- awslogs_region: Optional[str] = None,
- awslogs_stream_prefix: Optional[str] = None,
+ launch_type: str = "EC2",
+ capacity_provider_strategy: list | None = None,
+ group: str | None = None,
+ placement_constraints: list | None = None,
+ placement_strategy: list | None = None,
+ platform_version: str | None = None,
+ network_configuration: dict | None = None,
+ tags: dict | None = None,
+ awslogs_group: str | None = None,
+ awslogs_region: str | None = None,
+ awslogs_stream_prefix: str | None = None,
awslogs_fetch_interval: timedelta = timedelta(seconds=30),
- propagate_tags: Optional[str] = None,
- quota_retry: Optional[dict] = None,
+ propagate_tags: str | None = None,
+ quota_retry: dict | None = None,
reattach: bool = False,
number_logs_exception: int = 10,
+ wait_for_completion: bool = True,
**kwargs,
):
super().__init__(**kwargs)
- self.aws_conn_id = aws_conn_id
- self.region_name = region_name
self.task_definition = task_definition
self.cluster = cluster
self.overrides = overrides
@@ -280,29 +395,26 @@ def __init__(
self.number_logs_exception = number_logs_exception
if self.awslogs_region is None:
- self.awslogs_region = region_name
+ self.awslogs_region = self.region
- self.hook: Optional[AwsBaseHook] = None
- self.client: Optional[EcsProtocol] = None
- self.arn: Optional[str] = None
+ self.arn: str | None = None
self.retry_args = quota_retry
- self.task_log_fetcher: Optional[EcsTaskLogFetcher] = None
+ self.task_log_fetcher: EcsTaskLogFetcher | None = None
+ self.wait_for_completion = wait_for_completion
@provide_session
def execute(self, context, session=None):
self.log.info(
- 'Running ECS Task - Task definition: %s - on cluster %s', self.task_definition, self.cluster
+ "Running ECS Task - Task definition: %s - on cluster %s", self.task_definition, self.cluster
)
- self.log.info('EcsOperator overrides: %s', self.overrides)
-
- self.client = self.get_hook().get_conn()
+ self.log.info("EcsOperator overrides: %s", self.overrides)
if self.reattach:
self._try_reattach_task(context)
self._start_wait_check_task(context)
- self.log.info('ECS Task has been successfully executed')
+ self.log.info("ECS Task has been successfully executed")
if self.reattach:
# Clear the XCom value storing the ECS task ARN if the task has completed
@@ -321,18 +433,20 @@ def _start_wait_check_task(self, context):
self._start_task(context)
if self._aws_logs_enabled():
- self.log.info('Starting ECS Task Log Fetcher')
+ self.log.info("Starting ECS Task Log Fetcher")
self.task_log_fetcher = self._get_task_log_fetcher()
self.task_log_fetcher.start()
try:
- self._wait_for_task_ended()
+ if self.wait_for_completion:
+ self._wait_for_task_ended()
finally:
self.task_log_fetcher.stop()
self.task_log_fetcher.join()
else:
- self._wait_for_task_ended()
+ if self.wait_for_completion:
+ self._wait_for_task_ended()
self._check_success_task()
@@ -341,68 +455,54 @@ def _xcom_del(self, session, task_id):
def _start_task(self, context):
run_opts = {
- 'cluster': self.cluster,
- 'taskDefinition': self.task_definition,
- 'overrides': self.overrides,
- 'startedBy': self.owner,
+ "cluster": self.cluster,
+ "taskDefinition": self.task_definition,
+ "overrides": self.overrides,
+ "startedBy": self.owner,
}
if self.capacity_provider_strategy:
- run_opts['capacityProviderStrategy'] = self.capacity_provider_strategy
+ run_opts["capacityProviderStrategy"] = self.capacity_provider_strategy
elif self.launch_type:
- run_opts['launchType'] = self.launch_type
+ run_opts["launchType"] = self.launch_type
if self.platform_version is not None:
- run_opts['platformVersion'] = self.platform_version
+ run_opts["platformVersion"] = self.platform_version
if self.group is not None:
- run_opts['group'] = self.group
+ run_opts["group"] = self.group
if self.placement_constraints is not None:
- run_opts['placementConstraints'] = self.placement_constraints
+ run_opts["placementConstraints"] = self.placement_constraints
if self.placement_strategy is not None:
- run_opts['placementStrategy'] = self.placement_strategy
+ run_opts["placementStrategy"] = self.placement_strategy
if self.network_configuration is not None:
- run_opts['networkConfiguration'] = self.network_configuration
+ run_opts["networkConfiguration"] = self.network_configuration
if self.tags is not None:
- run_opts['tags'] = [{'key': k, 'value': v} for (k, v) in self.tags.items()]
+ run_opts["tags"] = [{"key": k, "value": v} for (k, v) in self.tags.items()]
if self.propagate_tags is not None:
- run_opts['propagateTags'] = self.propagate_tags
+ run_opts["propagateTags"] = self.propagate_tags
response = self.client.run_task(**run_opts)
- failures = response['failures']
+ failures = response["failures"]
if len(failures) > 0:
raise EcsOperatorError(failures, response)
- self.log.info('ECS Task started: %s', response)
+ self.log.info("ECS Task started: %s", response)
- self.arn = response['tasks'][0]['taskArn']
+ self.arn = response["tasks"][0]["taskArn"]
self.ecs_task_id = self.arn.split("/")[-1]
self.log.info("ECS task ID is: %s", self.ecs_task_id)
if self.reattach:
# Save the task ARN in XCom to be able to reattach it if needed
- self._xcom_set(
- context,
- key=self.REATTACH_XCOM_KEY,
- value=self.arn,
- task_id=self.REATTACH_XCOM_TASK_ID_TEMPLATE.format(task_id=self.task_id),
- )
-
- def _xcom_set(self, context, key, value, task_id):
- XCom.set(
- key=key,
- value=value,
- task_id=task_id,
- dag_id=self.dag_id,
- run_id=context["run_id"],
- )
+ self.xcom_push(context, key=self.REATTACH_XCOM_KEY, value=self.arn)
def _try_reattach_task(self, context):
task_def_resp = self.client.describe_task_definition(taskDefinition=self.task_definition)
- ecs_task_family = task_def_resp['taskDefinition']['family']
+ ecs_task_family = task_def_resp["taskDefinition"]["family"]
list_tasks_resp = self.client.list_tasks(
- cluster=self.cluster, desiredStatus='RUNNING', family=ecs_task_family
+ cluster=self.cluster, desiredStatus="RUNNING", family=ecs_task_family
)
- running_tasks = list_tasks_resp['taskArns']
+ running_tasks = list_tasks_resp["taskArns"]
# Check if the ECS task previously launched is already running
previous_task_arn = self.xcom_pull(
@@ -421,7 +521,7 @@ def _wait_for_task_ended(self) -> None:
if not self.client or not self.arn:
return
- waiter = self.client.get_waiter('tasks_stopped')
+ waiter = self.client.get_waiter("tasks_stopped")
waiter.config.max_attempts = sys.maxsize # timeout is managed by airflow
waiter.wait(cluster=self.cluster, tasks=[self.arn])
@@ -430,12 +530,13 @@ def _wait_for_task_ended(self) -> None:
def _aws_logs_enabled(self):
return self.awslogs_group and self.awslogs_stream_prefix
- def _get_task_log_fetcher(self) -> EcsTaskLogFetcher:
+ # TODO: When the deprecation wrapper below is removed, please fix the following return type hint.
+ def _get_task_log_fetcher(self) -> ecs.EcsTaskLogFetcher:
if not self.awslogs_group:
raise ValueError("must specify awslogs_group to fetch task logs")
log_stream_name = f"{self.awslogs_stream_prefix}/{self.ecs_task_id}"
- return EcsTaskLogFetcher(
+ return ecs.EcsTaskLogFetcher(
aws_conn_id=self.aws_conn_id,
region_name=self.awslogs_region,
log_group=self.awslogs_group,
@@ -449,14 +550,14 @@ def _check_success_task(self) -> None:
return
response = self.client.describe_tasks(cluster=self.cluster, tasks=[self.arn])
- self.log.info('ECS Task stopped, check status: %s', response)
+ self.log.info("ECS Task stopped, check status: %s", response)
- if len(response.get('failures', [])) > 0:
+ if len(response.get("failures", [])) > 0:
raise AirflowException(response)
- for task in response['tasks']:
+ for task in response["tasks"]:
- if task.get('stopCode', '') == 'TaskFailedToStart':
+ if task.get("stopCode", "") == "TaskFailedToStart":
# Reset task arn here otherwise the retry run will not start
# a new task but keep polling the old dead one
# I'm not resetting it for other exceptions here because
@@ -468,14 +569,14 @@ def _check_success_task(self) -> None:
# successfully finished, but there is no other indication of failure
# in the response.
# https://docs.aws.amazon.com/AmazonECS/latest/developerguide/stopped-task-errors.html
- if re.match(r'Host EC2 \(instance .+?\) (stopped|terminated)\.', task.get('stoppedReason', '')):
+ if re.match(r"Host EC2 \(instance .+?\) (stopped|terminated)\.", task.get("stoppedReason", "")):
raise AirflowException(
f"The task was stopped because the host instance terminated:"
f" {task.get('stoppedReason', '')}"
)
- containers = task['containers']
+ containers = task["containers"]
for container in containers:
- if container.get('lastStatus') == 'STOPPED' and container.get('exitCode', 1) != 0:
+ if container.get("lastStatus") == "STOPPED" and container.get("exitCode", 1) != 0:
if self.task_log_fetcher:
last_logs = "\n".join(
self.task_log_fetcher.get_last_log_messages(self.number_logs_exception)
@@ -485,23 +586,15 @@ def _check_success_task(self) -> None:
f"logs from Cloudwatch:\n{last_logs}"
)
else:
- raise AirflowException(f'This task is not in success state {task}')
- elif container.get('lastStatus') == 'PENDING':
- raise AirflowException(f'This task is still pending {task}')
- elif 'error' in container.get('reason', '').lower():
+ raise AirflowException(f"This task is not in success state {task}")
+ elif container.get("lastStatus") == "PENDING":
+ raise AirflowException(f"This task is still pending {task}")
+ elif "error" in container.get("reason", "").lower():
raise AirflowException(
f"This containers encounter an error during launching: "
f"{container.get('reason', '').lower()}"
)
- def get_hook(self) -> AwsBaseHook:
- """Create and return an AwsHook."""
- if self.hook:
- return self.hook
-
- self.hook = AwsBaseHook(aws_conn_id=self.aws_conn_id, client_type='ecs', region_name=self.region_name)
- return self.hook
-
def on_kill(self) -> None:
if not self.client or not self.arn:
return
@@ -510,53 +603,56 @@ def on_kill(self) -> None:
self.task_log_fetcher.stop()
response = self.client.stop_task(
- cluster=self.cluster, task=self.arn, reason='Task killed by the user'
+ cluster=self.cluster, task=self.arn, reason="Task killed by the user"
)
self.log.info(response)
-class ECSOperator(EcsOperator):
+class EcsOperator(EcsRunTaskOperator):
"""
This operator is deprecated.
- Please use :class:`airflow.providers.amazon.aws.operators.ecs.EcsOperator`.
+ Please use :class:`airflow.providers.amazon.aws.operators.ecs.EcsRunTaskOperator`.
"""
def __init__(self, *args, **kwargs):
warnings.warn(
"This operator is deprecated. "
- "Please use `airflow.providers.amazon.aws.operators.ecs.EcsOperator`.",
+ "Please use `airflow.providers.amazon.aws.operators.ecs.EcsRunTaskOperator`.",
DeprecationWarning,
stacklevel=2,
)
super().__init__(*args, **kwargs)
-class ECSTaskLogFetcher(EcsTaskLogFetcher):
+class EcsTaskLogFetcher(ecs.EcsTaskLogFetcher):
"""
This class is deprecated.
- Please use :class:`airflow.providers.amazon.aws.operators.ecs.EcsTaskLogFetcher`.
+ Please use :class:`airflow.providers.amazon.aws.hooks.ecs.EcsTaskLogFetcher`.
"""
+ # TODO: Note to deprecator, Be sure to fix the use of `ecs.EcsTaskLogFetcher`
+ # in the Operators above when you remove this wrapper class.
def __init__(self, *args, **kwargs):
warnings.warn(
"This class is deprecated. "
- "Please use `airflow.providers.amazon.aws.operators.ecs.EcsTaskLogFetcher`.",
+ "Please use `airflow.providers.amazon.aws.hooks.ecs.EcsTaskLogFetcher`.",
DeprecationWarning,
stacklevel=2,
)
super().__init__(*args, **kwargs)
-class ECSProtocol(EcsProtocol):
+class EcsProtocol(ecs.EcsProtocol):
"""
This class is deprecated.
- Please use :class:`airflow.providers.amazon.aws.operators.ecs.EcsProtocol`.
+ Please use :class:`airflow.providers.amazon.aws.hooks.ecs.EcsProtocol`.
"""
+ # TODO: Note to deprecator, Be sure to fix the use of `ecs.EcsProtocol`
+ # in the Operators above when you remove this wrapper class.
def __init__(self):
warnings.warn(
- "This class is deprecated. "
- "Please use `airflow.providers.amazon.aws.operators.ecs.EcsProtocol`.",
+ "This class is deprecated. Please use `airflow.providers.amazon.aws.hooks.ecs.EcsProtocol`.",
DeprecationWarning,
stacklevel=2,
)
diff --git a/airflow/providers/amazon/aws/operators/eks.py b/airflow/providers/amazon/aws/operators/eks.py
index 8e97015d86182..70d74f4cda849 100644
--- a/airflow/providers/amazon/aws/operators/eks.py
+++ b/airflow/providers/amazon/aws/operators/eks.py
@@ -14,12 +14,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""This module contains Amazon EKS operators."""
+from __future__ import annotations
+
import warnings
from ast import literal_eval
from time import sleep
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union, cast
+from typing import TYPE_CHECKING, Any, List, Sequence, cast
from airflow import AirflowException
from airflow.models import BaseOperator
@@ -32,21 +33,20 @@
CHECK_INTERVAL_SECONDS = 15
TIMEOUT_SECONDS = 25 * 60
-DEFAULT_COMPUTE_TYPE = 'nodegroup'
-DEFAULT_CONN_ID = 'aws_default'
-DEFAULT_FARGATE_PROFILE_NAME = 'profile'
-DEFAULT_NAMESPACE_NAME = 'default'
-DEFAULT_NODEGROUP_NAME = 'nodegroup'
-DEFAULT_POD_NAME = 'pod'
+DEFAULT_COMPUTE_TYPE = "nodegroup"
+DEFAULT_CONN_ID = "aws_default"
+DEFAULT_FARGATE_PROFILE_NAME = "profile"
+DEFAULT_NAMESPACE_NAME = "default"
+DEFAULT_NODEGROUP_NAME = "nodegroup"
ABORT_MSG = "{compute} are still active after the allocated time limit. Aborting."
CAN_NOT_DELETE_MSG = "A cluster can not be deleted with attached {compute}. Deleting {count} {compute}."
MISSING_ARN_MSG = "Creating an {compute} requires {requirement} to be passed in."
SUCCESS_MSG = "No {compute} remain, deleting cluster."
-SUPPORTED_COMPUTE_VALUES = frozenset({'nodegroup', 'fargate'})
-NODEGROUP_FULL_NAME = 'Amazon EKS managed node groups'
-FARGATE_FULL_NAME = 'AWS Fargate profiles'
+SUPPORTED_COMPUTE_VALUES = frozenset({"nodegroup", "fargate"})
+NODEGROUP_FULL_NAME = "Amazon EKS managed node groups"
+FARGATE_FULL_NAME = "AWS Fargate profiles"
class EksCreateClusterOperator(BaseOperator):
@@ -125,18 +125,18 @@ def __init__(
self,
cluster_name: str,
cluster_role_arn: str,
- resources_vpc_config: Dict[str, Any],
- compute: Optional[str] = DEFAULT_COMPUTE_TYPE,
- create_cluster_kwargs: Optional[Dict] = None,
+ resources_vpc_config: dict[str, Any],
+ compute: str | None = DEFAULT_COMPUTE_TYPE,
+ create_cluster_kwargs: dict | None = None,
nodegroup_name: str = DEFAULT_NODEGROUP_NAME,
- nodegroup_role_arn: Optional[str] = None,
- create_nodegroup_kwargs: Optional[Dict] = None,
+ nodegroup_role_arn: str | None = None,
+ create_nodegroup_kwargs: dict | None = None,
fargate_profile_name: str = DEFAULT_FARGATE_PROFILE_NAME,
- fargate_pod_execution_role_arn: Optional[str] = None,
- fargate_selectors: Optional[List] = None,
- create_fargate_profile_kwargs: Optional[Dict] = None,
+ fargate_pod_execution_role_arn: str | None = None,
+ fargate_selectors: list | None = None,
+ create_fargate_profile_kwargs: dict | None = None,
aws_conn_id: str = DEFAULT_CONN_ID,
- region: Optional[str] = None,
+ region: str | None = None,
**kwargs,
) -> None:
self.compute = compute
@@ -155,18 +155,18 @@ def __init__(
self.region = region
super().__init__(**kwargs)
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
if self.compute:
if self.compute not in SUPPORTED_COMPUTE_VALUES:
raise ValueError("Provided compute type is not supported.")
- elif (self.compute == 'nodegroup') and not self.nodegroup_role_arn:
+ elif (self.compute == "nodegroup") and not self.nodegroup_role_arn:
raise ValueError(
- MISSING_ARN_MSG.format(compute=NODEGROUP_FULL_NAME, requirement='nodegroup_role_arn')
+ MISSING_ARN_MSG.format(compute=NODEGROUP_FULL_NAME, requirement="nodegroup_role_arn")
)
- elif (self.compute == 'fargate') and not self.fargate_pod_execution_role_arn:
+ elif (self.compute == "fargate") and not self.fargate_pod_execution_role_arn:
raise ValueError(
MISSING_ARN_MSG.format(
- compute=FARGATE_FULL_NAME, requirement='fargate_pod_execution_role_arn'
+ compute=FARGATE_FULL_NAME, requirement="fargate_pod_execution_role_arn"
)
)
@@ -205,15 +205,15 @@ def execute(self, context: 'Context'):
eks_hook.delete_cluster(name=self.cluster_name)
raise RuntimeError(message)
- if self.compute == 'nodegroup':
+ if self.compute == "nodegroup":
eks_hook.create_nodegroup(
clusterName=self.cluster_name,
nodegroupName=self.nodegroup_name,
- subnets=cast(List[str], self.resources_vpc_config.get('subnetIds')),
+ subnets=cast(List[str], self.resources_vpc_config.get("subnetIds")),
nodeRole=self.nodegroup_role_arn,
**self.create_nodegroup_kwargs,
)
- elif self.compute == 'fargate':
+ elif self.compute == "fargate":
eks_hook.create_fargate_profile(
clusterName=self.cluster_name,
fargateProfileName=self.fargate_profile_name,
@@ -261,12 +261,12 @@ class EksCreateNodegroupOperator(BaseOperator):
def __init__(
self,
cluster_name: str,
- nodegroup_subnets: Union[List[str], str],
+ nodegroup_subnets: list[str] | str,
nodegroup_role_arn: str,
nodegroup_name: str = DEFAULT_NODEGROUP_NAME,
- create_nodegroup_kwargs: Optional[Dict] = None,
+ create_nodegroup_kwargs: dict | None = None,
aws_conn_id: str = DEFAULT_CONN_ID,
- region: Optional[str] = None,
+ region: str | None = None,
**kwargs,
) -> None:
self.cluster_name = cluster_name
@@ -278,9 +278,9 @@ def __init__(
self.nodegroup_subnets = nodegroup_subnets
super().__init__(**kwargs)
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
if isinstance(self.nodegroup_subnets, str):
- nodegroup_subnets_list: List[str] = []
+ nodegroup_subnets_list: list[str] = []
if self.nodegroup_subnets != "":
try:
nodegroup_subnets_list = cast(List, literal_eval(self.nodegroup_subnets))
@@ -344,11 +344,11 @@ def __init__(
self,
cluster_name: str,
pod_execution_role_arn: str,
- selectors: List,
- fargate_profile_name: Optional[str] = DEFAULT_FARGATE_PROFILE_NAME,
- create_fargate_profile_kwargs: Optional[Dict] = None,
+ selectors: list,
+ fargate_profile_name: str | None = DEFAULT_FARGATE_PROFILE_NAME,
+ create_fargate_profile_kwargs: dict | None = None,
aws_conn_id: str = DEFAULT_CONN_ID,
- region: Optional[str] = None,
+ region: str | None = None,
**kwargs,
) -> None:
self.cluster_name = cluster_name
@@ -360,7 +360,7 @@ def __init__(
self.region = region
super().__init__(**kwargs)
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
eks_hook = EksHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region,
@@ -408,7 +408,7 @@ def __init__(
cluster_name: str,
force_delete_compute: bool = False,
aws_conn_id: str = DEFAULT_CONN_ID,
- region: Optional[str] = None,
+ region: str | None = None,
**kwargs,
) -> None:
self.cluster_name = cluster_name
@@ -417,7 +417,7 @@ def __init__(
self.region = region
super().__init__(**kwargs)
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
eks_hook = EksHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region,
@@ -528,7 +528,7 @@ def __init__(
cluster_name: str,
nodegroup_name: str,
aws_conn_id: str = DEFAULT_CONN_ID,
- region: Optional[str] = None,
+ region: str | None = None,
**kwargs,
) -> None:
self.cluster_name = cluster_name
@@ -537,7 +537,7 @@ def __init__(
self.region = region
super().__init__(**kwargs)
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
eks_hook = EksHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region,
@@ -577,7 +577,7 @@ def __init__(
cluster_name: str,
fargate_profile_name: str,
aws_conn_id: str = DEFAULT_CONN_ID,
- region: Optional[str] = None,
+ region: str | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -586,7 +586,7 @@ def __init__(
self.aws_conn_id = aws_conn_id
self.region = region
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
eks_hook = EksHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region,
@@ -646,23 +646,14 @@ def __init__(
# file is stored locally in the worker and not in the cluster.
in_cluster: bool = False,
namespace: str = DEFAULT_NAMESPACE_NAME,
- pod_context: Optional[str] = None,
- pod_name: Optional[str] = None,
- pod_username: Optional[str] = None,
+ pod_context: str | None = None,
+ pod_name: str | None = None,
+ pod_username: str | None = None,
aws_conn_id: str = DEFAULT_CONN_ID,
- region: Optional[str] = None,
- is_delete_operator_pod: Optional[bool] = None,
+ region: str | None = None,
+ is_delete_operator_pod: bool | None = None,
**kwargs,
) -> None:
- if pod_name is None:
- warnings.warn(
- "Default value of pod name is deprecated. "
- "We recommend that you pass pod name explicitly. ",
- DeprecationWarning,
- stacklevel=2,
- )
- pod_name = DEFAULT_POD_NAME
-
if is_delete_operator_pod is None:
warnings.warn(
f"You have not set parameter `is_delete_operator_pod` in class {self.__class__.__name__}. "
@@ -687,26 +678,12 @@ def __init__(
is_delete_operator_pod=is_delete_operator_pod,
**kwargs,
)
- if pod_username:
- warnings.warn(
- "This pod_username parameter is deprecated, because changing the value does not make any "
- "visible changes to the user.",
- DeprecationWarning,
- stacklevel=2,
- )
- if pod_context:
- warnings.warn(
- "This pod_context parameter is deprecated, because changing the value does not make any "
- "visible changes to the user.",
- DeprecationWarning,
- stacklevel=2,
- )
# There is no need to manage the kube_config file, as it will be generated automatically.
# All Kubernetes parameters (except config_file) are also valid for the EksPodOperator.
if self.config_file:
raise AirflowException("The config_file is not an allowed parameter for the EksPodOperator.")
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
eks_hook = EksHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region,
@@ -715,115 +692,3 @@ def execute(self, context: 'Context'):
eks_cluster_name=self.cluster_name, pod_namespace=self.namespace
) as self.config_file:
return super().execute(context)
-
-
-class EKSCreateClusterOperator(EksCreateClusterOperator):
- """
- This operator is deprecated.
- Please use :class:`airflow.providers.amazon.aws.operators.eks.EksCreateClusterOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- "This operator is deprecated. "
- "Please use `airflow.providers.amazon.aws.operators.eks.EksCreateClusterOperator`.",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class EKSCreateNodegroupOperator(EksCreateNodegroupOperator):
- """
- This operator is deprecated.
- Please use :class:`airflow.providers.amazon.aws.operators.eks.EksCreateNodegroupOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- "This operator is deprecated. "
- "Please use `airflow.providers.amazon.aws.operators.eks.EksCreateNodegroupOperator`.",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class EKSCreateFargateProfileOperator(EksCreateFargateProfileOperator):
- """
- This operator is deprecated.
- Please use :class:`airflow.providers.amazon.aws.operators.eks.EksCreateFargateProfileOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- "This operator is deprecated. "
- "Please use `airflow.providers.amazon.aws.operators.eks.EksCreateFargateProfileOperator`.",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class EKSDeleteClusterOperator(EksDeleteClusterOperator):
- """
- This operator is deprecated.
- Please use :class:`airflow.providers.amazon.aws.operators.eks.EksDeleteClusterOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- "This operator is deprecated. "
- "Please use `airflow.providers.amazon.aws.operators.eks.EksDeleteClusterOperator`.",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class EKSDeleteNodegroupOperator(EksDeleteNodegroupOperator):
- """
- This operator is deprecated.
- Please use :class:`airflow.providers.amazon.aws.operators.eks.EksDeleteNodegroupOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- "This operator is deprecated. "
- "Please use `airflow.providers.amazon.aws.operators.eks.EksDeleteNodegroupOperator`.",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class EKSDeleteFargateProfileOperator(EksDeleteFargateProfileOperator):
- """
- This operator is deprecated.
- Please use :class:`airflow.providers.amazon.aws.operators.eks.EksDeleteFargateProfileOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- "This operator is deprecated. "
- "Please use `airflow.providers.amazon.aws.operators.eks.EksDeleteFargateProfileOperator`.",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class EKSPodOperator(EksPodOperator):
- """
- This operator is deprecated.
- Please use :class:`airflow.providers.amazon.aws.operators.eks.EksPodOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- "This operator is deprecated. "
- "Please use `airflow.providers.amazon.aws.operators.eks.EksPodOperator`.",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py
index a1f3fa753d817..63659e8188b9d 100644
--- a/airflow/providers/amazon/aws/operators/emr.py
+++ b/airflow/providers/amazon/aws/operators/emr.py
@@ -15,27 +15,22 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
import ast
-import sys
-from datetime import datetime
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
+import warnings
+from typing import TYPE_CHECKING, Any, Sequence
from uuid import uuid4
from airflow.exceptions import AirflowException
-from airflow.models import BaseOperator, BaseOperatorLink, XCom
-from airflow.providers.amazon.aws.hooks.emr import EmrHook
+from airflow.models import BaseOperator
+from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook
+from airflow.providers.amazon.aws.links.emr import EmrClusterLink
if TYPE_CHECKING:
- from airflow.models.taskinstance import TaskInstanceKey
from airflow.utils.context import Context
-
-if sys.version_info >= (3, 8):
- from functools import cached_property
-else:
- from cached_property import cached_property
-
-from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook
+from airflow.compat.functools import cached_property
class EmrAddStepsOperator(BaseOperator):
@@ -55,28 +50,29 @@ class EmrAddStepsOperator(BaseOperator):
:param aws_conn_id: aws connection to uses
:param steps: boto3 style steps or reference to a steps file (must be '.json') to
be added to the jobflow. (templated)
+ :param wait_for_completion: If True, the operator will wait for all the steps to be completed.
:param do_xcom_push: if True, job_flow_id is pushed to XCom with key job_flow_id.
"""
- template_fields: Sequence[str] = ('job_flow_id', 'job_flow_name', 'cluster_states', 'steps')
- template_ext: Sequence[str] = ('.json',)
+ template_fields: Sequence[str] = ("job_flow_id", "job_flow_name", "cluster_states", "steps")
+ template_ext: Sequence[str] = (".json",)
template_fields_renderers = {"steps": "json"}
- ui_color = '#f9c915'
+ ui_color = "#f9c915"
+ operator_extra_links = (EmrClusterLink(),)
def __init__(
self,
*,
- job_flow_id: Optional[str] = None,
- job_flow_name: Optional[str] = None,
- cluster_states: Optional[List[str]] = None,
- aws_conn_id: str = 'aws_default',
- steps: Optional[Union[List[dict], str]] = None,
+ job_flow_id: str | None = None,
+ job_flow_name: str | None = None,
+ cluster_states: list[str] | None = None,
+ aws_conn_id: str = "aws_default",
+ steps: list[dict] | str | None = None,
+ wait_for_completion: bool = False,
**kwargs,
):
- if kwargs.get('xcom_push') is not None:
- raise AirflowException("'xcom_push' was deprecated, use 'do_xcom_push' instead")
if not (job_flow_id is None) ^ (job_flow_name is None):
- raise AirflowException('Exactly one of job_flow_id or job_flow_name must be specified.')
+ raise AirflowException("Exactly one of job_flow_id or job_flow_name must be specified.")
super().__init__(**kwargs)
cluster_states = cluster_states or []
steps = steps or []
@@ -85,23 +81,30 @@ def __init__(
self.job_flow_name = job_flow_name
self.cluster_states = cluster_states
self.steps = steps
+ self.wait_for_completion = wait_for_completion
- def execute(self, context: 'Context') -> List[str]:
+ def execute(self, context: Context) -> list[str]:
emr_hook = EmrHook(aws_conn_id=self.aws_conn_id)
- emr = emr_hook.get_conn()
-
job_flow_id = self.job_flow_id or emr_hook.get_cluster_id_by_name(
str(self.job_flow_name), self.cluster_states
)
if not job_flow_id:
- raise AirflowException(f'No cluster found for name: {self.job_flow_name}')
+ raise AirflowException(f"No cluster found for name: {self.job_flow_name}")
if self.do_xcom_push:
- context['ti'].xcom_push(key='job_flow_id', value=job_flow_id)
+ context["ti"].xcom_push(key="job_flow_id", value=job_flow_id)
+
+ EmrClusterLink.persist(
+ context=context,
+ operator=self,
+ region_name=emr_hook.conn_region_name,
+ aws_partition=emr_hook.conn_partition,
+ job_flow_id=job_flow_id,
+ )
- self.log.info('Adding steps to %s', job_flow_id)
+ self.log.info("Adding steps to %s", job_flow_id)
# steps may arrive as a string representing a list
# e.g. if we used XCom or a file then: steps="[{ step1 }, { step2 }]"
@@ -109,19 +112,73 @@ def execute(self, context: 'Context') -> List[str]:
if isinstance(steps, str):
steps = ast.literal_eval(steps)
- response = emr.add_job_flow_steps(JobFlowId=job_flow_id, Steps=steps)
+ return emr_hook.add_job_flow_steps(job_flow_id=job_flow_id, steps=steps, wait_for_completion=True)
- if not response['ResponseMetadata']['HTTPStatusCode'] == 200:
- raise AirflowException(f'Adding steps failed: {response}')
- else:
- self.log.info('Steps %s added to JobFlow', response['StepIds'])
- return response['StepIds']
+
+class EmrEksCreateClusterOperator(BaseOperator):
+ """
+ An operator that creates EMR on EKS virtual clusters.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:EmrEksCreateClusterOperator`
+
+ :param virtual_cluster_name: The name of the EMR EKS virtual cluster to create.
+ :param eks_cluster_name: The EKS cluster used by the EMR virtual cluster.
+ :param eks_namespace: namespace used by the EKS cluster.
+ :param virtual_cluster_id: The EMR on EKS virtual cluster id.
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ :param tags: The tags assigned to created cluster.
+ Defaults to None
+ """
+
+ template_fields: Sequence[str] = (
+ "virtual_cluster_name",
+ "eks_cluster_name",
+ "eks_namespace",
+ )
+ ui_color = "#f9c915"
+
+ def __init__(
+ self,
+ *,
+ virtual_cluster_name: str,
+ eks_cluster_name: str,
+ eks_namespace: str,
+ virtual_cluster_id: str = "",
+ aws_conn_id: str = "aws_default",
+ tags: dict | None = None,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.virtual_cluster_name = virtual_cluster_name
+ self.eks_cluster_name = eks_cluster_name
+ self.eks_namespace = eks_namespace
+ self.virtual_cluster_id = virtual_cluster_id
+ self.aws_conn_id = aws_conn_id
+ self.tags = tags
+
+ @cached_property
+ def hook(self) -> EmrContainerHook:
+ """Create and return an EmrContainerHook."""
+ return EmrContainerHook(self.aws_conn_id)
+
+ def execute(self, context: Context) -> str | None:
+ """Create EMR on EKS virtual Cluster"""
+ self.virtual_cluster_id = self.hook.create_emr_on_eks_cluster(
+ self.virtual_cluster_name, self.eks_cluster_name, self.eks_namespace, self.tags
+ )
+ return self.virtual_cluster_id
class EmrContainerOperator(BaseOperator):
"""
An operator that submits jobs to EMR on EKS virtual clusters.
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:EmrContainerOperator`
+
:param name: The name of the job run.
:param virtual_cluster_id: The EMR on EKS virtual cluster ID
:param execution_role_arn: The IAM role ARN associated with the job run.
@@ -133,8 +190,10 @@ class EmrContainerOperator(BaseOperator):
Use this if you want to specify a unique ID to prevent two jobs from getting started.
If no token is provided, a UUIDv4 token will be generated for you.
:param aws_conn_id: The Airflow connection used for AWS credentials.
+ :param wait_for_completion: Whether or not to wait in the operator for the job to complete.
:param poll_interval: Time (in seconds) to wait between two consecutive calls to check query status on EMR
- :param max_tries: Maximum number of times to wait for the job run to finish.
+ :param max_tries: Deprecated - use max_polling_attempts instead.
+ :param max_polling_attempts: Maximum number of times to wait for the job run to finish.
Defaults to None, which will poll until the job is *not* in a pending, submitted, or running state.
:param tags: The tags assigned to job runs.
Defaults to None
@@ -157,12 +216,14 @@ def __init__(
execution_role_arn: str,
release_label: str,
job_driver: dict,
- configuration_overrides: Optional[dict] = None,
- client_request_token: Optional[str] = None,
+ configuration_overrides: dict | None = None,
+ client_request_token: str | None = None,
aws_conn_id: str = "aws_default",
+ wait_for_completion: bool = True,
poll_interval: int = 30,
- max_tries: Optional[int] = None,
- tags: Optional[dict] = None,
+ max_tries: int | None = None,
+ tags: dict | None = None,
+ max_polling_attempts: int | None = None,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
@@ -174,10 +235,23 @@ def __init__(
self.configuration_overrides = configuration_overrides or {}
self.aws_conn_id = aws_conn_id
self.client_request_token = client_request_token or str(uuid4())
+ self.wait_for_completion = wait_for_completion
self.poll_interval = poll_interval
- self.max_tries = max_tries
+ self.max_polling_attempts = max_polling_attempts
self.tags = tags
- self.job_id: Optional[str] = None
+ self.job_id: str | None = None
+
+ if max_tries:
+ warnings.warn(
+ f"Parameter `{self.__class__.__name__}.max_tries` is deprecated and will be removed "
+ "in a future release. Please use method `max_polling_attempts` instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ if max_polling_attempts and max_polling_attempts != max_tries:
+ raise Exception("max_polling_attempts must be the same value as max_tries")
+ else:
+ self.max_polling_attempts = max_tries
@cached_property
def hook(self) -> EmrContainerHook:
@@ -187,7 +261,7 @@ def hook(self) -> EmrContainerHook:
virtual_cluster_id=self.virtual_cluster_id,
)
- def execute(self, context: 'Context') -> Optional[str]:
+ def execute(self, context: Context) -> str | None:
"""Run job on EMR Containers"""
self.job_id = self.hook.submit_job(
self.name,
@@ -198,20 +272,25 @@ def execute(self, context: 'Context') -> Optional[str]:
self.client_request_token,
self.tags,
)
- query_status = self.hook.poll_query_status(self.job_id, self.max_tries, self.poll_interval)
-
- if query_status in EmrContainerHook.FAILURE_STATES:
- error_message = self.hook.get_job_failure_reason(self.job_id)
- raise AirflowException(
- f"EMR Containers job failed. Final state is {query_status}. "
- f"query_execution_id is {self.job_id}. Error: {error_message}"
- )
- elif not query_status or query_status in EmrContainerHook.INTERMEDIATE_STATES:
- raise AirflowException(
- f"Final state of EMR Containers job is {query_status}. "
- f"Max tries of poll status exceeded, query_execution_id is {self.job_id}."
+ if self.wait_for_completion:
+ query_status = self.hook.poll_query_status(
+ self.job_id,
+ max_polling_attempts=self.max_polling_attempts,
+ poll_interval=self.poll_interval,
)
+ if query_status in EmrContainerHook.FAILURE_STATES:
+ error_message = self.hook.get_job_failure_reason(self.job_id)
+ raise AirflowException(
+ f"EMR Containers job failed. Final state is {query_status}. "
+ f"query_execution_id is {self.job_id}. Error: {error_message}"
+ )
+ elif not query_status or query_status in EmrContainerHook.INTERMEDIATE_STATES:
+ raise AirflowException(
+ f"Final state of EMR Containers job is {query_status}. "
+ f"Max tries of poll status exceeded, query_execution_id is {self.job_id}."
+ )
+
return self.job_id
def on_kill(self) -> None:
@@ -235,38 +314,6 @@ def on_kill(self) -> None:
self.hook.poll_query_status(self.job_id)
-class EmrClusterLink(BaseOperatorLink):
- """Operator link for EmrCreateJobFlowOperator. It allows users to access the EMR Cluster"""
-
- name = 'EMR Cluster'
-
- def get_link(
- self,
- operator,
- dttm: Optional[datetime] = None,
- ti_key: Optional["TaskInstanceKey"] = None,
- ) -> str:
- """
- Get link to EMR cluster.
-
- :param operator: operator
- :param dttm: datetime
- :return: url link
- """
- if ti_key is not None:
- flow_id = XCom.get_value(key="return_value", ti_key=ti_key)
- else:
- assert dttm
- flow_id = XCom.get_one(
- key="return_value", dag_id=operator.dag_id, task_id=operator.task_id, execution_date=dttm
- )
- return (
- f'https://console.aws.amazon.com/elasticmapreduce/home#cluster-details:{flow_id}'
- if flow_id
- else ''
- )
-
-
class EmrCreateJobFlowOperator(BaseOperator):
"""
Creates an EMR JobFlow, reading the config from the EMR connection.
@@ -277,57 +324,70 @@ class EmrCreateJobFlowOperator(BaseOperator):
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:EmrCreateJobFlowOperator`
- :param aws_conn_id: aws connection to uses
- :param emr_conn_id: emr connection to use
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ If this is None or empty then the default boto3 behaviour is used. If
+ running Airflow in a distributed manner and aws_conn_id is None or
+ empty, then default boto3 configuration would be used (and must be
+ maintained on each worker node)
+ :param emr_conn_id: :ref:`Amazon Elastic MapReduce Connection `.
+ Use to receive an initial Amazon EMR cluster configuration:
+ ``boto3.client('emr').run_job_flow`` request body.
+ If this is None or empty or the connection does not exist,
+ then an empty initial configuration is used.
:param job_flow_overrides: boto3 style arguments or reference to an arguments file
- (must be '.json') to override emr_connection extra. (templated)
+ (must be '.json') to override specific ``emr_conn_id`` extra parameters. (templated)
:param region_name: Region named passed to EmrHook
"""
- template_fields: Sequence[str] = ('job_flow_overrides',)
- template_ext: Sequence[str] = ('.json',)
+ template_fields: Sequence[str] = ("job_flow_overrides",)
+ template_ext: Sequence[str] = (".json",)
template_fields_renderers = {"job_flow_overrides": "json"}
- ui_color = '#f9c915'
+ ui_color = "#f9c915"
operator_extra_links = (EmrClusterLink(),)
def __init__(
self,
*,
- aws_conn_id: str = 'aws_default',
- emr_conn_id: str = 'emr_default',
- job_flow_overrides: Optional[Union[str, Dict[str, Any]]] = None,
- region_name: Optional[str] = None,
+ aws_conn_id: str = "aws_default",
+ emr_conn_id: str | None = "emr_default",
+ job_flow_overrides: str | dict[str, Any] | None = None,
+ region_name: str | None = None,
**kwargs,
):
super().__init__(**kwargs)
self.aws_conn_id = aws_conn_id
self.emr_conn_id = emr_conn_id
- if job_flow_overrides is None:
- job_flow_overrides = {}
- self.job_flow_overrides = job_flow_overrides
+ self.job_flow_overrides = job_flow_overrides or {}
self.region_name = region_name
- def execute(self, context: 'Context') -> str:
+ def execute(self, context: Context) -> str:
emr = EmrHook(
aws_conn_id=self.aws_conn_id, emr_conn_id=self.emr_conn_id, region_name=self.region_name
)
self.log.info(
- 'Creating JobFlow using aws-conn-id: %s, emr-conn-id: %s', self.aws_conn_id, self.emr_conn_id
+ "Creating JobFlow using aws-conn-id: %s, emr-conn-id: %s", self.aws_conn_id, self.emr_conn_id
)
-
if isinstance(self.job_flow_overrides, str):
- job_flow_overrides: Dict[str, Any] = ast.literal_eval(self.job_flow_overrides)
+ job_flow_overrides: dict[str, Any] = ast.literal_eval(self.job_flow_overrides)
self.job_flow_overrides = job_flow_overrides
else:
job_flow_overrides = self.job_flow_overrides
response = emr.create_job_flow(job_flow_overrides)
- if not response['ResponseMetadata']['HTTPStatusCode'] == 200:
- raise AirflowException(f'JobFlow creation failed: {response}')
+ if not response["ResponseMetadata"]["HTTPStatusCode"] == 200:
+ raise AirflowException(f"JobFlow creation failed: {response}")
else:
- self.log.info('JobFlow with id %s created', response['JobFlowId'])
- return response['JobFlowId']
+ job_flow_id = response["JobFlowId"]
+ self.log.info("JobFlow with id %s created", job_flow_id)
+ EmrClusterLink.persist(
+ context=context,
+ operator=self,
+ region_name=emr.conn_region_name,
+ aws_partition=emr.conn_partition,
+ job_flow_id=job_flow_id,
+ )
+ return job_flow_id
class EmrModifyClusterOperator(BaseOperator):
@@ -344,38 +404,44 @@ class EmrModifyClusterOperator(BaseOperator):
:param do_xcom_push: if True, cluster_id is pushed to XCom with key cluster_id.
"""
- template_fields: Sequence[str] = ('cluster_id', 'step_concurrency_level')
+ template_fields: Sequence[str] = ("cluster_id", "step_concurrency_level")
template_ext: Sequence[str] = ()
- ui_color = '#f9c915'
+ ui_color = "#f9c915"
+ operator_extra_links = (EmrClusterLink(),)
def __init__(
- self, *, cluster_id: str, step_concurrency_level: int, aws_conn_id: str = 'aws_default', **kwargs
+ self, *, cluster_id: str, step_concurrency_level: int, aws_conn_id: str = "aws_default", **kwargs
):
- if kwargs.get('xcom_push') is not None:
- raise AirflowException("'xcom_push' was deprecated, use 'do_xcom_push' instead")
super().__init__(**kwargs)
self.aws_conn_id = aws_conn_id
self.cluster_id = cluster_id
self.step_concurrency_level = step_concurrency_level
- def execute(self, context: 'Context') -> int:
+ def execute(self, context: Context) -> int:
emr_hook = EmrHook(aws_conn_id=self.aws_conn_id)
-
emr = emr_hook.get_conn()
if self.do_xcom_push:
- context['ti'].xcom_push(key='cluster_id', value=self.cluster_id)
+ context["ti"].xcom_push(key="cluster_id", value=self.cluster_id)
+
+ EmrClusterLink.persist(
+ context=context,
+ operator=self,
+ region_name=emr_hook.conn_region_name,
+ aws_partition=emr_hook.conn_partition,
+ job_flow_id=self.cluster_id,
+ )
- self.log.info('Modifying cluster %s', self.cluster_id)
+ self.log.info("Modifying cluster %s", self.cluster_id)
response = emr.modify_cluster(
ClusterId=self.cluster_id, StepConcurrencyLevel=self.step_concurrency_level
)
- if response['ResponseMetadata']['HTTPStatusCode'] != 200:
- raise AirflowException(f'Modify cluster failed: {response}')
+ if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
+ raise AirflowException(f"Modify cluster failed: {response}")
else:
- self.log.info('Steps concurrency level %d', response['StepConcurrencyLevel'])
- return response['StepConcurrencyLevel']
+ self.log.info("Steps concurrency level %d", response["StepConcurrencyLevel"])
+ return response["StepConcurrencyLevel"]
class EmrTerminateJobFlowOperator(BaseOperator):
@@ -390,22 +456,292 @@ class EmrTerminateJobFlowOperator(BaseOperator):
:param aws_conn_id: aws connection to uses
"""
- template_fields: Sequence[str] = ('job_flow_id',)
+ template_fields: Sequence[str] = ("job_flow_id",)
template_ext: Sequence[str] = ()
- ui_color = '#f9c915'
+ ui_color = "#f9c915"
+ operator_extra_links = (EmrClusterLink(),)
- def __init__(self, *, job_flow_id: str, aws_conn_id: str = 'aws_default', **kwargs):
+ def __init__(self, *, job_flow_id: str, aws_conn_id: str = "aws_default", **kwargs):
super().__init__(**kwargs)
self.job_flow_id = job_flow_id
self.aws_conn_id = aws_conn_id
- def execute(self, context: 'Context') -> None:
- emr = EmrHook(aws_conn_id=self.aws_conn_id).get_conn()
+ def execute(self, context: Context) -> None:
+ emr_hook = EmrHook(aws_conn_id=self.aws_conn_id)
+ emr = emr_hook.get_conn()
+
+ EmrClusterLink.persist(
+ context=context,
+ operator=self,
+ region_name=emr_hook.conn_region_name,
+ aws_partition=emr_hook.conn_partition,
+ job_flow_id=self.job_flow_id,
+ )
- self.log.info('Terminating JobFlow %s', self.job_flow_id)
+ self.log.info("Terminating JobFlow %s", self.job_flow_id)
response = emr.terminate_job_flows(JobFlowIds=[self.job_flow_id])
- if not response['ResponseMetadata']['HTTPStatusCode'] == 200:
- raise AirflowException(f'JobFlow termination failed: {response}')
+ if not response["ResponseMetadata"]["HTTPStatusCode"] == 200:
+ raise AirflowException(f"JobFlow termination failed: {response}")
else:
- self.log.info('JobFlow with id %s terminated', self.job_flow_id)
+ self.log.info("JobFlow with id %s terminated", self.job_flow_id)
+
+
+class EmrServerlessCreateApplicationOperator(BaseOperator):
+ """
+ Operator to create Serverless EMR Application
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:EmrServerlessCreateApplicationOperator`
+
+ :param release_label: The EMR release version associated with the application.
+ :param job_type: The type of application you want to start, such as Spark or Hive.
+ :param wait_for_completion: If true, wait for the Application to start before returning. Default to True
+ :param client_request_token: The client idempotency token of the application to create.
+ Its value must be unique for each request.
+ :param config: Optional dictionary for arbitrary parameters to the boto API create_application call.
+ :param aws_conn_id: AWS connection to use
+ """
+
+ def __init__(
+ self,
+ release_label: str,
+ job_type: str,
+ client_request_token: str = "",
+ config: dict | None = None,
+ wait_for_completion: bool = True,
+ aws_conn_id: str = "aws_default",
+ **kwargs,
+ ):
+ self.aws_conn_id = aws_conn_id
+ self.release_label = release_label
+ self.job_type = job_type
+ self.wait_for_completion = wait_for_completion
+ self.kwargs = kwargs
+ self.config = config or {}
+ super().__init__(**kwargs)
+
+ self.client_request_token = client_request_token or str(uuid4())
+
+ @cached_property
+ def hook(self) -> EmrServerlessHook:
+ """Create and return an EmrServerlessHook."""
+ return EmrServerlessHook(aws_conn_id=self.aws_conn_id)
+
+ def execute(self, context: Context):
+ response = self.hook.conn.create_application(
+ clientToken=self.client_request_token,
+ releaseLabel=self.release_label,
+ type=self.job_type,
+ **self.config,
+ )
+ application_id = response["applicationId"]
+
+ if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
+ raise AirflowException(f"Application Creation failed: {response}")
+
+ self.log.info("EMR serverless application created: %s", application_id)
+
+ # This should be replaced with a boto waiter when available.
+ self.hook.waiter(
+ get_state_callable=self.hook.conn.get_application,
+ get_state_args={"applicationId": application_id},
+ parse_response=["application", "state"],
+ desired_state={"CREATED"},
+ failure_states=EmrServerlessHook.APPLICATION_FAILURE_STATES,
+ object_type="application",
+ action="created",
+ )
+
+ self.log.info("Starting application %s", application_id)
+ self.hook.conn.start_application(applicationId=application_id)
+
+ if self.wait_for_completion:
+ # This should be replaced with a boto waiter when available.
+ self.hook.waiter(
+ get_state_callable=self.hook.conn.get_application,
+ get_state_args={"applicationId": application_id},
+ parse_response=["application", "state"],
+ desired_state={"STARTED"},
+ failure_states=EmrServerlessHook.APPLICATION_FAILURE_STATES,
+ object_type="application",
+ action="started",
+ )
+
+ return application_id
+
+
+class EmrServerlessStartJobOperator(BaseOperator):
+ """
+ Operator to start EMR Serverless job.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:EmrServerlessStartJobOperator`
+
+ :param application_id: ID of the EMR Serverless application to start.
+ :param execution_role_arn: ARN of role to perform action.
+ :param job_driver: Driver that the job runs on.
+ :param configuration_overrides: Configuration specifications to override existing configurations.
+ :param client_request_token: The client idempotency token of the application to create.
+ Its value must be unique for each request.
+ :param config: Optional dictionary for arbitrary parameters to the boto API start_job_run call.
+ :param wait_for_completion: If true, waits for the job to start before returning. Defaults to True.
+ :param aws_conn_id: AWS connection to use.
+ :param name: Name for the EMR Serverless job. If not provided, a default name will be assigned.
+ """
+
+ template_fields: Sequence[str] = (
+ "application_id",
+ "execution_role_arn",
+ "job_driver",
+ "configuration_overrides",
+ )
+
+ def __init__(
+ self,
+ application_id: str,
+ execution_role_arn: str,
+ job_driver: dict,
+ configuration_overrides: dict | None,
+ client_request_token: str = "",
+ config: dict | None = None,
+ wait_for_completion: bool = True,
+ aws_conn_id: str = "aws_default",
+ name: str | None = None,
+ **kwargs,
+ ):
+ self.aws_conn_id = aws_conn_id
+ self.application_id = application_id
+ self.execution_role_arn = execution_role_arn
+ self.job_driver = job_driver
+ self.configuration_overrides = configuration_overrides
+ self.wait_for_completion = wait_for_completion
+ self.config = config or {}
+ self.name = name or self.config.pop("name", f"emr_serverless_job_airflow_{uuid4()}")
+ super().__init__(**kwargs)
+
+ self.client_request_token = client_request_token or str(uuid4())
+
+ @cached_property
+ def hook(self) -> EmrServerlessHook:
+ """Create and return an EmrServerlessHook."""
+ return EmrServerlessHook(aws_conn_id=self.aws_conn_id)
+
+ def execute(self, context: Context) -> dict:
+ self.log.info("Starting job on Application: %s", self.application_id)
+
+ app_state = self.hook.conn.get_application(applicationId=self.application_id)["application"]["state"]
+ if app_state not in EmrServerlessHook.APPLICATION_SUCCESS_STATES:
+ self.hook.conn.start_application(applicationId=self.application_id)
+
+ self.hook.waiter(
+ get_state_callable=self.hook.conn.get_application,
+ get_state_args={"applicationId": self.application_id},
+ parse_response=["application", "state"],
+ desired_state={"STARTED"},
+ failure_states=EmrServerlessHook.APPLICATION_FAILURE_STATES,
+ object_type="application",
+ action="started",
+ )
+
+ response = self.hook.conn.start_job_run(
+ clientToken=self.client_request_token,
+ applicationId=self.application_id,
+ executionRoleArn=self.execution_role_arn,
+ jobDriver=self.job_driver,
+ configurationOverrides=self.configuration_overrides,
+ name=self.name,
+ **self.config,
+ )
+
+ if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
+ raise AirflowException(f"EMR serverless job failed to start: {response}")
+
+ self.log.info("EMR serverless job started: %s", response["jobRunId"])
+ if self.wait_for_completion:
+ # This should be replaced with a boto waiter when available.
+ self.hook.waiter(
+ get_state_callable=self.hook.conn.get_job_run,
+ get_state_args={
+ "applicationId": self.application_id,
+ "jobRunId": response["jobRunId"],
+ },
+ parse_response=["jobRun", "state"],
+ desired_state=EmrServerlessHook.JOB_SUCCESS_STATES,
+ failure_states=EmrServerlessHook.JOB_FAILURE_STATES,
+ object_type="job",
+ action="run",
+ )
+ return response["jobRunId"]
+
+
+class EmrServerlessDeleteApplicationOperator(BaseOperator):
+ """
+ Operator to delete EMR Serverless application
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:EmrServerlessDeleteApplicationOperator`
+
+ :param application_id: ID of the EMR Serverless application to delete.
+ :param wait_for_completion: If true, wait for the Application to start before returning. Default to True
+ :param aws_conn_id: AWS connection to use
+ """
+
+ template_fields: Sequence[str] = ("application_id",)
+
+ def __init__(
+ self,
+ application_id: str,
+ wait_for_completion: bool = True,
+ aws_conn_id: str = "aws_default",
+ **kwargs,
+ ):
+ self.aws_conn_id = aws_conn_id
+ self.application_id = application_id
+ self.wait_for_completion = wait_for_completion
+ super().__init__(**kwargs)
+
+ @cached_property
+ def hook(self) -> EmrServerlessHook:
+ """Create and return an EmrServerlessHook."""
+ return EmrServerlessHook(aws_conn_id=self.aws_conn_id)
+
+ def execute(self, context: Context) -> None:
+ self.log.info("Stopping application: %s", self.application_id)
+ self.hook.conn.stop_application(applicationId=self.application_id)
+
+ # This should be replaced with a boto waiter when available.
+ self.hook.waiter(
+ get_state_callable=self.hook.conn.get_application,
+ get_state_args={
+ "applicationId": self.application_id,
+ },
+ parse_response=["application", "state"],
+ desired_state=EmrServerlessHook.APPLICATION_FAILURE_STATES,
+ failure_states=set(),
+ object_type="application",
+ action="stopped",
+ )
+
+ self.log.info("Deleting application: %s", self.application_id)
+ response = self.hook.conn.delete_application(applicationId=self.application_id)
+
+ if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
+ raise AirflowException(f"Application deletion failed: {response}")
+
+ if self.wait_for_completion:
+ # This should be replaced with a boto waiter when available.
+ self.hook.waiter(
+ get_state_callable=self.hook.conn.get_application,
+ get_state_args={"applicationId": self.application_id},
+ parse_response=["application", "state"],
+ desired_state={"TERMINATED"},
+ failure_states=EmrServerlessHook.APPLICATION_FAILURE_STATES,
+ object_type="application",
+ action="deleted",
+ )
+
+ self.log.info("EMR serverless application deleted")
diff --git a/airflow/providers/amazon/aws/operators/emr_add_steps.py b/airflow/providers/amazon/aws/operators/emr_add_steps.py
deleted file mode 100644
index c99be43e8cdf5..0000000000000
--- a/airflow/providers/amazon/aws/operators/emr_add_steps.py
+++ /dev/null
@@ -1,29 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.emr`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.operators.emr import EmrAddStepsOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.emr`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/providers/amazon/aws/operators/emr_containers.py b/airflow/providers/amazon/aws/operators/emr_containers.py
deleted file mode 100644
index 6e81b0bddea0c..0000000000000
--- a/airflow/providers/amazon/aws/operators/emr_containers.py
+++ /dev/null
@@ -1,44 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.emr`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.operators.emr import EmrContainerOperator
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.emr`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class EMRContainerOperator(EmrContainerOperator):
- """
- This class is deprecated.
- Please use :class:`airflow.providers.amazon.aws.operators.emr.EmrContainerOperator`.
- """
-
- def __init__(self, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.amazon.aws.operators.emr.EmrContainerOperator`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(**kwargs)
diff --git a/airflow/providers/amazon/aws/operators/emr_create_job_flow.py b/airflow/providers/amazon/aws/operators/emr_create_job_flow.py
deleted file mode 100644
index 18c052fa2c835..0000000000000
--- a/airflow/providers/amazon/aws/operators/emr_create_job_flow.py
+++ /dev/null
@@ -1,31 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.emr`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.operators.emr import EmrClusterLink, EmrCreateJobFlowOperator
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.emr`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-__all__ = ["EmrClusterLink", "EmrCreateJobFlowOperator"]
diff --git a/airflow/providers/amazon/aws/operators/emr_modify_cluster.py b/airflow/providers/amazon/aws/operators/emr_modify_cluster.py
deleted file mode 100644
index 71b44d5364e42..0000000000000
--- a/airflow/providers/amazon/aws/operators/emr_modify_cluster.py
+++ /dev/null
@@ -1,29 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.emr`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.operators.emr import EmrClusterLink, EmrModifyClusterOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.emr`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/providers/amazon/aws/operators/emr_terminate_job_flow.py b/airflow/providers/amazon/aws/operators/emr_terminate_job_flow.py
deleted file mode 100644
index f924393a443ad..0000000000000
--- a/airflow/providers/amazon/aws/operators/emr_terminate_job_flow.py
+++ /dev/null
@@ -1,29 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.emr`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.operators.emr import EmrTerminateJobFlowOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.emr`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/providers/amazon/aws/operators/glacier.py b/airflow/providers/amazon/aws/operators/glacier.py
index 337492a4523e1..4e7c8b5e17421 100644
--- a/airflow/providers/amazon/aws/operators/glacier.py
+++ b/airflow/providers/amazon/aws/operators/glacier.py
@@ -15,6 +15,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
from typing import TYPE_CHECKING, Sequence
from airflow.models import BaseOperator
@@ -49,6 +51,56 @@ def __init__(
self.aws_conn_id = aws_conn_id
self.vault_name = vault_name
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
hook = GlacierHook(aws_conn_id=self.aws_conn_id)
return hook.retrieve_inventory(vault_name=self.vault_name)
+
+
+class GlacierUploadArchiveOperator(BaseOperator):
+ """
+ This operator add an archive to an Amazon S3 Glacier vault
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:GlacierUploadArchiveOperator`
+
+ :param vault_name: The name of the vault
+ :param body: A bytes or seekable file-like object. The data to upload.
+ :param checksum: The SHA256 tree hash of the data being uploaded.
+ This parameter is automatically populated if it is not provided
+ :param archive_description: The description of the archive you are uploading
+ :param account_id: (Optional) AWS account ID of the account that owns the vault.
+ Defaults to the credentials used to sign the request
+ :param aws_conn_id: The reference to the AWS connection details
+ """
+
+ template_fields: Sequence[str] = ("vault_name",)
+
+ def __init__(
+ self,
+ *,
+ vault_name: str,
+ body: object,
+ checksum: str | None = None,
+ archive_description: str | None = None,
+ account_id: str | None = None,
+ aws_conn_id="aws_default",
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.aws_conn_id = aws_conn_id
+ self.account_id = account_id
+ self.vault_name = vault_name
+ self.body = body
+ self.checksum = checksum
+ self.archive_description = archive_description
+
+ def execute(self, context: Context):
+ hook = GlacierHook(aws_conn_id=self.aws_conn_id)
+ return hook.get_conn().upload_archive(
+ accountId=self.account_id,
+ vaultName=self.vault_name,
+ archiveDescription=self.archive_description,
+ body=self.body,
+ checksum=self.checksum,
+ )
diff --git a/airflow/providers/amazon/aws/operators/glue.py b/airflow/providers/amazon/aws/operators/glue.py
index cb76e784300e0..37a08a216a0b4 100644
--- a/airflow/providers/amazon/aws/operators/glue.py
+++ b/airflow/providers/amazon/aws/operators/glue.py
@@ -15,10 +15,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
import os.path
-import warnings
-from typing import TYPE_CHECKING, Optional, Sequence
+from typing import TYPE_CHECKING, Sequence
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.glue import GlueJobHook
@@ -51,33 +51,41 @@ class GlueJobOperator(BaseOperator):
:param create_job_kwargs: Extra arguments for Glue Job Creation
:param run_job_kwargs: Extra arguments for Glue Job Run
:param wait_for_completion: Whether or not wait for job run completion. (default: True)
+ :param verbose: If True, Glue Job Run logs show in the Airflow Task Logs. (default: False)
"""
- template_fields: Sequence[str] = ('script_args',)
+ template_fields: Sequence[str] = (
+ "job_name",
+ "script_location",
+ "script_args",
+ "s3_bucket",
+ "iam_role_name",
+ )
template_ext: Sequence[str] = ()
template_fields_renderers = {
"script_args": "json",
"create_job_kwargs": "json",
}
- ui_color = '#ededed'
+ ui_color = "#ededed"
def __init__(
self,
*,
- job_name: str = 'aws_glue_default_job',
- job_desc: str = 'AWS Glue Job with Airflow',
- script_location: str,
- concurrent_run_limit: Optional[int] = None,
- script_args: Optional[dict] = None,
+ job_name: str = "aws_glue_default_job",
+ job_desc: str = "AWS Glue Job with Airflow",
+ script_location: str | None = None,
+ concurrent_run_limit: int | None = None,
+ script_args: dict | None = None,
retry_limit: int = 0,
- num_of_dpus: Optional[int] = None,
- aws_conn_id: str = 'aws_default',
- region_name: Optional[str] = None,
- s3_bucket: Optional[str] = None,
- iam_role_name: Optional[str] = None,
- create_job_kwargs: Optional[dict] = None,
- run_job_kwargs: Optional[dict] = None,
+ num_of_dpus: int | None = None,
+ aws_conn_id: str = "aws_default",
+ region_name: str | None = None,
+ s3_bucket: str | None = None,
+ iam_role_name: str | None = None,
+ create_job_kwargs: dict | None = None,
+ run_job_kwargs: dict | None = None,
wait_for_completion: bool = True,
+ verbose: bool = False,
**kwargs,
):
super().__init__(**kwargs)
@@ -93,12 +101,13 @@ def __init__(
self.s3_bucket = s3_bucket
self.iam_role_name = iam_role_name
self.s3_protocol = "s3://"
- self.s3_artifacts_prefix = 'artifacts/glue-scripts/'
+ self.s3_artifacts_prefix = "artifacts/glue-scripts/"
self.create_job_kwargs = create_job_kwargs
self.run_job_kwargs = run_job_kwargs or {}
self.wait_for_completion = wait_for_completion
+ self.verbose = verbose
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
"""
Executes AWS Glue Job from Airflow
@@ -135,29 +144,13 @@ def execute(self, context: 'Context'):
)
glue_job_run = glue_job.initialize_job(self.script_args, self.run_job_kwargs)
if self.wait_for_completion:
- glue_job_run = glue_job.job_completion(self.job_name, glue_job_run['JobRunId'])
+ glue_job_run = glue_job.job_completion(self.job_name, glue_job_run["JobRunId"], self.verbose)
self.log.info(
"AWS Glue Job: %s status: %s. Run Id: %s",
self.job_name,
- glue_job_run['JobRunState'],
- glue_job_run['JobRunId'],
+ glue_job_run["JobRunState"],
+ glue_job_run["JobRunId"],
)
else:
- self.log.info("AWS Glue Job: %s. Run Id: %s", self.job_name, glue_job_run['JobRunId'])
- return glue_job_run['JobRunId']
-
-
-class AwsGlueJobOperator(GlueJobOperator):
- """
- This operator is deprecated.
- Please use :class:`airflow.providers.amazon.aws.operators.glue.GlueJobOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- "This operator is deprecated. "
- "Please use :class:`airflow.providers.amazon.aws.operators.glue.GlueJobOperator`.",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
+ self.log.info("AWS Glue Job: %s. Run Id: %s", self.job_name, glue_job_run["JobRunId"])
+ return glue_job_run["JobRunId"]
diff --git a/airflow/providers/amazon/aws/operators/glue_crawler.py b/airflow/providers/amazon/aws/operators/glue_crawler.py
index a584301f926d1..1e30be1b262bd 100644
--- a/airflow/providers/amazon/aws/operators/glue_crawler.py
+++ b/airflow/providers/amazon/aws/operators/glue_crawler.py
@@ -15,19 +15,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-import sys
-import warnings
-from typing import TYPE_CHECKING
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Sequence
if TYPE_CHECKING:
from airflow.utils.context import Context
-
-if sys.version_info >= (3, 8):
- from functools import cached_property
-else:
- from cached_property import cached_property
-
+from airflow.compat.functools import cached_property
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.glue_crawler import GlueCrawlerHook
@@ -48,12 +43,13 @@ class GlueCrawlerOperator(BaseOperator):
:param wait_for_completion: Whether or not wait for crawl execution completion. (default: True)
"""
- ui_color = '#ededed'
+ template_fields: Sequence[str] = ("config",)
+ ui_color = "#ededed"
def __init__(
self,
config,
- aws_conn_id='aws_default',
+ aws_conn_id="aws_default",
poll_interval: int = 5,
wait_for_completion: bool = True,
**kwargs,
@@ -69,13 +65,13 @@ def hook(self) -> GlueCrawlerHook:
"""Create and return an GlueCrawlerHook."""
return GlueCrawlerHook(self.aws_conn_id)
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
"""
Executes AWS Glue Crawler from Airflow
:return: the name of the current glue crawler.
"""
- crawler_name = self.config['Name']
+ crawler_name = self.config["Name"]
if self.hook.has_crawler(crawler_name):
self.hook.update_crawler(**self.config)
else:
@@ -88,19 +84,3 @@ def execute(self, context: 'Context'):
self.hook.wait_for_crawler_completion(crawler_name=crawler_name, poll_interval=self.poll_interval)
return crawler_name
-
-
-class AwsGlueCrawlerOperator(GlueCrawlerOperator):
- """
- This operator is deprecated.
- Please use :class:`airflow.providers.amazon.aws.operators.glue_crawler.GlueCrawlerOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- "This operator is deprecated. "
- "Please use :class:`airflow.providers.amazon.aws.operators.glue_crawler.GlueCrawlerOperator`.",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/providers/amazon/aws/operators/lambda_function.py b/airflow/providers/amazon/aws/operators/lambda_function.py
new file mode 100644
index 0000000000000..59e0012b0530f
--- /dev/null
+++ b/airflow/providers/amazon/aws/operators/lambda_function.py
@@ -0,0 +1,103 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import json
+from typing import TYPE_CHECKING, Sequence
+
+from airflow.models import BaseOperator
+from airflow.providers.amazon.aws.hooks.lambda_function import LambdaHook
+
+if TYPE_CHECKING:
+ from airflow.utils.context import Context
+
+
+class AwsLambdaInvokeFunctionOperator(BaseOperator):
+ """
+ Invokes an AWS Lambda function.
+ You can invoke a function synchronously (and wait for the response),
+ or asynchronously.
+ To invoke a function asynchronously,
+ set `invocation_type` to `Event`. For more details,
+ review the boto3 Lambda invoke docs.
+
+ :param function_name: The name of the AWS Lambda function, version, or alias.
+ :param payload: The JSON string that you want to provide to your Lambda function as input.
+ :param log_type: Set to Tail to include the execution log in the response. Otherwise, set to "None".
+ :param qualifier: Specify a version or alias to invoke a published version of the function.
+ :param aws_conn_id: The AWS connection ID to use
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:AwsLambdaInvokeFunctionOperator`
+
+ """
+
+ template_fields: Sequence[str] = ("function_name", "payload", "qualifier", "invocation_type")
+ ui_color = "#ff7300"
+
+ def __init__(
+ self,
+ *,
+ function_name: str,
+ log_type: str | None = None,
+ qualifier: str | None = None,
+ invocation_type: str | None = None,
+ client_context: str | None = None,
+ payload: str | None = None,
+ aws_conn_id: str = "aws_default",
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.function_name = function_name
+ self.payload = payload
+ self.log_type = log_type
+ self.qualifier = qualifier
+ self.invocation_type = invocation_type
+ self.client_context = client_context
+ self.aws_conn_id = aws_conn_id
+
+ def execute(self, context: Context):
+ """
+ Invokes the target AWS Lambda function from Airflow.
+
+ :return: The response payload from the function, or an error object.
+ """
+ hook = LambdaHook(aws_conn_id=self.aws_conn_id)
+ success_status_codes = [200, 202, 204]
+ self.log.info("Invoking AWS Lambda function: %s with payload: %s", self.function_name, self.payload)
+ response = hook.invoke_lambda(
+ function_name=self.function_name,
+ invocation_type=self.invocation_type,
+ log_type=self.log_type,
+ client_context=self.client_context,
+ payload=self.payload,
+ qualifier=self.qualifier,
+ )
+ self.log.info("Lambda response metadata: %r", response.get("ResponseMetadata"))
+ if response.get("StatusCode") not in success_status_codes:
+ raise ValueError("Lambda function did not execute", json.dumps(response.get("ResponseMetadata")))
+ payload_stream = response.get("Payload")
+ payload = payload_stream.read().decode()
+ if "FunctionError" in response:
+ raise ValueError(
+ "Lambda function execution resulted in error",
+ {"ResponseMetadata": response.get("ResponseMetadata"), "Payload": payload},
+ )
+ self.log.info("Lambda function invocation succeeded: %r", response.get("ResponseMetadata"))
+ return payload
diff --git a/airflow/providers/amazon/aws/operators/quicksight.py b/airflow/providers/amazon/aws/operators/quicksight.py
index a9da61bdfc3bf..85514af806c2f 100644
--- a/airflow/providers/amazon/aws/operators/quicksight.py
+++ b/airflow/providers/amazon/aws/operators/quicksight.py
@@ -14,8 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
-from typing import TYPE_CHECKING, Optional, Sequence
+from typing import TYPE_CHECKING, Sequence
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.quicksight import QuickSightHook
@@ -71,7 +72,7 @@ def __init__(
wait_for_completion: bool = True,
check_interval: int = 30,
aws_conn_id: str = DEFAULT_CONN_ID,
- region: Optional[str] = None,
+ region: str | None = None,
**kwargs,
):
self.data_set_id = data_set_id
@@ -83,12 +84,12 @@ def __init__(
self.region = region
super().__init__(**kwargs)
- def execute(self, context: "Context"):
+ def execute(self, context: Context):
hook = QuickSightHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region,
)
- self.log.info("Running the Amazon QuickSight SPICE Ingestion on Dataset ID: %s)", self.data_set_id)
+ self.log.info("Running the Amazon QuickSight SPICE Ingestion on Dataset ID: %s", self.data_set_id)
return hook.create_ingestion(
data_set_id=self.data_set_id,
ingestion_id=self.ingestion_id,
diff --git a/airflow/providers/amazon/aws/operators/rds.py b/airflow/providers/amazon/aws/operators/rds.py
index a527107e80a46..c10e969c8bc2d 100644
--- a/airflow/providers/amazon/aws/operators/rds.py
+++ b/airflow/providers/amazon/aws/operators/rds.py
@@ -15,14 +15,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
import json
-import time
-from typing import TYPE_CHECKING, List, Optional, Sequence
+from typing import TYPE_CHECKING, Sequence
from mypy_boto3_rds.type_defs import TagTypeDef
-from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.rds import RdsHook
from airflow.providers.amazon.aws.utils.rds import RdsDbType
@@ -37,64 +36,14 @@ class RdsBaseOperator(BaseOperator):
ui_color = "#eeaa88"
ui_fgcolor = "#ffffff"
- def __init__(self, *args, aws_conn_id: str = "aws_conn_id", hook_params: Optional[dict] = None, **kwargs):
+ def __init__(self, *args, aws_conn_id: str = "aws_conn_id", hook_params: dict | None = None, **kwargs):
hook_params = hook_params or {}
self.hook = RdsHook(aws_conn_id=aws_conn_id, **hook_params)
super().__init__(*args, **kwargs)
self._await_interval = 60 # seconds
- def _describe_item(self, item_type: str, item_name: str) -> list:
-
- if item_type == 'instance_snapshot':
- db_snaps = self.hook.conn.describe_db_snapshots(DBSnapshotIdentifier=item_name)
- return db_snaps['DBSnapshots']
- elif item_type == 'cluster_snapshot':
- cl_snaps = self.hook.conn.describe_db_cluster_snapshots(DBClusterSnapshotIdentifier=item_name)
- return cl_snaps['DBClusterSnapshots']
- elif item_type == 'export_task':
- exports = self.hook.conn.describe_export_tasks(ExportTaskIdentifier=item_name)
- return exports['ExportTasks']
- elif item_type == 'event_subscription':
- subscriptions = self.hook.conn.describe_event_subscriptions(SubscriptionName=item_name)
- return subscriptions['EventSubscriptionsList']
- else:
- raise AirflowException(f"Method for {item_type} is not implemented")
-
- def _await_status(
- self,
- item_type: str,
- item_name: str,
- wait_statuses: Optional[List[str]] = None,
- ok_statuses: Optional[List[str]] = None,
- error_statuses: Optional[List[str]] = None,
- ) -> None:
- """
- Continuously gets item description from `_describe_item()` and waits until:
- - status is in `wait_statuses`
- - status not in `ok_statuses` and `error_statuses`
- """
- while True:
- items = self._describe_item(item_type, item_name)
-
- if len(items) == 0:
- raise AirflowException(f"There is no {item_type} with identifier {item_name}")
- if len(items) > 1:
- raise AirflowException(f"There are {len(items)} {item_type} with identifier {item_name}")
-
- if wait_statuses and items[0]['Status'].lower() in wait_statuses:
- time.sleep(self._await_interval)
- continue
- elif ok_statuses and items[0]['Status'].lower() in ok_statuses:
- break
- elif error_statuses and items[0]['Status'].lower() in error_statuses:
- raise AirflowException(f"Item has error status ({error_statuses}): {items[0]}")
- else:
- raise AirflowException(f"Item has uncertain status: {items[0]}")
-
- return None
-
- def execute(self, context: 'Context') -> str:
+ def execute(self, context: Context) -> str:
"""Different implementations for snapshots, tasks and events"""
raise NotImplementedError
@@ -117,6 +66,7 @@ class RdsCreateDbSnapshotOperator(RdsBaseOperator):
:param db_snapshot_identifier: The identifier for the DB snapshot
:param tags: A list of tags in format `[{"Key": "something", "Value": "something"},]
`USER Tagging `__
+ :param wait_for_completion: If True, waits for creation of the DB snapshot to complete. (default: True)
"""
template_fields = ("db_snapshot_identifier", "db_identifier", "tags")
@@ -127,7 +77,8 @@ def __init__(
db_type: str,
db_identifier: str,
db_snapshot_identifier: str,
- tags: Optional[Sequence[TagTypeDef]] = None,
+ tags: Sequence[TagTypeDef] | None = None,
+ wait_for_completion: bool = True,
aws_conn_id: str = "aws_conn_id",
**kwargs,
):
@@ -136,8 +87,9 @@ def __init__(
self.db_identifier = db_identifier
self.db_snapshot_identifier = db_snapshot_identifier
self.tags = tags or []
+ self.wait_for_completion = wait_for_completion
- def execute(self, context: 'Context') -> str:
+ def execute(self, context: Context) -> str:
self.log.info(
"Starting to create snapshot of RDS %s '%s': %s",
self.db_type,
@@ -152,12 +104,8 @@ def execute(self, context: 'Context') -> str:
Tags=self.tags,
)
create_response = json.dumps(create_instance_snap, default=str)
- self._await_status(
- 'instance_snapshot',
- self.db_snapshot_identifier,
- wait_statuses=['creating'],
- ok_statuses=['available'],
- )
+ if self.wait_for_completion:
+ self.hook.wait_for_db_snapshot_state(self.db_snapshot_identifier, target_state="available")
else:
create_cluster_snap = self.hook.conn.create_db_cluster_snapshot(
DBClusterIdentifier=self.db_identifier,
@@ -165,13 +113,10 @@ def execute(self, context: 'Context') -> str:
Tags=self.tags,
)
create_response = json.dumps(create_cluster_snap, default=str)
- self._await_status(
- 'cluster_snapshot',
- self.db_snapshot_identifier,
- wait_statuses=['creating'],
- ok_statuses=['available'],
- )
-
+ if self.wait_for_completion:
+ self.hook.wait_for_db_cluster_snapshot_state(
+ self.db_snapshot_identifier, target_state="available"
+ )
return create_response
@@ -196,6 +141,7 @@ class RdsCopyDbSnapshotOperator(RdsBaseOperator):
:param target_custom_availability_zone: The external custom Availability Zone identifier for the target
Only when db_type='instance'
:param source_region: The ID of the region that contains the snapshot to be copied
+ :param wait_for_completion: If True, waits for snapshot copy to complete. (default: True)
"""
template_fields = (
@@ -213,12 +159,13 @@ def __init__(
source_db_snapshot_identifier: str,
target_db_snapshot_identifier: str,
kms_key_id: str = "",
- tags: Optional[Sequence[TagTypeDef]] = None,
+ tags: Sequence[TagTypeDef] | None = None,
copy_tags: bool = False,
pre_signed_url: str = "",
option_group_name: str = "",
target_custom_availability_zone: str = "",
source_region: str = "",
+ wait_for_completion: bool = True,
aws_conn_id: str = "aws_default",
**kwargs,
):
@@ -234,8 +181,9 @@ def __init__(
self.option_group_name = option_group_name
self.target_custom_availability_zone = target_custom_availability_zone
self.source_region = source_region
+ self.wait_for_completion = wait_for_completion
- def execute(self, context: 'Context') -> str:
+ def execute(self, context: Context) -> str:
self.log.info(
"Starting to copy snapshot '%s' as '%s'",
self.source_db_snapshot_identifier,
@@ -255,12 +203,10 @@ def execute(self, context: 'Context') -> str:
SourceRegion=self.source_region,
)
copy_response = json.dumps(copy_instance_snap, default=str)
- self._await_status(
- 'instance_snapshot',
- self.target_db_snapshot_identifier,
- wait_statuses=['creating'],
- ok_statuses=['available'],
- )
+ if self.wait_for_completion:
+ self.hook.wait_for_db_snapshot_state(
+ self.target_db_snapshot_identifier, target_state="available"
+ )
else:
copy_cluster_snap = self.hook.conn.copy_db_cluster_snapshot(
SourceDBClusterSnapshotIdentifier=self.source_db_snapshot_identifier,
@@ -272,13 +218,10 @@ def execute(self, context: 'Context') -> str:
SourceRegion=self.source_region,
)
copy_response = json.dumps(copy_cluster_snap, default=str)
- self._await_status(
- 'cluster_snapshot',
- self.target_db_snapshot_identifier,
- wait_statuses=['copying'],
- ok_statuses=['available'],
- )
-
+ if self.wait_for_completion:
+ self.hook.wait_for_db_cluster_snapshot_state(
+ self.target_db_snapshot_identifier, target_state="available"
+ )
return copy_response
@@ -301,6 +244,7 @@ def __init__(
*,
db_type: str,
db_snapshot_identifier: str,
+ wait_for_completion: bool = True,
aws_conn_id: str = "aws_default",
**kwargs,
):
@@ -308,8 +252,9 @@ def __init__(
self.db_type = RdsDbType(db_type)
self.db_snapshot_identifier = db_snapshot_identifier
+ self.wait_for_completion = wait_for_completion
- def execute(self, context: 'Context') -> str:
+ def execute(self, context: Context) -> str:
self.log.info("Starting to delete snapshot '%s'", self.db_snapshot_identifier)
if self.db_type.value == "instance":
@@ -317,11 +262,17 @@ def execute(self, context: 'Context') -> str:
DBSnapshotIdentifier=self.db_snapshot_identifier,
)
delete_response = json.dumps(delete_instance_snap, default=str)
+ if self.wait_for_completion:
+ self.hook.wait_for_db_snapshot_state(self.db_snapshot_identifier, target_state="deleted")
else:
delete_cluster_snap = self.hook.conn.delete_db_cluster_snapshot(
DBClusterSnapshotIdentifier=self.db_snapshot_identifier,
)
delete_response = json.dumps(delete_cluster_snap, default=str)
+ if self.wait_for_completion:
+ self.hook.wait_for_db_cluster_snapshot_state(
+ self.db_snapshot_identifier, target_state="deleted"
+ )
return delete_response
@@ -341,6 +292,7 @@ class RdsStartExportTaskOperator(RdsBaseOperator):
:param kms_key_id: The ID of the Amazon Web Services KMS key to use to encrypt the snapshot.
:param s3_prefix: The Amazon S3 bucket prefix to use as the file name and path of the exported snapshot.
:param export_only: The data to be exported from the snapshot.
+ :param wait_for_completion: If True, waits for the DB snapshot export to complete. (default: True)
"""
template_fields = (
@@ -361,8 +313,9 @@ def __init__(
s3_bucket_name: str,
iam_role_arn: str,
kms_key_id: str,
- s3_prefix: str = '',
- export_only: Optional[List[str]] = None,
+ s3_prefix: str = "",
+ export_only: list[str] | None = None,
+ wait_for_completion: bool = True,
aws_conn_id: str = "aws_default",
**kwargs,
):
@@ -375,8 +328,9 @@ def __init__(
self.kms_key_id = kms_key_id
self.s3_prefix = s3_prefix
self.export_only = export_only or []
+ self.wait_for_completion = wait_for_completion
- def execute(self, context: 'Context') -> str:
+ def execute(self, context: Context) -> str:
self.log.info("Starting export task %s for snapshot %s", self.export_task_identifier, self.source_arn)
start_export = self.hook.conn.start_export_task(
@@ -389,14 +343,8 @@ def execute(self, context: 'Context') -> str:
ExportOnly=self.export_only,
)
- self._await_status(
- 'export_task',
- self.export_task_identifier,
- wait_statuses=['starting', 'in_progress'],
- ok_statuses=['complete'],
- error_statuses=['canceling', 'canceled'],
- )
-
+ if self.wait_for_completion:
+ self.hook.wait_for_export_task_state(self.export_task_identifier, target_state="complete")
return json.dumps(start_export, default=str)
@@ -409,6 +357,7 @@ class RdsCancelExportTaskOperator(RdsBaseOperator):
:ref:`howto/operator:RdsCancelExportTaskOperator`
:param export_task_identifier: The identifier of the snapshot export task to cancel
+ :param wait_for_completion: If True, waits for DB snapshot export to cancel. (default: True)
"""
template_fields = ("export_task_identifier",)
@@ -417,26 +366,24 @@ def __init__(
self,
*,
export_task_identifier: str,
+ wait_for_completion: bool = True,
aws_conn_id: str = "aws_default",
**kwargs,
):
super().__init__(aws_conn_id=aws_conn_id, **kwargs)
self.export_task_identifier = export_task_identifier
+ self.wait_for_completion = wait_for_completion
- def execute(self, context: 'Context') -> str:
+ def execute(self, context: Context) -> str:
self.log.info("Canceling export task %s", self.export_task_identifier)
cancel_export = self.hook.conn.cancel_export_task(
ExportTaskIdentifier=self.export_task_identifier,
)
- self._await_status(
- 'export_task',
- self.export_task_identifier,
- wait_statuses=['canceling'],
- ok_statuses=['canceled'],
- )
+ if self.wait_for_completion:
+ self.hook.wait_for_export_task_state(self.export_task_identifier, target_state="canceled")
return json.dumps(cancel_export, default=str)
@@ -458,6 +405,7 @@ class RdsCreateEventSubscriptionOperator(RdsBaseOperator):
:param enabled: A value that indicates whether to activate the subscription (default True)l
:param tags: A list of tags in format `[{"Key": "something", "Value": "something"},]
`USER Tagging `__
+ :param wait_for_completion: If True, waits for creation of the subscription to complete. (default: True)
"""
template_fields = (
@@ -475,10 +423,11 @@ def __init__(
subscription_name: str,
sns_topic_arn: str,
source_type: str = "",
- event_categories: Optional[Sequence[str]] = None,
- source_ids: Optional[Sequence[str]] = None,
+ event_categories: Sequence[str] | None = None,
+ source_ids: Sequence[str] | None = None,
enabled: bool = True,
- tags: Optional[Sequence[TagTypeDef]] = None,
+ tags: Sequence[TagTypeDef] | None = None,
+ wait_for_completion: bool = True,
aws_conn_id: str = "aws_default",
**kwargs,
):
@@ -491,8 +440,9 @@ def __init__(
self.source_ids = source_ids or []
self.enabled = enabled
self.tags = tags or []
+ self.wait_for_completion = wait_for_completion
- def execute(self, context: 'Context') -> str:
+ def execute(self, context: Context) -> str:
self.log.info("Creating event subscription '%s' to '%s'", self.subscription_name, self.sns_topic_arn)
create_subscription = self.hook.conn.create_event_subscription(
@@ -504,13 +454,9 @@ def execute(self, context: 'Context') -> str:
Enabled=self.enabled,
Tags=self.tags,
)
- self._await_status(
- 'event_subscription',
- self.subscription_name,
- wait_statuses=['creating'],
- ok_statuses=['active'],
- )
+ if self.wait_for_completion:
+ self.hook.wait_for_event_subscription_state(self.subscription_name, target_state="active")
return json.dumps(create_subscription, default=str)
@@ -538,7 +484,7 @@ def __init__(
self.subscription_name = subscription_name
- def execute(self, context: 'Context') -> str:
+ def execute(self, context: Context) -> str:
self.log.info(
"Deleting event subscription %s",
self.subscription_name,
@@ -551,6 +497,225 @@ def execute(self, context: 'Context') -> str:
return json.dumps(delete_subscription, default=str)
+class RdsCreateDbInstanceOperator(RdsBaseOperator):
+ """
+ Creates an RDS DB instance
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:RdsCreateDbInstanceOperator`
+
+ :param db_instance_identifier: The DB instance identifier, must start with a letter and
+ contain from 1 to 63 letters, numbers, or hyphens
+ :param db_instance_class: The compute and memory capacity of the DB instance, for example db.m5.large
+ :param engine: The name of the database engine to be used for this instance
+ :param rds_kwargs: Named arguments to pass to boto3 RDS client function ``create_db_instance``
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.create_db_instance
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ :param wait_for_completion: If True, waits for creation of the DB instance to complete. (default: True)
+ """
+
+ template_fields = ("db_instance_identifier", "db_instance_class", "engine", "rds_kwargs")
+
+ def __init__(
+ self,
+ *,
+ db_instance_identifier: str,
+ db_instance_class: str,
+ engine: str,
+ rds_kwargs: dict | None = None,
+ aws_conn_id: str = "aws_default",
+ wait_for_completion: bool = True,
+ **kwargs,
+ ):
+ super().__init__(aws_conn_id=aws_conn_id, **kwargs)
+
+ self.db_instance_identifier = db_instance_identifier
+ self.db_instance_class = db_instance_class
+ self.engine = engine
+ self.rds_kwargs = rds_kwargs or {}
+ self.wait_for_completion = wait_for_completion
+
+ def execute(self, context: Context) -> str:
+ self.log.info("Creating new DB instance %s", self.db_instance_identifier)
+
+ create_db_instance = self.hook.conn.create_db_instance(
+ DBInstanceIdentifier=self.db_instance_identifier,
+ DBInstanceClass=self.db_instance_class,
+ Engine=self.engine,
+ **self.rds_kwargs,
+ )
+
+ if self.wait_for_completion:
+ self.hook.wait_for_db_instance_state(self.db_instance_identifier, target_state="available")
+ return json.dumps(create_db_instance, default=str)
+
+
+class RdsDeleteDbInstanceOperator(RdsBaseOperator):
+ """
+ Deletes an RDS DB Instance
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:RdsDeleteDbInstanceOperator`
+
+ :param db_instance_identifier: The DB instance identifier for the DB instance to be deleted
+ :param rds_kwargs: Named arguments to pass to boto3 RDS client function ``delete_db_instance``
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/rds.html#RDS.Client.delete_db_instance
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ :param wait_for_completion: If True, waits for deletion of the DB instance to complete. (default: True)
+ """
+
+ template_fields = ("db_instance_identifier", "rds_kwargs")
+
+ def __init__(
+ self,
+ *,
+ db_instance_identifier: str,
+ rds_kwargs: dict | None = None,
+ aws_conn_id: str = "aws_default",
+ wait_for_completion: bool = True,
+ **kwargs,
+ ):
+ super().__init__(aws_conn_id=aws_conn_id, **kwargs)
+ self.db_instance_identifier = db_instance_identifier
+ self.rds_kwargs = rds_kwargs or {}
+ self.wait_for_completion = wait_for_completion
+
+ def execute(self, context: Context) -> str:
+ self.log.info("Deleting DB instance %s", self.db_instance_identifier)
+
+ delete_db_instance = self.hook.conn.delete_db_instance(
+ DBInstanceIdentifier=self.db_instance_identifier,
+ **self.rds_kwargs,
+ )
+
+ if self.wait_for_completion:
+ self.hook.wait_for_db_instance_state(self.db_instance_identifier, target_state="deleted")
+ return json.dumps(delete_db_instance, default=str)
+
+
+class RdsStartDbOperator(RdsBaseOperator):
+ """
+ Starts an RDS DB instance / cluster
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:RdsStartDbOperator`
+
+ :param db_identifier: The AWS identifier of the DB to start
+ :param db_type: Type of the DB - either "instance" or "cluster" (default: "instance")
+ :param aws_conn_id: The Airflow connection used for AWS credentials. (default: "aws_default")
+ :param wait_for_completion: If True, waits for DB to start. (default: True)
+ """
+
+ template_fields = ("db_identifier", "db_type")
+
+ def __init__(
+ self,
+ *,
+ db_identifier: str,
+ db_type: RdsDbType | str = RdsDbType.INSTANCE,
+ aws_conn_id: str = "aws_default",
+ wait_for_completion: bool = True,
+ **kwargs,
+ ):
+ super().__init__(aws_conn_id=aws_conn_id, **kwargs)
+ self.db_identifier = db_identifier
+ self.db_type = db_type
+ self.wait_for_completion = wait_for_completion
+
+ def execute(self, context: Context) -> str:
+ self.db_type = RdsDbType(self.db_type)
+ start_db_response = self._start_db()
+ if self.wait_for_completion:
+ self._wait_until_db_available()
+ return json.dumps(start_db_response, default=str)
+
+ def _start_db(self):
+ self.log.info("Starting DB %s '%s'", self.db_type.value, self.db_identifier)
+ if self.db_type == RdsDbType.INSTANCE:
+ response = self.hook.conn.start_db_instance(DBInstanceIdentifier=self.db_identifier)
+ else:
+ response = self.hook.conn.start_db_cluster(DBClusterIdentifier=self.db_identifier)
+ return response
+
+ def _wait_until_db_available(self):
+ self.log.info("Waiting for DB %s to reach 'available' state", self.db_type.value)
+ if self.db_type == RdsDbType.INSTANCE:
+ self.hook.wait_for_db_instance_state(self.db_identifier, target_state="available")
+ else:
+ self.hook.wait_for_db_cluster_state(self.db_identifier, target_state="available")
+
+
+class RdsStopDbOperator(RdsBaseOperator):
+ """
+ Stops an RDS DB instance / cluster
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:RdsStopDbOperator`
+
+ :param db_identifier: The AWS identifier of the DB to stop
+ :param db_type: Type of the DB - either "instance" or "cluster" (default: "instance")
+ :param db_snapshot_identifier: The instance identifier of the DB Snapshot to create before
+ stopping the DB instance. The default value (None) skips snapshot creation. This
+ parameter is ignored when ``db_type`` is "cluster"
+ :param aws_conn_id: The Airflow connection used for AWS credentials. (default: "aws_default")
+ :param wait_for_completion: If True, waits for DB to stop. (default: True)
+ """
+
+ template_fields = ("db_identifier", "db_snapshot_identifier", "db_type")
+
+ def __init__(
+ self,
+ *,
+ db_identifier: str,
+ db_type: RdsDbType | str = RdsDbType.INSTANCE,
+ db_snapshot_identifier: str | None = None,
+ aws_conn_id: str = "aws_default",
+ wait_for_completion: bool = True,
+ **kwargs,
+ ):
+ super().__init__(aws_conn_id=aws_conn_id, **kwargs)
+ self.db_identifier = db_identifier
+ self.db_type = db_type
+ self.db_snapshot_identifier = db_snapshot_identifier
+ self.wait_for_completion = wait_for_completion
+
+ def execute(self, context: Context) -> str:
+ self.db_type = RdsDbType(self.db_type)
+ stop_db_response = self._stop_db()
+ if self.wait_for_completion:
+ self._wait_until_db_stopped()
+ return json.dumps(stop_db_response, default=str)
+
+ def _stop_db(self):
+ self.log.info("Stopping DB %s '%s'", self.db_type.value, self.db_identifier)
+ if self.db_type == RdsDbType.INSTANCE:
+ conn_params = {"DBInstanceIdentifier": self.db_identifier}
+ # The db snapshot parameter is optional, but the AWS SDK raises an exception
+ # if passed a null value. Only set snapshot id if value is present.
+ if self.db_snapshot_identifier:
+ conn_params["DBSnapshotIdentifier"] = self.db_snapshot_identifier
+ response = self.hook.conn.stop_db_instance(**conn_params)
+ else:
+ if self.db_snapshot_identifier:
+ self.log.warning(
+ "'db_snapshot_identifier' does not apply to db clusters. "
+ "Remove it to silence this warning."
+ )
+ response = self.hook.conn.stop_db_cluster(DBClusterIdentifier=self.db_identifier)
+ return response
+
+ def _wait_until_db_stopped(self):
+ self.log.info("Waiting for DB %s to reach 'stopped' state", self.db_type.value)
+ if self.db_type == RdsDbType.INSTANCE:
+ self.hook.wait_for_db_instance_state(self.db_identifier, target_state="stopped")
+ else:
+ self.hook.wait_for_db_cluster_state(self.db_identifier, target_state="stopped")
+
+
__all__ = [
"RdsCreateDbSnapshotOperator",
"RdsCopyDbSnapshotOperator",
@@ -559,4 +724,8 @@ def execute(self, context: 'Context') -> str:
"RdsDeleteEventSubscriptionOperator",
"RdsStartExportTaskOperator",
"RdsCancelExportTaskOperator",
+ "RdsCreateDbInstanceOperator",
+ "RdsDeleteDbInstanceOperator",
+ "RdsStartDbOperator",
+ "RdsStopDbOperator",
]
diff --git a/airflow/providers/amazon/aws/operators/redshift.py b/airflow/providers/amazon/aws/operators/redshift.py
deleted file mode 100644
index efcf44a797dea..0000000000000
--- a/airflow/providers/amazon/aws/operators/redshift.py
+++ /dev/null
@@ -1,33 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-import warnings
-
-from airflow.providers.amazon.aws.operators.redshift_cluster import (
- RedshiftPauseClusterOperator,
- RedshiftResumeClusterOperator,
-)
-from airflow.providers.amazon.aws.operators.redshift_sql import RedshiftSQLOperator
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.redshift_sql` "
- "or `airflow.providers.amazon.aws.operators.redshift_cluster` as appropriate.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-__all__ = ["RedshiftSQLOperator", "RedshiftPauseClusterOperator", "RedshiftResumeClusterOperator"]
diff --git a/airflow/providers/amazon/aws/operators/redshift_cluster.py b/airflow/providers/amazon/aws/operators/redshift_cluster.py
index 340e9577efef4..2da0fbf23a73d 100644
--- a/airflow/providers/amazon/aws/operators/redshift_cluster.py
+++ b/airflow/providers/amazon/aws/operators/redshift_cluster.py
@@ -14,9 +14,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
import time
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence
+from typing import TYPE_CHECKING, Any, Sequence
+from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook
@@ -88,6 +91,7 @@ class RedshiftCreateClusterOperator(BaseOperator):
"cluster_type",
"node_type",
"number_of_nodes",
+ "vpc_security_group_ids",
)
ui_color = "#eeaa11"
ui_fgcolor = "#ffffff"
@@ -102,32 +106,32 @@ def __init__(
cluster_type: str = "multi-node",
db_name: str = "dev",
number_of_nodes: int = 1,
- cluster_security_groups: Optional[List[str]] = None,
- vpc_security_group_ids: Optional[List[str]] = None,
- cluster_subnet_group_name: Optional[str] = None,
- availability_zone: Optional[str] = None,
- preferred_maintenance_window: Optional[str] = None,
- cluster_parameter_group_name: Optional[str] = None,
+ cluster_security_groups: list[str] | None = None,
+ vpc_security_group_ids: list[str] | None = None,
+ cluster_subnet_group_name: str | None = None,
+ availability_zone: str | None = None,
+ preferred_maintenance_window: str | None = None,
+ cluster_parameter_group_name: str | None = None,
automated_snapshot_retention_period: int = 1,
- manual_snapshot_retention_period: Optional[int] = None,
+ manual_snapshot_retention_period: int | None = None,
port: int = 5439,
cluster_version: str = "1.0",
allow_version_upgrade: bool = True,
publicly_accessible: bool = True,
encrypted: bool = False,
- hsm_client_certificate_identifier: Optional[str] = None,
- hsm_configuration_identifier: Optional[str] = None,
- elastic_ip: Optional[str] = None,
- tags: Optional[List[Any]] = None,
- kms_key_id: Optional[str] = None,
+ hsm_client_certificate_identifier: str | None = None,
+ hsm_configuration_identifier: str | None = None,
+ elastic_ip: str | None = None,
+ tags: list[Any] | None = None,
+ kms_key_id: str | None = None,
enhanced_vpc_routing: bool = False,
- additional_info: Optional[str] = None,
- iam_roles: Optional[List[str]] = None,
- maintenance_track_name: Optional[str] = None,
- snapshot_schedule_identifier: Optional[str] = None,
- availability_zone_relocation: Optional[bool] = None,
- aqua_configuration_status: Optional[str] = None,
- default_iam_role_arn: Optional[str] = None,
+ additional_info: str | None = None,
+ iam_roles: list[str] | None = None,
+ maintenance_track_name: str | None = None,
+ snapshot_schedule_identifier: str | None = None,
+ availability_zone_relocation: bool | None = None,
+ aqua_configuration_status: str | None = None,
+ default_iam_role_arn: str | None = None,
aws_conn_id: str = "aws_default",
**kwargs,
):
@@ -168,10 +172,10 @@ def __init__(
self.aws_conn_id = aws_conn_id
self.kwargs = kwargs
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id)
self.log.info("Creating Redshift cluster %s", self.cluster_identifier)
- params: Dict[str, Any] = {}
+ params: dict[str, Any] = {}
if self.db_name:
params["DBName"] = self.db_name
if self.cluster_type:
@@ -242,6 +246,130 @@ def execute(self, context: 'Context'):
self.log.info(cluster)
+class RedshiftCreateClusterSnapshotOperator(BaseOperator):
+ """
+ Creates a manual snapshot of the specified cluster. The cluster must be in the available state
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:RedshiftCreateClusterSnapshotOperator`
+
+ :param snapshot_identifier: A unique identifier for the snapshot that you are requesting
+ :param cluster_identifier: The cluster identifier for which you want a snapshot
+ :param retention_period: The number of days that a manual snapshot is retained.
+ If the value is -1, the manual snapshot is retained indefinitely.
+ :param wait_for_completion: Whether wait for the cluster snapshot to be in ``available`` state
+ :param poll_interval: Time (in seconds) to wait between two consecutive calls to check state
+ :param max_attempt: The maximum number of attempts to be made to check the state
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ The default connection id is ``aws_default``
+ """
+
+ template_fields: Sequence[str] = (
+ "cluster_identifier",
+ "snapshot_identifier",
+ )
+
+ def __init__(
+ self,
+ *,
+ snapshot_identifier: str,
+ cluster_identifier: str,
+ retention_period: int = -1,
+ wait_for_completion: bool = False,
+ poll_interval: int = 15,
+ max_attempt: int = 20,
+ aws_conn_id: str = "aws_default",
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.snapshot_identifier = snapshot_identifier
+ self.cluster_identifier = cluster_identifier
+ self.retention_period = retention_period
+ self.wait_for_completion = wait_for_completion
+ self.poll_interval = poll_interval
+ self.max_attempt = max_attempt
+ self.redshift_hook = RedshiftHook(aws_conn_id=aws_conn_id)
+
+ def execute(self, context: Context) -> Any:
+ cluster_state = self.redshift_hook.cluster_status(cluster_identifier=self.cluster_identifier)
+ if cluster_state != "available":
+ raise AirflowException(
+ "Redshift cluster must be in available state. "
+ f"Redshift cluster current state is {cluster_state}"
+ )
+
+ self.redshift_hook.create_cluster_snapshot(
+ cluster_identifier=self.cluster_identifier,
+ snapshot_identifier=self.snapshot_identifier,
+ retention_period=self.retention_period,
+ )
+
+ if self.wait_for_completion:
+ self.redshift_hook.get_conn().get_waiter("snapshot_available").wait(
+ ClusterIdentifier=self.cluster_identifier,
+ WaiterConfig={
+ "Delay": self.poll_interval,
+ "MaxAttempts": self.max_attempt,
+ },
+ )
+
+
+class RedshiftDeleteClusterSnapshotOperator(BaseOperator):
+ """
+ Deletes the specified manual snapshot
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:RedshiftDeleteClusterSnapshotOperator`
+
+ :param snapshot_identifier: A unique identifier for the snapshot that you are requesting
+ :param cluster_identifier: The unique identifier of the cluster the snapshot was created from
+ :param wait_for_completion: Whether wait for cluster deletion or not
+ The default value is ``True``
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ The default connection id is ``aws_default``
+ :param poll_interval: Time (in seconds) to wait between two consecutive calls to check snapshot state
+ """
+
+ template_fields: Sequence[str] = (
+ "cluster_identifier",
+ "snapshot_identifier",
+ )
+
+ def __init__(
+ self,
+ *,
+ snapshot_identifier: str,
+ cluster_identifier: str,
+ wait_for_completion: bool = True,
+ aws_conn_id: str = "aws_default",
+ poll_interval: int = 10,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.snapshot_identifier = snapshot_identifier
+ self.cluster_identifier = cluster_identifier
+ self.wait_for_completion = wait_for_completion
+ self.poll_interval = poll_interval
+ self.redshift_hook = RedshiftHook(aws_conn_id=aws_conn_id)
+
+ def execute(self, context: Context) -> Any:
+ self.redshift_hook.get_conn().delete_cluster_snapshot(
+ SnapshotClusterIdentifier=self.cluster_identifier,
+ SnapshotIdentifier=self.snapshot_identifier,
+ )
+
+ if self.wait_for_completion:
+ while self.get_status() is not None:
+ time.sleep(self.poll_interval)
+
+ def get_status(self) -> str:
+ return self.redshift_hook.get_cluster_snapshot_status(
+ snapshot_identifier=self.snapshot_identifier,
+ )
+
+
class RedshiftResumeClusterOperator(BaseOperator):
"""
Resume a paused AWS Redshift Cluster
@@ -268,17 +396,27 @@ def __init__(
super().__init__(**kwargs)
self.cluster_identifier = cluster_identifier
self.aws_conn_id = aws_conn_id
+ # These parameters are added to address an issue with the boto3 API where the API
+ # prematurely reports the cluster as available to receive requests. This causes the cluster
+ # to reject initial attempts to resume the cluster despite reporting the correct state.
+ self._attempts = 10
+ self._attempt_interval = 15
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id)
- cluster_state = redshift_hook.cluster_status(cluster_identifier=self.cluster_identifier)
- if cluster_state == 'paused':
- self.log.info("Starting Redshift cluster %s", self.cluster_identifier)
- redshift_hook.get_conn().resume_cluster(ClusterIdentifier=self.cluster_identifier)
- else:
- self.log.warning(
- "Unable to resume cluster since cluster is currently in status: %s", cluster_state
- )
+
+ while self._attempts >= 1:
+ try:
+ redshift_hook.get_conn().resume_cluster(ClusterIdentifier=self.cluster_identifier)
+ return
+ except redshift_hook.get_conn().exceptions.InvalidClusterStateFault as error:
+ self._attempts = self._attempts - 1
+
+ if self._attempts > 0:
+ self.log.error("Unable to resume cluster. %d attempts remaining.", self._attempts)
+ time.sleep(self._attempt_interval)
+ else:
+ raise error
class RedshiftPauseClusterOperator(BaseOperator):
@@ -307,17 +445,27 @@ def __init__(
super().__init__(**kwargs)
self.cluster_identifier = cluster_identifier
self.aws_conn_id = aws_conn_id
+ # These parameters are added to address an issue with the boto3 API where the API
+ # prematurely reports the cluster as available to receive requests. This causes the cluster
+ # to reject initial attempts to pause the cluster despite reporting the correct state.
+ self._attempts = 10
+ self._attempt_interval = 15
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id)
- cluster_state = redshift_hook.cluster_status(cluster_identifier=self.cluster_identifier)
- if cluster_state == 'available':
- self.log.info("Pausing Redshift cluster %s", self.cluster_identifier)
- redshift_hook.get_conn().pause_cluster(ClusterIdentifier=self.cluster_identifier)
- else:
- self.log.warning(
- "Unable to pause cluster since cluster is currently in status: %s", cluster_state
- )
+
+ while self._attempts >= 1:
+ try:
+ redshift_hook.get_conn().pause_cluster(ClusterIdentifier=self.cluster_identifier)
+ return
+ except redshift_hook.get_conn().exceptions.InvalidClusterStateFault as error:
+ self._attempts = self._attempts - 1
+
+ if self._attempts > 0:
+ self.log.error("Unable to pause cluster. %d attempts remaining.", self._attempts)
+ time.sleep(self._attempt_interval)
+ else:
+ raise error
class RedshiftDeleteClusterOperator(BaseOperator):
@@ -346,7 +494,7 @@ def __init__(
*,
cluster_identifier: str,
skip_final_cluster_snapshot: bool = True,
- final_cluster_snapshot_identifier: Optional[str] = None,
+ final_cluster_snapshot_identifier: str | None = None,
wait_for_completion: bool = True,
aws_conn_id: str = "aws_default",
poll_interval: float = 30.0,
@@ -360,24 +508,16 @@ def __init__(
self.redshift_hook = RedshiftHook(aws_conn_id=aws_conn_id)
self.poll_interval = poll_interval
- def execute(self, context: 'Context'):
- self.delete_cluster()
-
- if self.wait_for_completion:
- cluster_status: str = self.check_status()
- while cluster_status != "cluster_not_found":
- self.log.info(
- "cluster status is %s. Sleeping for %s seconds.", cluster_status, self.poll_interval
- )
- time.sleep(self.poll_interval)
- cluster_status = self.check_status()
-
- def delete_cluster(self) -> None:
+ def execute(self, context: Context):
self.redshift_hook.delete_cluster(
cluster_identifier=self.cluster_identifier,
skip_final_cluster_snapshot=self.skip_final_cluster_snapshot,
final_cluster_snapshot_identifier=self.final_cluster_snapshot_identifier,
)
- def check_status(self) -> str:
- return self.redshift_hook.cluster_status(self.cluster_identifier)
+ if self.wait_for_completion:
+ waiter = self.redshift_hook.get_conn().get_waiter("cluster_deleted")
+ waiter.wait(
+ ClusterIdentifier=self.cluster_identifier,
+ WaiterConfig={"Delay": self.poll_interval, "MaxAttempts": 30},
+ )
diff --git a/airflow/providers/amazon/aws/operators/redshift_data.py b/airflow/providers/amazon/aws/operators/redshift_data.py
index f2d47da655835..416ce8146ef6d 100644
--- a/airflow/providers/amazon/aws/operators/redshift_data.py
+++ b/airflow/providers/amazon/aws/operators/redshift_data.py
@@ -15,17 +15,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-import sys
-from time import sleep
-from typing import TYPE_CHECKING, Any, Dict, Optional
+from __future__ import annotations
-if sys.version_info >= (3, 8):
- from functools import cached_property
-else:
- from cached_property import cached_property
+from time import sleep
+from typing import TYPE_CHECKING, Any
+from airflow.compat.functools import cached_property
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook
+from airflow.providers.amazon.aws.utils import trim_none_values
if TYPE_CHECKING:
from airflow.utils.context import Context
@@ -40,7 +38,7 @@ class RedshiftDataOperator(BaseOperator):
:ref:`howto/operator:RedshiftDataOperator`
:param database: the name of the database
- :param sql: the SQL statement text to run
+ :param sql: the SQL statement or list of SQL statement to run
:param cluster_identifier: unique identifier of a cluster
:param db_user: the database username
:param parameters: the parameters for the SQL statement
@@ -54,32 +52,32 @@ class RedshiftDataOperator(BaseOperator):
"""
template_fields = (
- 'cluster_identifier',
- 'database',
- 'sql',
- 'db_user',
- 'parameters',
- 'statement_name',
- 'aws_conn_id',
- 'region',
+ "cluster_identifier",
+ "database",
+ "sql",
+ "db_user",
+ "parameters",
+ "statement_name",
+ "aws_conn_id",
+ "region",
)
- template_ext = ('.sql',)
- template_fields_renderers = {'sql': 'sql'}
+ template_ext = (".sql",)
+ template_fields_renderers = {"sql": "sql"}
def __init__(
self,
database: str,
- sql: str,
- cluster_identifier: Optional[str] = None,
- db_user: Optional[str] = None,
- parameters: Optional[list] = None,
- secret_arn: Optional[str] = None,
- statement_name: Optional[str] = None,
+ sql: str | list,
+ cluster_identifier: str | None = None,
+ db_user: str | None = None,
+ parameters: list | None = None,
+ secret_arn: str | None = None,
+ statement_name: str | None = None,
with_event: bool = False,
await_result: bool = True,
poll_interval: int = 10,
- aws_conn_id: str = 'aws_default',
- region: Optional[str] = None,
+ aws_conn_id: str = "aws_default",
+ region: str | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -109,7 +107,7 @@ def hook(self) -> RedshiftDataHook:
return RedshiftDataHook(aws_conn_id=self.aws_conn_id, region_name=self.region)
def execute_query(self):
- kwargs: Dict[str, Any] = {
+ kwargs: dict[str, Any] = {
"ClusterIdentifier": self.cluster_identifier,
"Database": self.database,
"Sql": self.sql,
@@ -120,9 +118,22 @@ def execute_query(self):
"StatementName": self.statement_name,
}
- filter_values = {key: val for key, val in kwargs.items() if val is not None}
- resp = self.hook.conn.execute_statement(**filter_values)
- return resp['Id']
+ resp = self.hook.conn.execute_statement(**trim_none_values(kwargs))
+ return resp["Id"]
+
+ def execute_batch_query(self):
+ kwargs: dict[str, Any] = {
+ "ClusterIdentifier": self.cluster_identifier,
+ "Database": self.database,
+ "Sqls": self.sql,
+ "DbUser": self.db_user,
+ "Parameters": self.parameters,
+ "WithEvent": self.with_event,
+ "SecretArn": self.secret_arn,
+ "StatementName": self.statement_name,
+ }
+ resp = self.hook.conn.batch_execute_statement(**trim_none_values(kwargs))
+ return resp["Id"]
def wait_for_results(self, statement_id):
while True:
@@ -130,20 +141,22 @@ def wait_for_results(self, statement_id):
resp = self.hook.conn.describe_statement(
Id=statement_id,
)
- status = resp['Status']
- if status == 'FINISHED':
+ status = resp["Status"]
+ if status == "FINISHED":
return status
- elif status == 'FAILED' or status == 'ABORTED':
+ elif status == "FAILED" or status == "ABORTED":
raise ValueError(f"Statement {statement_id!r} terminated with status {status}.")
else:
self.log.info("Query %s", status)
sleep(self.poll_interval)
- def execute(self, context: 'Context') -> None:
+ def execute(self, context: Context) -> None:
"""Execute a statement against Amazon Redshift"""
self.log.info("Executing statement: %s", self.sql)
-
- self.statement_id = self.execute_query()
+ if isinstance(self.sql, list):
+ self.statement_id = self.execute_batch_query()
+ else:
+ self.statement_id = self.execute_query()
if self.await_result:
self.wait_for_results(self.statement_id)
@@ -153,10 +166,10 @@ def execute(self, context: 'Context') -> None:
def on_kill(self) -> None:
"""Cancel the submitted redshift query"""
if self.statement_id:
- self.log.info('Received a kill signal.')
- self.log.info('Stopping Query with statementId - %s', self.statement_id)
+ self.log.info("Received a kill signal.")
+ self.log.info("Stopping Query with statementId - %s", self.statement_id)
try:
self.hook.conn.cancel_statement(Id=self.statement_id)
except Exception as ex:
- self.log.error('Unable to cancel query. Exiting. %s', ex)
+ self.log.error("Unable to cancel query. Exiting. %s", ex)
diff --git a/airflow/providers/amazon/aws/operators/redshift_sql.py b/airflow/providers/amazon/aws/operators/redshift_sql.py
index c7ad77acb5341..b425ac262d17f 100644
--- a/airflow/providers/amazon/aws/operators/redshift_sql.py
+++ b/airflow/providers/amazon/aws/operators/redshift_sql.py
@@ -14,18 +14,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
-from typing import TYPE_CHECKING, Iterable, Optional, Sequence, Union
+import warnings
+from typing import Sequence
-from airflow.models import BaseOperator
-from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook
-from airflow.www import utils as wwwutils
+from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator
-if TYPE_CHECKING:
- from airflow.utils.context import Context
-
-class RedshiftSQLOperator(BaseOperator):
+class RedshiftSQLOperator(SQLExecuteQueryOperator):
"""
Executes SQL Statements against an Amazon Redshift cluster
@@ -43,36 +40,18 @@ class RedshiftSQLOperator(BaseOperator):
(default value: False)
"""
- template_fields: Sequence[str] = ('sql',)
- template_ext: Sequence[str] = ('.sql',)
- # TODO: Remove renderer check when the provider has an Airflow 2.3+ requirement.
- template_fields_renderers = {
- "sql": "postgresql" if "postgresql" in wwwutils.get_attr_renderer() else "sql"
- }
-
- def __init__(
- self,
- *,
- sql: Union[str, Iterable[str]],
- redshift_conn_id: str = 'redshift_default',
- parameters: Optional[dict] = None,
- autocommit: bool = True,
- **kwargs,
- ) -> None:
- super().__init__(**kwargs)
- self.redshift_conn_id = redshift_conn_id
- self.sql = sql
- self.autocommit = autocommit
- self.parameters = parameters
-
- def get_hook(self) -> RedshiftSQLHook:
- """Create and return RedshiftSQLHook.
- :return RedshiftSQLHook: A RedshiftSQLHook instance.
- """
- return RedshiftSQLHook(redshift_conn_id=self.redshift_conn_id)
-
- def execute(self, context: 'Context') -> None:
- """Execute a statement against Amazon Redshift"""
- self.log.info("Executing statement: %s", self.sql)
- hook = self.get_hook()
- hook.run(self.sql, autocommit=self.autocommit, parameters=self.parameters)
+ template_fields: Sequence[str] = (
+ "sql",
+ "conn_id",
+ )
+ template_ext: Sequence[str] = (".sql",)
+ template_fields_renderers = {"sql": "postgresql"}
+
+ def __init__(self, *, redshift_conn_id: str = "redshift_default", **kwargs) -> None:
+ super().__init__(conn_id=redshift_conn_id, **kwargs)
+ warnings.warn(
+ """This class is deprecated.
+ Please use `airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator`.""",
+ DeprecationWarning,
+ stacklevel=2,
+ )
diff --git a/airflow/providers/amazon/aws/operators/s3.py b/airflow/providers/amazon/aws/operators/s3.py
index 7fbef6629c7f7..91732084017d1 100644
--- a/airflow/providers/amazon/aws/operators/s3.py
+++ b/airflow/providers/amazon/aws/operators/s3.py
@@ -15,13 +15,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-#
-
"""This module contains AWS S3 operators."""
+from __future__ import annotations
+
import subprocess
import sys
from tempfile import NamedTemporaryFile
-from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Union
+from typing import TYPE_CHECKING, Sequence
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
@@ -57,17 +57,16 @@ def __init__(
self,
*,
bucket_name: str,
- aws_conn_id: Optional[str] = "aws_default",
- region_name: Optional[str] = None,
+ aws_conn_id: str | None = "aws_default",
+ region_name: str | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.bucket_name = bucket_name
self.region_name = region_name
self.aws_conn_id = aws_conn_id
- self.region_name = region_name
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
if not s3_hook.check_for_bucket(self.bucket_name):
s3_hook.create_bucket(bucket_name=self.bucket_name, region_name=self.region_name)
@@ -99,7 +98,7 @@ def __init__(
self,
bucket_name: str,
force_delete: bool = False,
- aws_conn_id: Optional[str] = "aws_default",
+ aws_conn_id: str | None = "aws_default",
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -107,7 +106,7 @@ def __init__(
self.force_delete = force_delete
self.aws_conn_id = aws_conn_id
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
s3_hook = S3Hook(aws_conn_id=self.aws_conn_id)
if s3_hook.check_for_bucket(self.bucket_name):
s3_hook.delete_bucket(bucket_name=self.bucket_name, force_delete=self.force_delete)
@@ -134,12 +133,12 @@ class S3GetBucketTaggingOperator(BaseOperator):
template_fields: Sequence[str] = ("bucket_name",)
- def __init__(self, bucket_name: str, aws_conn_id: Optional[str] = "aws_default", **kwargs) -> None:
+ def __init__(self, bucket_name: str, aws_conn_id: str | None = "aws_default", **kwargs) -> None:
super().__init__(**kwargs)
self.bucket_name = bucket_name
self.aws_conn_id = aws_conn_id
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
s3_hook = S3Hook(aws_conn_id=self.aws_conn_id)
if s3_hook.check_for_bucket(self.bucket_name):
@@ -177,10 +176,10 @@ class S3PutBucketTaggingOperator(BaseOperator):
def __init__(
self,
bucket_name: str,
- key: Optional[str] = None,
- value: Optional[str] = None,
- tag_set: Optional[List[Dict[str, str]]] = None,
- aws_conn_id: Optional[str] = "aws_default",
+ key: str | None = None,
+ value: str | None = None,
+ tag_set: list[dict[str, str]] | None = None,
+ aws_conn_id: str | None = "aws_default",
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -190,7 +189,7 @@ def __init__(
self.bucket_name = bucket_name
self.aws_conn_id = aws_conn_id
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
s3_hook = S3Hook(aws_conn_id=self.aws_conn_id)
if s3_hook.check_for_bucket(self.bucket_name):
@@ -221,12 +220,12 @@ class S3DeleteBucketTaggingOperator(BaseOperator):
template_fields: Sequence[str] = ("bucket_name",)
- def __init__(self, bucket_name: str, aws_conn_id: Optional[str] = "aws_default", **kwargs) -> None:
+ def __init__(self, bucket_name: str, aws_conn_id: str | None = "aws_default", **kwargs) -> None:
super().__init__(**kwargs)
self.bucket_name = bucket_name
self.aws_conn_id = aws_conn_id
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
s3_hook = S3Hook(aws_conn_id=self.aws_conn_id)
if s3_hook.check_for_bucket(self.bucket_name):
@@ -280,10 +279,10 @@ class S3CopyObjectOperator(BaseOperator):
"""
template_fields: Sequence[str] = (
- 'source_bucket_key',
- 'dest_bucket_key',
- 'source_bucket_name',
- 'dest_bucket_name',
+ "source_bucket_key",
+ "dest_bucket_key",
+ "source_bucket_name",
+ "dest_bucket_name",
)
def __init__(
@@ -291,12 +290,12 @@ def __init__(
*,
source_bucket_key: str,
dest_bucket_key: str,
- source_bucket_name: Optional[str] = None,
- dest_bucket_name: Optional[str] = None,
- source_version_id: Optional[str] = None,
- aws_conn_id: str = 'aws_default',
- verify: Optional[Union[str, bool]] = None,
- acl_policy: Optional[str] = None,
+ source_bucket_name: str | None = None,
+ dest_bucket_name: str | None = None,
+ source_version_id: str | None = None,
+ aws_conn_id: str = "aws_default",
+ verify: str | bool | None = None,
+ acl_policy: str | None = None,
**kwargs,
):
super().__init__(**kwargs)
@@ -310,7 +309,7 @@ def __init__(
self.verify = verify
self.acl_policy = acl_policy
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
s3_hook.copy_object(
self.source_bucket_key,
@@ -360,21 +359,21 @@ class S3CreateObjectOperator(BaseOperator):
"""
- template_fields: Sequence[str] = ('s3_bucket', 's3_key', 'data')
+ template_fields: Sequence[str] = ("s3_bucket", "s3_key", "data")
def __init__(
self,
*,
- s3_bucket: Optional[str] = None,
+ s3_bucket: str | None = None,
s3_key: str,
- data: Union[str, bytes],
+ data: str | bytes,
replace: bool = False,
encrypt: bool = False,
- acl_policy: Optional[str] = None,
- encoding: Optional[str] = None,
- compression: Optional[str] = None,
- aws_conn_id: str = 'aws_default',
- verify: Optional[Union[str, bool]] = None,
+ acl_policy: str | None = None,
+ encoding: str | None = None,
+ compression: str | None = None,
+ aws_conn_id: str = "aws_default",
+ verify: str | bool | None = None,
**kwargs,
):
super().__init__(**kwargs)
@@ -390,10 +389,10 @@ def __init__(
self.aws_conn_id = aws_conn_id
self.verify = verify
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
- s3_bucket, s3_key = s3_hook.get_s3_bucket_key(self.s3_bucket, self.s3_key, 'dest_bucket', 'dest_key')
+ s3_bucket, s3_key = s3_hook.get_s3_bucket_key(self.s3_bucket, self.s3_key, "dest_bucket", "dest_key")
if isinstance(self.data, str):
s3_hook.load_string(
@@ -444,16 +443,16 @@ class S3DeleteObjectsOperator(BaseOperator):
CA cert bundle than the one used by botocore.
"""
- template_fields: Sequence[str] = ('keys', 'bucket', 'prefix')
+ template_fields: Sequence[str] = ("keys", "bucket", "prefix")
def __init__(
self,
*,
bucket: str,
- keys: Optional[Union[str, list]] = None,
- prefix: Optional[str] = None,
- aws_conn_id: str = 'aws_default',
- verify: Optional[Union[str, bool]] = None,
+ keys: str | list | None = None,
+ prefix: str | None = None,
+ aws_conn_id: str = "aws_default",
+ verify: str | bool | None = None,
**kwargs,
):
@@ -467,7 +466,7 @@ def __init__(
if not bool(keys is None) ^ bool(prefix is None):
raise AirflowException("Either keys or prefix should be set.")
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
if not bool(self.keys is None) ^ bool(self.prefix is None):
raise AirflowException("Either keys or prefix should be set.")
@@ -525,22 +524,22 @@ class S3FileTransformOperator(BaseOperator):
:param replace: Replace dest S3 key if it already exists
"""
- template_fields: Sequence[str] = ('source_s3_key', 'dest_s3_key', 'script_args')
+ template_fields: Sequence[str] = ("source_s3_key", "dest_s3_key", "script_args")
template_ext: Sequence[str] = ()
- ui_color = '#f9c915'
+ ui_color = "#f9c915"
def __init__(
self,
*,
source_s3_key: str,
dest_s3_key: str,
- transform_script: Optional[str] = None,
+ transform_script: str | None = None,
select_expression=None,
- script_args: Optional[Sequence[str]] = None,
- source_aws_conn_id: str = 'aws_default',
- source_verify: Optional[Union[bool, str]] = None,
- dest_aws_conn_id: str = 'aws_default',
- dest_verify: Optional[Union[bool, str]] = None,
+ script_args: Sequence[str] | None = None,
+ source_aws_conn_id: str = "aws_default",
+ source_verify: bool | str | None = None,
+ dest_aws_conn_id: str = "aws_default",
+ dest_verify: bool | str | None = None,
replace: bool = False,
**kwargs,
) -> None:
@@ -558,7 +557,7 @@ def __init__(
self.script_args = script_args or []
self.output_encoding = sys.getdefaultencoding()
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
if self.transform_script is None and self.select_expression is None:
raise AirflowException("Either transform_script or select_expression must be specified")
@@ -589,7 +588,7 @@ def execute(self, context: 'Context'):
) as process:
self.log.info("Output:")
if process.stdout is not None:
- for line in iter(process.stdout.readline, b''):
+ for line in iter(process.stdout.readline, b""):
self.log.info(line.decode(self.output_encoding).rstrip())
process.wait()
@@ -653,17 +652,17 @@ class S3ListOperator(BaseOperator):
)
"""
- template_fields: Sequence[str] = ('bucket', 'prefix', 'delimiter')
- ui_color = '#ffd700'
+ template_fields: Sequence[str] = ("bucket", "prefix", "delimiter")
+ ui_color = "#ffd700"
def __init__(
self,
*,
bucket: str,
- prefix: str = '',
- delimiter: str = '',
- aws_conn_id: str = 'aws_default',
- verify: Optional[Union[str, bool]] = None,
+ prefix: str = "",
+ delimiter: str = "",
+ aws_conn_id: str = "aws_default",
+ verify: str | bool | None = None,
**kwargs,
):
super().__init__(**kwargs)
@@ -673,11 +672,11 @@ def __init__(
self.aws_conn_id = aws_conn_id
self.verify = verify
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
self.log.info(
- 'Getting the list of files from bucket: %s in prefix: %s (Delimiter %s)',
+ "Getting the list of files from bucket: %s in prefix: %s (Delimiter %s)",
self.bucket,
self.prefix,
self.delimiter,
@@ -727,8 +726,8 @@ class S3ListPrefixesOperator(BaseOperator):
)
"""
- template_fields: Sequence[str] = ('bucket', 'prefix', 'delimiter')
- ui_color = '#ffd700'
+ template_fields: Sequence[str] = ("bucket", "prefix", "delimiter")
+ ui_color = "#ffd700"
def __init__(
self,
@@ -736,8 +735,8 @@ def __init__(
bucket: str,
prefix: str,
delimiter: str,
- aws_conn_id: str = 'aws_default',
- verify: Optional[Union[str, bool]] = None,
+ aws_conn_id: str = "aws_default",
+ verify: str | bool | None = None,
**kwargs,
):
super().__init__(**kwargs)
@@ -747,11 +746,11 @@ def __init__(
self.aws_conn_id = aws_conn_id
self.verify = verify
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
self.log.info(
- 'Getting the list of subfolders from bucket: %s in prefix: %s (Delimiter %s)',
+ "Getting the list of subfolders from bucket: %s in prefix: %s (Delimiter %s)",
self.bucket,
self.prefix,
self.delimiter,
diff --git a/airflow/providers/amazon/aws/operators/s3_bucket.py b/airflow/providers/amazon/aws/operators/s3_bucket.py
deleted file mode 100644
index e5806fa780848..0000000000000
--- a/airflow/providers/amazon/aws/operators/s3_bucket.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.s3`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.operators.s3 import S3CreateBucketOperator, S3DeleteBucketOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.s3`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/providers/amazon/aws/operators/s3_bucket_tagging.py b/airflow/providers/amazon/aws/operators/s3_bucket_tagging.py
deleted file mode 100644
index fb214412643ae..0000000000000
--- a/airflow/providers/amazon/aws/operators/s3_bucket_tagging.py
+++ /dev/null
@@ -1,32 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.s3`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.operators.s3 import ( # noqa
- S3DeleteBucketTaggingOperator,
- S3GetBucketTaggingOperator,
- S3PutBucketTaggingOperator,
-)
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.s3`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/providers/amazon/aws/operators/s3_copy_object.py b/airflow/providers/amazon/aws/operators/s3_copy_object.py
deleted file mode 100644
index 298f8d21724cb..0000000000000
--- a/airflow/providers/amazon/aws/operators/s3_copy_object.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.s3`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.operators.s3 import S3CopyObjectOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.s3`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/providers/amazon/aws/operators/s3_delete_objects.py b/airflow/providers/amazon/aws/operators/s3_delete_objects.py
deleted file mode 100644
index 35d86893eb377..0000000000000
--- a/airflow/providers/amazon/aws/operators/s3_delete_objects.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.s3`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.operators.s3 import S3DeleteObjectsOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.s3`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/providers/amazon/aws/operators/s3_file_transform.py b/airflow/providers/amazon/aws/operators/s3_file_transform.py
deleted file mode 100644
index 4400b202bbda0..0000000000000
--- a/airflow/providers/amazon/aws/operators/s3_file_transform.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.s3`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.operators.s3 import S3FileTransformOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.s3`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/providers/amazon/aws/operators/s3_list.py b/airflow/providers/amazon/aws/operators/s3_list.py
deleted file mode 100644
index c114a0b814886..0000000000000
--- a/airflow/providers/amazon/aws/operators/s3_list.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.s3`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.operators.s3 import S3ListOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.s3`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/providers/amazon/aws/operators/s3_list_prefixes.py b/airflow/providers/amazon/aws/operators/s3_list_prefixes.py
deleted file mode 100644
index 5e94f0c17f471..0000000000000
--- a/airflow/providers/amazon/aws/operators/s3_list_prefixes.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.s3`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.operators.s3 import S3ListPrefixesOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.s3`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py b/airflow/providers/amazon/aws/operators/sagemaker.py
index 084e3e857e9be..c5f08db049206 100644
--- a/airflow/providers/amazon/aws/operators/sagemaker.py
+++ b/airflow/providers/amazon/aws/operators/sagemaker.py
@@ -14,28 +14,29 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
import json
-import sys
-from typing import TYPE_CHECKING, Any, List, Optional, Sequence
+from typing import TYPE_CHECKING, Any, Sequence
from botocore.exceptions import ClientError
+from airflow.compat.functools import cached_property
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
-
-if sys.version_info >= (3, 8):
- from functools import cached_property
-else:
- from cached_property import cached_property
+from airflow.utils.json import AirflowJsonEncoder
if TYPE_CHECKING:
from airflow.utils.context import Context
-DEFAULT_CONN_ID = 'aws_default'
-CHECK_INTERVAL_SECOND = 30
+DEFAULT_CONN_ID: str = "aws_default"
+CHECK_INTERVAL_SECOND: int = 30
+
+
+def serialize(result: dict) -> str:
+ return json.loads(json.dumps(result, cls=AirflowJsonEncoder))
class SageMakerBaseOperator(BaseOperator):
@@ -44,17 +45,18 @@ class SageMakerBaseOperator(BaseOperator):
:param config: The configuration necessary to start a training job (templated)
"""
- template_fields: Sequence[str] = ('config',)
+ template_fields: Sequence[str] = ("config",)
template_ext: Sequence[str] = ()
- template_fields_renderers = {'config': 'json'}
- ui_color = '#ededed'
- integer_fields: List[List[Any]] = []
+ template_fields_renderers: dict = {"config": "json"}
+ ui_color: str = "#ededed"
+ integer_fields: list[list[Any]] = []
- def __init__(self, *, config: dict, **kwargs):
+ def __init__(self, *, config: dict, aws_conn_id: str = DEFAULT_CONN_ID, **kwargs):
super().__init__(**kwargs)
self.config = config
+ self.aws_conn_id = aws_conn_id
- def parse_integer(self, config, field):
+ def parse_integer(self, config: dict, field: list[str] | str) -> None:
"""Recursive method for parsing string fields holding integer values to integers."""
if len(field) == 1:
if isinstance(config, list):
@@ -74,34 +76,39 @@ def parse_integer(self, config, field):
self.parse_integer(config[head], tail)
return
- def parse_config_integers(self):
- """
- Parse the integer fields of training config to integers in case the config is rendered by Jinja and
- all fields are str
- """
+ def parse_config_integers(self) -> None:
+ """Parse the integer fields to ints in case the config is rendered by Jinja and all fields are str."""
for field in self.integer_fields:
self.parse_integer(self.config, field)
- def expand_role(self):
+ def expand_role(self) -> None:
"""Placeholder for calling boto3's `expand_role`, which expands an IAM role name into an ARN."""
- def preprocess_config(self):
+ def preprocess_config(self) -> None:
"""Process the config into a usable form."""
- self.log.info('Preprocessing the config and doing required s3_operations')
+ self._create_integer_fields()
+ self.log.info("Preprocessing the config and doing required s3_operations")
self.hook.configure_s3_resources(self.config)
self.parse_config_integers()
self.expand_role()
self.log.info(
- 'After preprocessing the config is:\n %s',
- json.dumps(self.config, sort_keys=True, indent=4, separators=(',', ': ')),
+ "After preprocessing the config is:\n %s",
+ json.dumps(self.config, sort_keys=True, indent=4, separators=(",", ": ")),
)
- def execute(self, context: 'Context'):
- raise NotImplementedError('Please implement execute() in sub class!')
+ def _create_integer_fields(self) -> None:
+ """
+ Set fields which should be cast to integers.
+ Child classes should override this method if they need integer fields parsed.
+ """
+ self.integer_fields = []
+
+ def execute(self, context: Context) -> None | dict:
+ raise NotImplementedError("Please implement execute() in sub class!")
@cached_property
def hook(self):
- """Return SageMakerHook"""
+ """Return SageMakerHook."""
return SageMakerHook(aws_conn_id=self.aws_conn_id)
@@ -140,55 +147,64 @@ def __init__(
wait_for_completion: bool = True,
print_log: bool = True,
check_interval: int = CHECK_INTERVAL_SECOND,
- max_ingestion_time: Optional[int] = None,
- action_if_job_exists: str = 'increment',
+ max_ingestion_time: int | None = None,
+ action_if_job_exists: str = "increment",
**kwargs,
):
- super().__init__(config=config, **kwargs)
- if action_if_job_exists not in ('increment', 'fail'):
+ super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs)
+ if action_if_job_exists not in ("increment", "fail"):
raise AirflowException(
f"Argument action_if_job_exists accepts only 'increment' and 'fail'. \
Provided value: '{action_if_job_exists}'."
)
self.action_if_job_exists = action_if_job_exists
- self.aws_conn_id = aws_conn_id
self.wait_for_completion = wait_for_completion
self.print_log = print_log
self.check_interval = check_interval
self.max_ingestion_time = max_ingestion_time
- self._create_integer_fields()
def _create_integer_fields(self) -> None:
- """Set fields which should be casted to integers."""
- self.integer_fields = [
- ['ProcessingResources', 'ClusterConfig', 'InstanceCount'],
- ['ProcessingResources', 'ClusterConfig', 'VolumeSizeInGB'],
+ """Set fields which should be cast to integers."""
+ self.integer_fields: list[list[str] | list[list[str]]] = [
+ ["ProcessingResources", "ClusterConfig", "InstanceCount"],
+ ["ProcessingResources", "ClusterConfig", "VolumeSizeInGB"],
]
- if 'StoppingCondition' in self.config:
- self.integer_fields += [['StoppingCondition', 'MaxRuntimeInSeconds']]
+ if "StoppingCondition" in self.config:
+ self.integer_fields.append(["StoppingCondition", "MaxRuntimeInSeconds"])
def expand_role(self) -> None:
- if 'RoleArn' in self.config:
- hook = AwsBaseHook(self.aws_conn_id, client_type='iam')
- self.config['RoleArn'] = hook.expand_role(self.config['RoleArn'])
+ """Expands an IAM role name into an ARN."""
+ if "RoleArn" in self.config:
+ hook = AwsBaseHook(self.aws_conn_id, client_type="iam")
+ self.config["RoleArn"] = hook.expand_role(self.config["RoleArn"])
- def execute(self, context: 'Context') -> dict:
+ def execute(self, context: Context) -> dict:
self.preprocess_config()
- processing_job_name = self.config['ProcessingJobName']
- if self.hook.find_processing_job_by_name(processing_job_name):
- raise AirflowException(
- f'A SageMaker processing job with name {processing_job_name} already exists.'
- )
- self.log.info('Creating SageMaker processing job %s.', self.config['ProcessingJobName'])
+ processing_job_name = self.config["ProcessingJobName"]
+ processing_job_dedupe_pattern = "-[0-9]+$"
+ existing_jobs_found = self.hook.count_processing_jobs_by_name(
+ processing_job_name, processing_job_dedupe_pattern
+ )
+ if existing_jobs_found:
+ if self.action_if_job_exists == "fail":
+ raise AirflowException(
+ f"A SageMaker processing job with name {processing_job_name} already exists."
+ )
+ elif self.action_if_job_exists == "increment":
+ self.log.info("Found existing processing job with name '%s'.", processing_job_name)
+ new_processing_job_name = f"{processing_job_name}-{existing_jobs_found + 1}"
+ self.config["ProcessingJobName"] = new_processing_job_name
+ self.log.info("Incremented processing job name to '%s'.", new_processing_job_name)
+
response = self.hook.create_processing_job(
self.config,
wait_for_completion=self.wait_for_completion,
check_interval=self.check_interval,
max_ingestion_time=self.max_ingestion_time,
)
- if response['ResponseMetadata']['HTTPStatusCode'] != 200:
- raise AirflowException(f'Sagemaker Processing Job creation failed: {response}')
- return {'Processing': self.hook.describe_processing_job(self.config['ProcessingJobName'])}
+ if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
+ raise AirflowException(f"Sagemaker Processing Job creation failed: {response}")
+ return {"Processing": serialize(self.hook.describe_processing_job(self.config["ProcessingJobName"]))}
class SageMakerEndpointConfigOperator(SageMakerBaseOperator):
@@ -209,8 +225,6 @@ class SageMakerEndpointConfigOperator(SageMakerBaseOperator):
:return Dict: Returns The ARN of the endpoint config created in Amazon SageMaker.
"""
- integer_fields = [['ProductionVariants', 'InitialInstanceCount']]
-
def __init__(
self,
*,
@@ -218,18 +232,24 @@ def __init__(
aws_conn_id: str = DEFAULT_CONN_ID,
**kwargs,
):
- super().__init__(config=config, **kwargs)
- self.config = config
- self.aws_conn_id = aws_conn_id
+ super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs)
+
+ def _create_integer_fields(self) -> None:
+ """Set fields which should be cast to integers."""
+ self.integer_fields: list[list[str]] = [["ProductionVariants", "InitialInstanceCount"]]
- def execute(self, context: 'Context') -> dict:
+ def execute(self, context: Context) -> dict:
self.preprocess_config()
- self.log.info('Creating SageMaker Endpoint Config %s.', self.config['EndpointConfigName'])
+ self.log.info("Creating SageMaker Endpoint Config %s.", self.config["EndpointConfigName"])
response = self.hook.create_endpoint_config(self.config)
- if response['ResponseMetadata']['HTTPStatusCode'] != 200:
- raise AirflowException(f'Sagemaker endpoint config creation failed: {response}')
+ if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
+ raise AirflowException(f"Sagemaker endpoint config creation failed: {response}")
else:
- return {'EndpointConfig': self.hook.describe_endpoint_config(self.config['EndpointConfigName'])}
+ return {
+ "EndpointConfig": serialize(
+ self.hook.describe_endpoint_config(self.config["EndpointConfigName"])
+ )
+ }
class SageMakerEndpointOperator(SageMakerBaseOperator):
@@ -287,54 +307,54 @@ def __init__(
aws_conn_id: str = DEFAULT_CONN_ID,
wait_for_completion: bool = True,
check_interval: int = CHECK_INTERVAL_SECOND,
- max_ingestion_time: Optional[int] = None,
- operation: str = 'create',
+ max_ingestion_time: int | None = None,
+ operation: str = "create",
**kwargs,
):
- super().__init__(config=config, **kwargs)
- self.config = config
- self.aws_conn_id = aws_conn_id
+ super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs)
self.wait_for_completion = wait_for_completion
self.check_interval = check_interval
self.max_ingestion_time = max_ingestion_time
self.operation = operation.lower()
- if self.operation not in ['create', 'update']:
+ if self.operation not in ["create", "update"]:
raise ValueError('Invalid value! Argument operation has to be one of "create" and "update"')
- self.create_integer_fields()
- def create_integer_fields(self) -> None:
- """Set fields which should be casted to integers."""
- if 'EndpointConfig' in self.config:
- self.integer_fields = [['EndpointConfig', 'ProductionVariants', 'InitialInstanceCount']]
+ def _create_integer_fields(self) -> None:
+ """Set fields which should be cast to integers."""
+ if "EndpointConfig" in self.config:
+ self.integer_fields: list[list[str]] = [
+ ["EndpointConfig", "ProductionVariants", "InitialInstanceCount"]
+ ]
def expand_role(self) -> None:
- if 'Model' not in self.config:
+ """Expands an IAM role name into an ARN."""
+ if "Model" not in self.config:
return
- hook = AwsBaseHook(self.aws_conn_id, client_type='iam')
- config = self.config['Model']
- if 'ExecutionRoleArn' in config:
- config['ExecutionRoleArn'] = hook.expand_role(config['ExecutionRoleArn'])
+ hook = AwsBaseHook(self.aws_conn_id, client_type="iam")
+ config = self.config["Model"]
+ if "ExecutionRoleArn" in config:
+ config["ExecutionRoleArn"] = hook.expand_role(config["ExecutionRoleArn"])
- def execute(self, context: 'Context') -> dict:
+ def execute(self, context: Context) -> dict:
self.preprocess_config()
- model_info = self.config.get('Model')
- endpoint_config_info = self.config.get('EndpointConfig')
- endpoint_info = self.config.get('Endpoint', self.config)
+ model_info = self.config.get("Model")
+ endpoint_config_info = self.config.get("EndpointConfig")
+ endpoint_info = self.config.get("Endpoint", self.config)
if model_info:
- self.log.info('Creating SageMaker model %s.', model_info['ModelName'])
+ self.log.info("Creating SageMaker model %s.", model_info["ModelName"])
self.hook.create_model(model_info)
if endpoint_config_info:
- self.log.info('Creating endpoint config %s.', endpoint_config_info['EndpointConfigName'])
+ self.log.info("Creating endpoint config %s.", endpoint_config_info["EndpointConfigName"])
self.hook.create_endpoint_config(endpoint_config_info)
- if self.operation == 'create':
+ if self.operation == "create":
sagemaker_operation = self.hook.create_endpoint
- log_str = 'Creating'
- elif self.operation == 'update':
+ log_str = "Creating"
+ elif self.operation == "update":
sagemaker_operation = self.hook.update_endpoint
- log_str = 'Updating'
+ log_str = "Updating"
else:
raise ValueError('Invalid value! Argument operation has to be one of "create" and "update"')
- self.log.info('%s SageMaker endpoint %s.', log_str, endpoint_info['EndpointName'])
+ self.log.info("%s SageMaker endpoint %s.", log_str, endpoint_info["EndpointName"])
try:
response = sagemaker_operation(
endpoint_info,
@@ -343,21 +363,23 @@ def execute(self, context: 'Context') -> dict:
max_ingestion_time=self.max_ingestion_time,
)
except ClientError:
- self.operation = 'update'
+ self.operation = "update"
sagemaker_operation = self.hook.update_endpoint
- log_str = 'Updating'
+ log_str = "Updating"
response = sagemaker_operation(
endpoint_info,
wait_for_completion=self.wait_for_completion,
check_interval=self.check_interval,
max_ingestion_time=self.max_ingestion_time,
)
- if response['ResponseMetadata']['HTTPStatusCode'] != 200:
- raise AirflowException(f'Sagemaker endpoint creation failed: {response}')
+ if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
+ raise AirflowException(f"Sagemaker endpoint creation failed: {response}")
else:
return {
- 'EndpointConfig': self.hook.describe_endpoint_config(endpoint_info['EndpointConfigName']),
- 'Endpoint': self.hook.describe_endpoint(endpoint_info['EndpointName']),
+ "EndpointConfig": serialize(
+ self.hook.describe_endpoint_config(endpoint_info["EndpointConfigName"])
+ ),
+ "Endpoint": serialize(self.hook.describe_endpoint(endpoint_info["EndpointName"])),
}
@@ -396,6 +418,11 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
:param max_ingestion_time: If wait is set to True, the operation fails
if the transform job doesn't finish within max_ingestion_time seconds. If you
set this parameter to None, the operation does not timeout.
+ :param check_if_job_exists: If set to true, then the operator will check whether a transform job
+ already exists for the name in the config.
+ :param action_if_job_exists: Behaviour if the job name already exists. Possible options are "increment"
+ (default) and "fail".
+ This is only relevant if check_if_job_exists is True.
:return Dict: Returns The ARN of the model created in Amazon SageMaker.
"""
@@ -406,58 +433,85 @@ def __init__(
aws_conn_id: str = DEFAULT_CONN_ID,
wait_for_completion: bool = True,
check_interval: int = CHECK_INTERVAL_SECOND,
- max_ingestion_time: Optional[int] = None,
+ max_ingestion_time: int | None = None,
+ check_if_job_exists: bool = True,
+ action_if_job_exists: str = "increment",
**kwargs,
):
- super().__init__(config=config, **kwargs)
- self.config = config
- self.aws_conn_id = aws_conn_id
+ super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs)
self.wait_for_completion = wait_for_completion
self.check_interval = check_interval
self.max_ingestion_time = max_ingestion_time
- self.create_integer_fields()
-
- def create_integer_fields(self) -> None:
- """Set fields which should be casted to integers."""
- self.integer_fields: List[List[str]] = [
- ['Transform', 'TransformResources', 'InstanceCount'],
- ['Transform', 'MaxConcurrentTransforms'],
- ['Transform', 'MaxPayloadInMB'],
+ self.check_if_job_exists = check_if_job_exists
+ if action_if_job_exists in ("increment", "fail"):
+ self.action_if_job_exists = action_if_job_exists
+ else:
+ raise AirflowException(
+ f"Argument action_if_job_exists accepts only 'increment' and 'fail'. \
+ Provided value: '{action_if_job_exists}'."
+ )
+
+ def _create_integer_fields(self) -> None:
+ """Set fields which should be cast to integers."""
+ self.integer_fields: list[list[str]] = [
+ ["Transform", "TransformResources", "InstanceCount"],
+ ["Transform", "MaxConcurrentTransforms"],
+ ["Transform", "MaxPayloadInMB"],
]
- if 'Transform' not in self.config:
+ if "Transform" not in self.config:
for field in self.integer_fields:
field.pop(0)
def expand_role(self) -> None:
- if 'Model' not in self.config:
+ """Expands an IAM role name into an ARN."""
+ if "Model" not in self.config:
return
- config = self.config['Model']
- if 'ExecutionRoleArn' in config:
- hook = AwsBaseHook(self.aws_conn_id, client_type='iam')
- config['ExecutionRoleArn'] = hook.expand_role(config['ExecutionRoleArn'])
+ config = self.config["Model"]
+ if "ExecutionRoleArn" in config:
+ hook = AwsBaseHook(self.aws_conn_id, client_type="iam")
+ config["ExecutionRoleArn"] = hook.expand_role(config["ExecutionRoleArn"])
- def execute(self, context: 'Context') -> dict:
+ def execute(self, context: Context) -> dict:
self.preprocess_config()
- model_config = self.config.get('Model')
- transform_config = self.config.get('Transform', self.config)
+ model_config = self.config.get("Model")
+ transform_config = self.config.get("Transform", self.config)
+ if self.check_if_job_exists:
+ self._check_if_transform_job_exists()
if model_config:
- self.log.info('Creating SageMaker Model %s for transform job', model_config['ModelName'])
+ self.log.info("Creating SageMaker Model %s for transform job", model_config["ModelName"])
self.hook.create_model(model_config)
- self.log.info('Creating SageMaker transform Job %s.', transform_config['TransformJobName'])
+ self.log.info("Creating SageMaker transform Job %s.", transform_config["TransformJobName"])
response = self.hook.create_transform_job(
transform_config,
wait_for_completion=self.wait_for_completion,
check_interval=self.check_interval,
max_ingestion_time=self.max_ingestion_time,
)
- if response['ResponseMetadata']['HTTPStatusCode'] != 200:
- raise AirflowException(f'Sagemaker transform Job creation failed: {response}')
+ if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
+ raise AirflowException(f"Sagemaker transform Job creation failed: {response}")
else:
return {
- 'Model': self.hook.describe_model(transform_config['ModelName']),
- 'Transform': self.hook.describe_transform_job(transform_config['TransformJobName']),
+ "Model": serialize(self.hook.describe_model(transform_config["ModelName"])),
+ "Transform": serialize(
+ self.hook.describe_transform_job(transform_config["TransformJobName"])
+ ),
}
+ def _check_if_transform_job_exists(self) -> None:
+ transform_config = self.config.get("Transform", self.config)
+ transform_job_name = transform_config["TransformJobName"]
+ transform_jobs = self.hook.list_transform_jobs(name_contains=transform_job_name)
+ if transform_job_name in [tj["TransformJobName"] for tj in transform_jobs]:
+ if self.action_if_job_exists == "increment":
+ self.log.info("Found existing transform job with name '%s'.", transform_job_name)
+ new_transform_job_name = f"{transform_job_name}-{(len(transform_jobs) + 1)}"
+ transform_config["TransformJobName"] = new_transform_job_name
+ self.log.info("Incremented transform job name to '%s'.", new_transform_job_name)
+ elif self.action_if_job_exists == "fail":
+ raise AirflowException(
+ f"A SageMaker transform job with name {transform_job_name} already exists."
+ )
+
class SageMakerTuningOperator(SageMakerBaseOperator):
"""
@@ -485,14 +539,6 @@ class SageMakerTuningOperator(SageMakerBaseOperator):
:return Dict: Returns The ARN of the tuning job created in Amazon SageMaker.
"""
- integer_fields = [
- ['HyperParameterTuningJobConfig', 'ResourceLimits', 'MaxNumberOfTrainingJobs'],
- ['HyperParameterTuningJobConfig', 'ResourceLimits', 'MaxParallelTrainingJobs'],
- ['TrainingJobDefinition', 'ResourceConfig', 'InstanceCount'],
- ['TrainingJobDefinition', 'ResourceConfig', 'VolumeSizeInGB'],
- ['TrainingJobDefinition', 'StoppingCondition', 'MaxRuntimeInSeconds'],
- ]
-
def __init__(
self,
*,
@@ -500,27 +546,36 @@ def __init__(
aws_conn_id: str = DEFAULT_CONN_ID,
wait_for_completion: bool = True,
check_interval: int = CHECK_INTERVAL_SECOND,
- max_ingestion_time: Optional[int] = None,
+ max_ingestion_time: int | None = None,
**kwargs,
):
- super().__init__(config=config, **kwargs)
- self.config = config
- self.aws_conn_id = aws_conn_id
+ super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs)
self.wait_for_completion = wait_for_completion
self.check_interval = check_interval
self.max_ingestion_time = max_ingestion_time
def expand_role(self) -> None:
- if 'TrainingJobDefinition' in self.config:
- config = self.config['TrainingJobDefinition']
- if 'RoleArn' in config:
- hook = AwsBaseHook(self.aws_conn_id, client_type='iam')
- config['RoleArn'] = hook.expand_role(config['RoleArn'])
+ """Expands an IAM role name into an ARN."""
+ if "TrainingJobDefinition" in self.config:
+ config = self.config["TrainingJobDefinition"]
+ if "RoleArn" in config:
+ hook = AwsBaseHook(self.aws_conn_id, client_type="iam")
+ config["RoleArn"] = hook.expand_role(config["RoleArn"])
+
+ def _create_integer_fields(self) -> None:
+ """Set fields which should be cast to integers."""
+ self.integer_fields: list[list[str]] = [
+ ["HyperParameterTuningJobConfig", "ResourceLimits", "MaxNumberOfTrainingJobs"],
+ ["HyperParameterTuningJobConfig", "ResourceLimits", "MaxParallelTrainingJobs"],
+ ["TrainingJobDefinition", "ResourceConfig", "InstanceCount"],
+ ["TrainingJobDefinition", "ResourceConfig", "VolumeSizeInGB"],
+ ["TrainingJobDefinition", "StoppingCondition", "MaxRuntimeInSeconds"],
+ ]
- def execute(self, context: 'Context') -> dict:
+ def execute(self, context: Context) -> dict:
self.preprocess_config()
self.log.info(
- 'Creating SageMaker Hyper-Parameter Tuning Job %s', self.config['HyperParameterTuningJobName']
+ "Creating SageMaker Hyper-Parameter Tuning Job %s", self.config["HyperParameterTuningJobName"]
)
response = self.hook.create_tuning_job(
self.config,
@@ -528,10 +583,12 @@ def execute(self, context: 'Context') -> dict:
check_interval=self.check_interval,
max_ingestion_time=self.max_ingestion_time,
)
- if response['ResponseMetadata']['HTTPStatusCode'] != 200:
- raise AirflowException(f'Sagemaker Tuning Job creation failed: {response}')
+ if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
+ raise AirflowException(f"Sagemaker Tuning Job creation failed: {response}")
else:
- return {'Tuning': self.hook.describe_tuning_job(self.config['HyperParameterTuningJobName'])}
+ return {
+ "Tuning": serialize(self.hook.describe_tuning_job(self.config["HyperParameterTuningJobName"]))
+ }
class SageMakerModelOperator(SageMakerBaseOperator):
@@ -552,24 +609,23 @@ class SageMakerModelOperator(SageMakerBaseOperator):
:return Dict: Returns The ARN of the model created in Amazon SageMaker.
"""
- def __init__(self, *, config, aws_conn_id: str = DEFAULT_CONN_ID, **kwargs):
- super().__init__(config=config, **kwargs)
- self.config = config
- self.aws_conn_id = aws_conn_id
+ def __init__(self, *, config: dict, aws_conn_id: str = DEFAULT_CONN_ID, **kwargs):
+ super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs)
def expand_role(self) -> None:
- if 'ExecutionRoleArn' in self.config:
- hook = AwsBaseHook(self.aws_conn_id, client_type='iam')
- self.config['ExecutionRoleArn'] = hook.expand_role(self.config['ExecutionRoleArn'])
+ """Expands an IAM role name into an ARN."""
+ if "ExecutionRoleArn" in self.config:
+ hook = AwsBaseHook(self.aws_conn_id, client_type="iam")
+ self.config["ExecutionRoleArn"] = hook.expand_role(self.config["ExecutionRoleArn"])
- def execute(self, context: 'Context') -> dict:
+ def execute(self, context: Context) -> dict:
self.preprocess_config()
- self.log.info('Creating SageMaker Model %s.', self.config['ModelName'])
+ self.log.info("Creating SageMaker Model %s.", self.config["ModelName"])
response = self.hook.create_model(self.config)
- if response['ResponseMetadata']['HTTPStatusCode'] != 200:
- raise AirflowException(f'Sagemaker model creation failed: {response}')
+ if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
+ raise AirflowException(f"Sagemaker model creation failed: {response}")
else:
- return {'Model': self.hook.describe_model(self.config['ModelName'])}
+ return {"Model": serialize(self.hook.describe_model(self.config["ModelName"]))}
class SageMakerTrainingOperator(SageMakerBaseOperator):
@@ -597,16 +653,10 @@ class SageMakerTrainingOperator(SageMakerBaseOperator):
already exists for the name in the config.
:param action_if_job_exists: Behaviour if the job name already exists. Possible options are "increment"
(default) and "fail".
- This is only relevant if check_if
+ This is only relevant if check_if_job_exists is True.
:return Dict: Returns The ARN of the training job created in Amazon SageMaker.
"""
- integer_fields = [
- ['ResourceConfig', 'InstanceCount'],
- ['ResourceConfig', 'VolumeSizeInGB'],
- ['StoppingCondition', 'MaxRuntimeInSeconds'],
- ]
-
def __init__(
self,
*,
@@ -615,19 +665,18 @@ def __init__(
wait_for_completion: bool = True,
print_log: bool = True,
check_interval: int = CHECK_INTERVAL_SECOND,
- max_ingestion_time: Optional[int] = None,
+ max_ingestion_time: int | None = None,
check_if_job_exists: bool = True,
- action_if_job_exists: str = 'increment',
+ action_if_job_exists: str = "increment",
**kwargs,
):
- super().__init__(config=config, **kwargs)
- self.aws_conn_id = aws_conn_id
+ super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs)
self.wait_for_completion = wait_for_completion
self.print_log = print_log
self.check_interval = check_interval
self.max_ingestion_time = max_ingestion_time
self.check_if_job_exists = check_if_job_exists
- if action_if_job_exists in ('increment', 'fail'):
+ if action_if_job_exists in ("increment", "fail"):
self.action_if_job_exists = action_if_job_exists
else:
raise AirflowException(
@@ -636,15 +685,24 @@ def __init__(
)
def expand_role(self) -> None:
- if 'RoleArn' in self.config:
- hook = AwsBaseHook(self.aws_conn_id, client_type='iam')
- self.config['RoleArn'] = hook.expand_role(self.config['RoleArn'])
+ """Expands an IAM role name into an ARN."""
+ if "RoleArn" in self.config:
+ hook = AwsBaseHook(self.aws_conn_id, client_type="iam")
+ self.config["RoleArn"] = hook.expand_role(self.config["RoleArn"])
+
+ def _create_integer_fields(self) -> None:
+ """Set fields which should be cast to integers."""
+ self.integer_fields: list[list[str]] = [
+ ["ResourceConfig", "InstanceCount"],
+ ["ResourceConfig", "VolumeSizeInGB"],
+ ["StoppingCondition", "MaxRuntimeInSeconds"],
+ ]
- def execute(self, context: 'Context') -> dict:
+ def execute(self, context: Context) -> dict:
self.preprocess_config()
if self.check_if_job_exists:
self._check_if_job_exists()
- self.log.info('Creating SageMaker training job %s.', self.config['TrainingJobName'])
+ self.log.info("Creating SageMaker training job %s.", self.config["TrainingJobName"])
response = self.hook.create_training_job(
self.config,
wait_for_completion=self.wait_for_completion,
@@ -652,23 +710,23 @@ def execute(self, context: 'Context') -> dict:
check_interval=self.check_interval,
max_ingestion_time=self.max_ingestion_time,
)
- if response['ResponseMetadata']['HTTPStatusCode'] != 200:
- raise AirflowException(f'Sagemaker Training Job creation failed: {response}')
+ if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
+ raise AirflowException(f"Sagemaker Training Job creation failed: {response}")
else:
- return {'Training': self.hook.describe_training_job(self.config['TrainingJobName'])}
+ return {"Training": serialize(self.hook.describe_training_job(self.config["TrainingJobName"]))}
def _check_if_job_exists(self) -> None:
- training_job_name = self.config['TrainingJobName']
+ training_job_name = self.config["TrainingJobName"]
training_jobs = self.hook.list_training_jobs(name_contains=training_job_name)
- if training_job_name in [tj['TrainingJobName'] for tj in training_jobs]:
- if self.action_if_job_exists == 'increment':
+ if training_job_name in [tj["TrainingJobName"] for tj in training_jobs]:
+ if self.action_if_job_exists == "increment":
self.log.info("Found existing training job with name '%s'.", training_job_name)
- new_training_job_name = f'{training_job_name}-{(len(training_jobs) + 1)}'
- self.config['TrainingJobName'] = new_training_job_name
+ new_training_job_name = f"{training_job_name}-{(len(training_jobs) + 1)}"
+ self.config["TrainingJobName"] = new_training_job_name
self.log.info("Incremented training job name to '%s'.", new_training_job_name)
- elif self.action_if_job_exists == 'fail':
+ elif self.action_if_job_exists == "fail":
raise AirflowException(
- f'A SageMaker training job with name {training_job_name} already exists.'
+ f"A SageMaker training job with name {training_job_name} already exists."
)
@@ -685,12 +743,10 @@ class SageMakerDeleteModelOperator(SageMakerBaseOperator):
:param aws_conn_id: The AWS connection ID to use.
"""
- def __init__(self, *, config, aws_conn_id: str = DEFAULT_CONN_ID, **kwargs):
- super().__init__(config=config, **kwargs)
- self.config = config
- self.aws_conn_id = aws_conn_id
+ def __init__(self, *, config: dict, aws_conn_id: str = DEFAULT_CONN_ID, **kwargs):
+ super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs)
- def execute(self, context: 'Context') -> Any:
+ def execute(self, context: Context) -> Any:
sagemaker_hook = SageMakerHook(aws_conn_id=self.aws_conn_id)
- sagemaker_hook.delete_model(model_name=self.config['ModelName'])
- self.log.info("Model %s deleted successfully.", self.config['ModelName'])
+ sagemaker_hook.delete_model(model_name=self.config["ModelName"])
+ self.log.info("Model %s deleted successfully.", self.config["ModelName"])
diff --git a/airflow/providers/amazon/aws/operators/sagemaker_base.py b/airflow/providers/amazon/aws/operators/sagemaker_base.py
deleted file mode 100644
index 22e44d71c8ca0..0000000000000
--- a/airflow/providers/amazon/aws/operators/sagemaker_base.py
+++ /dev/null
@@ -1,29 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.sagemaker`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.operators.sagemaker import SageMakerBaseOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/providers/amazon/aws/operators/sagemaker_endpoint.py b/airflow/providers/amazon/aws/operators/sagemaker_endpoint.py
deleted file mode 100644
index 5351431f00f68..0000000000000
--- a/airflow/providers/amazon/aws/operators/sagemaker_endpoint.py
+++ /dev/null
@@ -1,29 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.sagemaker`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.operators.sagemaker import SageMakerEndpointOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/providers/amazon/aws/operators/sagemaker_endpoint_config.py b/airflow/providers/amazon/aws/operators/sagemaker_endpoint_config.py
deleted file mode 100644
index 737e5b6c7d2c8..0000000000000
--- a/airflow/providers/amazon/aws/operators/sagemaker_endpoint_config.py
+++ /dev/null
@@ -1,29 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.sagemaker`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.operators.sagemaker import SageMakerEndpointConfigOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/providers/amazon/aws/operators/sagemaker_model.py b/airflow/providers/amazon/aws/operators/sagemaker_model.py
deleted file mode 100644
index fffe8d96e4b69..0000000000000
--- a/airflow/providers/amazon/aws/operators/sagemaker_model.py
+++ /dev/null
@@ -1,29 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.sagemaker`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.operators.sagemaker import SageMakerModelOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/providers/amazon/aws/operators/sagemaker_processing.py b/airflow/providers/amazon/aws/operators/sagemaker_processing.py
deleted file mode 100644
index b3a4be8fa2aa3..0000000000000
--- a/airflow/providers/amazon/aws/operators/sagemaker_processing.py
+++ /dev/null
@@ -1,29 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.sagemaker`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.operators.sagemaker import SageMakerProcessingOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/providers/amazon/aws/operators/sagemaker_training.py b/airflow/providers/amazon/aws/operators/sagemaker_training.py
deleted file mode 100644
index 40f13b45e932f..0000000000000
--- a/airflow/providers/amazon/aws/operators/sagemaker_training.py
+++ /dev/null
@@ -1,29 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.sagemaker`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.operators.sagemaker import SageMakerTrainingOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/providers/amazon/aws/operators/sagemaker_transform.py b/airflow/providers/amazon/aws/operators/sagemaker_transform.py
deleted file mode 100644
index 1e833edebcb16..0000000000000
--- a/airflow/providers/amazon/aws/operators/sagemaker_transform.py
+++ /dev/null
@@ -1,29 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.sagemaker`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.operators.sagemaker import SageMakerTransformOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/providers/amazon/aws/operators/sagemaker_tuning.py b/airflow/providers/amazon/aws/operators/sagemaker_tuning.py
deleted file mode 100644
index 18a8263a4f102..0000000000000
--- a/airflow/providers/amazon/aws/operators/sagemaker_tuning.py
+++ /dev/null
@@ -1,29 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.sagemaker`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.operators.sagemaker import SageMakerTuningOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/providers/amazon/aws/operators/sns.py b/airflow/providers/amazon/aws/operators/sns.py
index e916798d03386..99525c4cb6bc2 100644
--- a/airflow/providers/amazon/aws/operators/sns.py
+++ b/airflow/providers/amazon/aws/operators/sns.py
@@ -15,9 +15,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
"""Publish message to SNS queue"""
-from typing import TYPE_CHECKING, Optional, Sequence
+from typing import TYPE_CHECKING, Sequence
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.sns import SnsHook
@@ -42,7 +43,7 @@ class SnsPublishOperator(BaseOperator):
determined automatically)
"""
- template_fields: Sequence[str] = ('message', 'subject', 'message_attributes')
+ template_fields: Sequence[str] = ("target_arn", "message", "subject", "message_attributes", "aws_conn_id")
template_ext: Sequence[str] = ()
template_fields_renderers = {"message_attributes": "json"}
@@ -51,9 +52,9 @@ def __init__(
*,
target_arn: str,
message: str,
- aws_conn_id: str = 'aws_default',
- subject: Optional[str] = None,
- message_attributes: Optional[dict] = None,
+ subject: str | None = None,
+ message_attributes: dict | None = None,
+ aws_conn_id: str = "aws_default",
**kwargs,
):
super().__init__(**kwargs)
@@ -63,11 +64,11 @@ def __init__(
self.message_attributes = message_attributes
self.aws_conn_id = aws_conn_id
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
sns = SnsHook(aws_conn_id=self.aws_conn_id)
self.log.info(
- 'Sending SNS notification to %s using %s:\nsubject=%s\nattributes=%s\nmessage=%s',
+ "Sending SNS notification to %s using %s:\nsubject=%s\nattributes=%s\nmessage=%s",
self.target_arn,
self.aws_conn_id,
self.subject,
diff --git a/airflow/providers/amazon/aws/operators/sqs.py b/airflow/providers/amazon/aws/operators/sqs.py
index 6eff54134a75d..0b0cfc4f16297 100644
--- a/airflow/providers/amazon/aws/operators/sqs.py
+++ b/airflow/providers/amazon/aws/operators/sqs.py
@@ -14,10 +14,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Publish message to SQS queue"""
-import warnings
-from typing import TYPE_CHECKING, Optional, Sequence
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Sequence
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.sqs import SqsHook
@@ -39,21 +39,30 @@ class SqsPublishOperator(BaseOperator):
:param message_attributes: additional attributes for the message (default: None)
For details of the attributes parameter see :py:meth:`botocore.client.SQS.send_message`
:param delay_seconds: message delay (templated) (default: 1 second)
+ :param message_group_id: This parameter applies only to FIFO (first-in-first-out) queues. (default: None)
+ For details of the attributes parameter see :py:meth:`botocore.client.SQS.send_message`
:param aws_conn_id: AWS connection id (default: aws_default)
"""
- template_fields: Sequence[str] = ('sqs_queue', 'message_content', 'delay_seconds', 'message_attributes')
- template_fields_renderers = {'message_attributes': 'json'}
- ui_color = '#6ad3fa'
+ template_fields: Sequence[str] = (
+ "sqs_queue",
+ "message_content",
+ "delay_seconds",
+ "message_attributes",
+ "message_group_id",
+ )
+ template_fields_renderers = {"message_attributes": "json"}
+ ui_color = "#6ad3fa"
def __init__(
self,
*,
sqs_queue: str,
message_content: str,
- message_attributes: Optional[dict] = None,
+ message_attributes: dict | None = None,
delay_seconds: int = 0,
- aws_conn_id: str = 'aws_default',
+ message_group_id: str | None = None,
+ aws_conn_id: str = "aws_default",
**kwargs,
):
super().__init__(**kwargs)
@@ -62,15 +71,15 @@ def __init__(
self.message_content = message_content
self.delay_seconds = delay_seconds
self.message_attributes = message_attributes or {}
+ self.message_group_id = message_group_id
- def execute(self, context: 'Context'):
+ def execute(self, context: Context) -> dict:
"""
Publish the message to the Amazon SQS queue
:param context: the context object
:return: dict with information about the message sent
For details of the returned dict see :py:meth:`botocore.client.SQS.send_message`
- :rtype: dict
"""
hook = SqsHook(aws_conn_id=self.aws_conn_id)
@@ -79,24 +88,9 @@ def execute(self, context: 'Context'):
message_body=self.message_content,
delay_seconds=self.delay_seconds,
message_attributes=self.message_attributes,
+ message_group_id=self.message_group_id,
)
- self.log.info('send_message result: %s', result)
+ self.log.info("send_message result: %s", result)
return result
-
-
-class SQSPublishOperator(SqsPublishOperator):
- """
- This operator is deprecated.
- Please use :class:`airflow.providers.amazon.aws.operators.sqs.SqsPublishOperator`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- "This operator is deprecated. "
- "Please use `airflow.providers.amazon.aws.operators.sqs.SqsPublishOperator`.",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/providers/amazon/aws/operators/step_function.py b/airflow/providers/amazon/aws/operators/step_function.py
index 7c32b33890117..c131dabc742e1 100644
--- a/airflow/providers/amazon/aws/operators/step_function.py
+++ b/airflow/providers/amazon/aws/operators/step_function.py
@@ -14,10 +14,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
+from __future__ import annotations
import json
-from typing import TYPE_CHECKING, Optional, Sequence, Union
+from typing import TYPE_CHECKING, Sequence
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
@@ -44,18 +44,18 @@ class StepFunctionStartExecutionOperator(BaseOperator):
:param do_xcom_push: if True, execution_arn is pushed to XCom with key execution_arn.
"""
- template_fields: Sequence[str] = ('state_machine_arn', 'name', 'input')
+ template_fields: Sequence[str] = ("state_machine_arn", "name", "input")
template_ext: Sequence[str] = ()
- ui_color = '#f9c915'
+ ui_color = "#f9c915"
def __init__(
self,
*,
state_machine_arn: str,
- name: Optional[str] = None,
- state_machine_input: Union[dict, str, None] = None,
- aws_conn_id: str = 'aws_default',
- region_name: Optional[str] = None,
+ name: str | None = None,
+ state_machine_input: dict | str | None = None,
+ aws_conn_id: str = "aws_default",
+ region_name: str | None = None,
**kwargs,
):
super().__init__(**kwargs)
@@ -65,15 +65,15 @@ def __init__(
self.aws_conn_id = aws_conn_id
self.region_name = region_name
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
hook = StepFunctionHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
execution_arn = hook.start_execution(self.state_machine_arn, self.name, self.input)
if execution_arn is None:
- raise AirflowException(f'Failed to start State Machine execution for: {self.state_machine_arn}')
+ raise AirflowException(f"Failed to start State Machine execution for: {self.state_machine_arn}")
- self.log.info('Started State Machine execution for %s: %s', self.state_machine_arn, execution_arn)
+ self.log.info("Started State Machine execution for %s: %s", self.state_machine_arn, execution_arn)
return execution_arn
@@ -92,16 +92,16 @@ class StepFunctionGetExecutionOutputOperator(BaseOperator):
:param aws_conn_id: aws connection to use, defaults to 'aws_default'
"""
- template_fields: Sequence[str] = ('execution_arn',)
+ template_fields: Sequence[str] = ("execution_arn",)
template_ext: Sequence[str] = ()
- ui_color = '#f9c915'
+ ui_color = "#f9c915"
def __init__(
self,
*,
execution_arn: str,
- aws_conn_id: str = 'aws_default',
- region_name: Optional[str] = None,
+ aws_conn_id: str = "aws_default",
+ region_name: str | None = None,
**kwargs,
):
super().__init__(**kwargs)
@@ -109,12 +109,12 @@ def __init__(
self.aws_conn_id = aws_conn_id
self.region_name = region_name
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
hook = StepFunctionHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
execution_status = hook.describe_execution(self.execution_arn)
- execution_output = json.loads(execution_status['output']) if 'output' in execution_status else None
+ execution_output = json.loads(execution_status["output"]) if "output" in execution_status else None
- self.log.info('Got State Machine Execution output for %s', self.execution_arn)
+ self.log.info("Got State Machine Execution output for %s", self.execution_arn)
return execution_output
diff --git a/airflow/providers/amazon/aws/operators/step_function_get_execution_output.py b/airflow/providers/amazon/aws/operators/step_function_get_execution_output.py
deleted file mode 100644
index 2b047241ae861..0000000000000
--- a/airflow/providers/amazon/aws/operators/step_function_get_execution_output.py
+++ /dev/null
@@ -1,30 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.step_function`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.operators.step_function import ( # noqa
- StepFunctionGetExecutionOutputOperator,
-)
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.step_function`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/providers/amazon/aws/operators/step_function_start_execution.py b/airflow/providers/amazon/aws/operators/step_function_start_execution.py
deleted file mode 100644
index 10a847ffc90bd..0000000000000
--- a/airflow/providers/amazon/aws/operators/step_function_start_execution.py
+++ /dev/null
@@ -1,28 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.step_function`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.operators.step_function import StepFunctionStartExecutionOperator # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.operators.step_function`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/providers/amazon/aws/secrets/secrets_manager.py b/airflow/providers/amazon/aws/secrets/secrets_manager.py
index 8b72f955ac6d5..684a8d4afbee9 100644
--- a/airflow/providers/amazon/aws/secrets/secrets_manager.py
+++ b/airflow/providers/amazon/aws/secrets/secrets_manager.py
@@ -16,31 +16,22 @@
# specific language governing permissions and limitations
# under the License.
"""Objects relating to sourcing secrets from AWS Secrets Manager"""
+from __future__ import annotations
import ast
import json
-import re
-import sys
import warnings
-from typing import Optional
-from urllib.parse import urlencode
-
-import boto3
-
-from airflow.version import version as airflow_version
-
-if sys.version_info >= (3, 8):
- from functools import cached_property
-else:
- from cached_property import cached_property
+from typing import TYPE_CHECKING, Any
+from urllib.parse import unquote, urlencode
+from airflow.compat.functools import cached_property
+from airflow.providers.amazon.aws.utils import get_airflow_version, trim_none_values
from airflow.secrets import BaseSecretsBackend
from airflow.utils.log.logging_mixin import LoggingMixin
-
-def _parse_version(val):
- val = re.sub(r'(\d+\.\d+\.\d+).*', lambda x: x.group(1), val)
- return tuple(int(x) for x in val.split('.'))
+if TYPE_CHECKING:
+ # Avoid circular import problems when instantiating the backend during configuration.
+ from airflow.models.connection import Connection
class SecretsManagerBackend(BaseSecretsBackend, LoggingMixin):
@@ -63,8 +54,17 @@ class SecretsManagerBackend(BaseSecretsBackend, LoggingMixin):
if you provide ``{"config_prefix": "airflow/config"}`` and request config
key ``sql_alchemy_conn``.
- You can also pass additional keyword arguments like ``aws_secret_access_key``, ``aws_access_key_id``
- or ``region_name`` to this class and they would be passed on to Boto3 client.
+ You can also pass additional keyword arguments listed in AWS Connection Extra config
+ to this class, and they would be used for establishing a connection and passed on to Boto3 client.
+
+ .. code-block:: ini
+
+ [secrets]
+ backend = airflow.providers.amazon.aws.secrets.secrets_manager.SecretsManagerBackend
+ backend_kwargs = {"connections_prefix": "airflow/connections", "region_name": "eu-west-1"}
+
+ .. seealso::
+ :ref:`howto/connection:aws:configuring-the-connection`
There are two ways of storing secrets in Secret Manager for using them with this operator:
storing them as a conn URI in one field, or taking advantage of native approach of Secrets Manager
@@ -74,7 +74,7 @@ class SecretsManagerBackend(BaseSecretsBackend, LoggingMixin):
.. code-block:: python
possible_words_for_conn_fields = {
- "user": ["user", "username", "login", "user_name"],
+ "login": ["user", "username", "login", "user_name"],
"password": ["password", "pass", "key"],
"host": ["host", "remote_host", "server"],
"port": ["port"],
@@ -95,11 +95,13 @@ class SecretsManagerBackend(BaseSecretsBackend, LoggingMixin):
:param config_prefix: Specifies the prefix of the secret to read to get Configurations.
If set to None (null value in the configuration), requests for configurations will not be sent to
AWS Secrets Manager. If you don't want a config_prefix, set it as an empty string
- :param profile_name: The name of a profile to use. If not given, then the default profile is used.
:param sep: separator used to concatenate secret_prefix and secret_id. Default: "/"
:param full_url_mode: if True, the secrets must be stored as one conn URI in just one field per secret.
If False (set it as false in backend_kwargs), you can store the secret using different
fields (password, user...).
+ :param are_secret_values_urlencoded: If True, and full_url_mode is False, then the values are assumed to
+ be URL-encoded and will be decoded before being passed into a Connection object. This option is
+ ignored when full_url_mode is True.
:param extra_conn_words: for using just when you set full_url_mode as false and store
the secrets in different fields of secrets manager. You can add more words for each connection
part beyond the default ones. The extra words to be searched should be passed as a dict of lists,
@@ -109,13 +111,13 @@ class SecretsManagerBackend(BaseSecretsBackend, LoggingMixin):
def __init__(
self,
- connections_prefix: str = 'airflow/connections',
- variables_prefix: str = 'airflow/variables',
- config_prefix: str = 'airflow/config',
- profile_name: Optional[str] = None,
+ connections_prefix: str = "airflow/connections",
+ variables_prefix: str = "airflow/variables",
+ config_prefix: str = "airflow/config",
sep: str = "/",
full_url_mode: bool = True,
- extra_conn_words: Optional[dict] = None,
+ are_secret_values_urlencoded: bool | None = None,
+ extra_conn_words: dict[str, list[str]] | None = None,
**kwargs,
):
super().__init__()
@@ -131,23 +133,62 @@ def __init__(
self.config_prefix = config_prefix.rstrip(sep)
else:
self.config_prefix = config_prefix
- self.profile_name = profile_name
self.sep = sep
self.full_url_mode = full_url_mode
+
+ if are_secret_values_urlencoded is None:
+ self.are_secret_values_urlencoded = True
+ else:
+ warnings.warn(
+ "The `secret_values_are_urlencoded` kwarg only exists to assist in migrating away from"
+ " URL-encoding secret values when `full_url_mode` is False. It will be considered deprecated"
+ " when values are not required to be URL-encoded by default.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ if full_url_mode and not are_secret_values_urlencoded:
+ warnings.warn(
+ "The `secret_values_are_urlencoded` kwarg for the SecretsManagerBackend is only used"
+ " when `full_url_mode` is False. When `full_url_mode` is True, the secret needs to be"
+ " URL-encoded.",
+ UserWarning,
+ stacklevel=2,
+ )
+ self.are_secret_values_urlencoded = are_secret_values_urlencoded
self.extra_conn_words = extra_conn_words or {}
+
+ self.profile_name = kwargs.get("profile_name", None)
+ # Remove client specific arguments from kwargs
+ self.api_version = kwargs.pop("api_version", None)
+ self.use_ssl = kwargs.pop("use_ssl", None)
+
self.kwargs = kwargs
@cached_property
def client(self):
"""Create a Secrets Manager client"""
- session = boto3.session.Session(profile_name=self.profile_name)
-
- return session.client(service_name="secretsmanager", **self.kwargs)
+ from airflow.providers.amazon.aws.hooks.base_aws import SessionFactory
+ from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper
+
+ conn_id = f"{self.__class__.__name__}__connection"
+ conn_config = AwsConnectionWrapper.from_connection_metadata(conn_id=conn_id, extra=self.kwargs)
+ client_kwargs = trim_none_values(
+ {
+ "region_name": conn_config.region_name,
+ "verify": conn_config.verify,
+ "endpoint_url": conn_config.endpoint_url,
+ "api_version": self.api_version,
+ "use_ssl": self.use_ssl,
+ }
+ )
+
+ session = SessionFactory(conn=conn_config).create_session()
+ return session.client(service_name="secretsmanager", **client_kwargs)
@staticmethod
- def _format_uri_with_extra(secret, conn_string):
+ def _format_uri_with_extra(secret, conn_string: str) -> str:
try:
- extra_dict = secret['extra']
+ extra_dict = secret["extra"]
except KeyError:
return conn_string
@@ -156,31 +197,161 @@ def _format_uri_with_extra(secret, conn_string):
return conn_string
- def get_uri_from_secret(self, secret):
+ def get_connection(self, conn_id: str) -> Connection | None:
+ if not self.full_url_mode:
+ # Avoid circular import problems when instantiating the backend during configuration.
+ from airflow.models.connection import Connection
+
+ secret_string = self._get_secret(self.connections_prefix, conn_id)
+ secret_dict = self._deserialize_json_string(secret_string)
+
+ if not secret_dict:
+ return None
+
+ if "extra" in secret_dict and isinstance(secret_dict["extra"], str):
+ secret_dict["extra"] = self._deserialize_json_string(secret_dict["extra"])
+
+ data = self._standardize_secret_keys(secret_dict)
+
+ if self.are_secret_values_urlencoded:
+ data = self._remove_escaping_in_secret_dict(secret=data, conn_id=conn_id)
+
+ port: int | None = None
+
+ if data["port"] is not None:
+ port = int(data["port"])
+
+ return Connection(
+ conn_id=conn_id,
+ login=data["user"],
+ password=data["password"],
+ host=data["host"],
+ port=port,
+ schema=data["schema"],
+ conn_type=data["conn_type"],
+ extra=data["extra"],
+ )
+
+ return super().get_connection(conn_id=conn_id)
+
+ def _standardize_secret_keys(self, secret: dict[str, Any]) -> dict[str, Any]:
+ """Standardize the names of the keys in the dict. These keys align with"""
possible_words_for_conn_fields = {
- 'user': ['user', 'username', 'login', 'user_name'],
- 'password': ['password', 'pass', 'key'],
- 'host': ['host', 'remote_host', 'server'],
- 'port': ['port'],
- 'schema': ['database', 'schema'],
- 'conn_type': ['conn_type', 'conn_id', 'connection_type', 'engine'],
+ "user": ["user", "username", "login", "user_name"],
+ "password": ["password", "pass", "key"],
+ "host": ["host", "remote_host", "server"],
+ "port": ["port"],
+ "schema": ["database", "schema"],
+ "conn_type": ["conn_type", "conn_id", "connection_type", "engine"],
+ "extra": ["extra"],
}
for conn_field, extra_words in self.extra_conn_words.items():
possible_words_for_conn_fields[conn_field].extend(extra_words)
- conn_d = {}
+ conn_d: dict[str, Any] = {}
for conn_field, possible_words in possible_words_for_conn_fields.items():
try:
conn_d[conn_field] = [v for k, v in secret.items() if k in possible_words][0]
except IndexError:
- conn_d[conn_field] = ''
+ conn_d[conn_field] = None
- conn_string = "{conn_type}://{user}:{password}@{host}:{port}/{schema}".format(**conn_d)
+ return conn_d
+ def get_uri_from_secret(self, secret: dict[str, str]) -> str:
+ conn_d: dict[str, str] = {k: v if v else "" for k, v in self._standardize_secret_keys(secret).items()}
+ conn_string = "{conn_type}://{user}:{password}@{host}:{port}/{schema}".format(**conn_d)
return self._format_uri_with_extra(secret, conn_string)
- def get_conn_value(self, conn_id: str):
+ def _deserialize_json_string(self, value: str | None) -> dict[Any, Any] | None:
+ if not value:
+ return None
+ try:
+ # Use ast.literal_eval for backwards compatibility.
+ # Previous version of this code had a comment saying that using json.loads caused errors.
+ # This likely means people were using dict reprs instead of valid JSONs.
+ res: dict[str, Any] = json.loads(value)
+ except json.JSONDecodeError:
+ try:
+ res = ast.literal_eval(value) if value else None
+ warnings.warn(
+ f"In future versions, `{type(self).__name__}` will only support valid JSONs, not dict"
+ " reprs. Please make sure your secret is a valid JSON."
+ )
+ except ValueError: # 'malformed node or string: ' error, for empty conns
+ return None
+
+ return res
+
+ def _remove_escaping_in_secret_dict(self, secret: dict[str, Any], conn_id: str) -> dict[str, Any]:
+ # When ``unquote(v) == v``, then removing unquote won't affect the user, regardless of
+ # whether or not ``v`` is URL-encoded. For example, "foo bar" is not URL-encoded. But
+ # because decoding it doesn't affect the value, then it will migrate safely when
+ # ``unquote`` gets removed.
+ #
+ # When parameters are URL-encoded, but decoding is idempotent, we need to warn the user
+ # to un-escape their secrets. For example, if "foo%20bar" is a URL-encoded string, then
+ # decoding is idempotent because ``unquote(unquote("foo%20bar")) == unquote("foo%20bar")``.
+ #
+ # In the rare situation that value is URL-encoded but the decoding is _not_ idempotent,
+ # this causes a major issue. For example, if "foo%2520bar" is URL-encoded, then decoding is
+ # _not_ idempotent because ``unquote(unquote("foo%2520bar")) != unquote("foo%2520bar")``
+ #
+ # This causes a problem for migration because if the user decodes their value, we cannot
+ # infer that is the case by looking at the decoded value (from our vantage point, it will
+ # look to be URL-encoded.)
+ #
+ # So when this uncommon situation occurs, the user _must_ adjust the configuration and set
+ # ``parameters_are_urlencoded`` to False to migrate safely. In all other cases, we do not
+ # need the user to adjust this object to migrate; they can transition their secrets with
+ # the default configuration.
+
+ warn_user = False
+ idempotent = True
+
+ for k, v in secret.copy().items():
+
+ if k == "extra" and isinstance(v, dict):
+ # The old behavior was that extras were _not_ urlencoded inside the secret.
+ # If they were urlencoded (e.g. "foo%20bar"), then they would be re-urlencoded
+ # (e.g. "foo%20bar" becomes "foo%2520bar") and then unquoted once when parsed.
+ # So we should just allow the extra dict to remain as-is.
+ continue
+
+ elif v is not None:
+ v_unquoted = unquote(v)
+ if v != v_unquoted:
+ secret[k] = unquote(v)
+ warn_user = True
+
+ # Check to see if decoding is idempotent.
+ if v_unquoted == unquote(v_unquoted):
+ idempotent = False
+
+ if warn_user:
+ msg = (
+ "When full_url_mode=False, URL-encoding secret values is deprecated. In future versions, "
+ f"this value will not be un-escaped. For the conn_id {conn_id!r}, please remove the "
+ "URL-encoding.\n\n"
+ "This warning was raised because the SecretsManagerBackend detected that this "
+ "connection was URL-encoded."
+ )
+ if idempotent:
+ msg = f" Once the values for conn_id {conn_id!r} are decoded, this warning will go away."
+ if not idempotent:
+ msg += (
+ " In addition to decoding the values for your connection, you must also set"
+ " secret_values_are_urlencoded=False for your config variable"
+ " secrets.backend_kwargs because this connection's URL encoding is not idempotent."
+ " For more information, see:"
+ " https://airflow.apache.org/docs/apache-airflow-providers-amazon/stable/secrets-backends"
+ "/aws-secrets-manager.html#url-encoding-of-secrets-when-full-url-mode-is-false"
+ )
+ warnings.warn(msg, DeprecationWarning, stacklevel=2)
+
+ return secret
+
+ def get_conn_value(self, conn_id: str) -> str | None:
"""
Get serialized representation of Connection
@@ -191,13 +362,30 @@ def get_conn_value(self, conn_id: str):
if self.full_url_mode:
return self._get_secret(self.connections_prefix, conn_id)
- try:
- secret_string = self._get_secret(self.connections_prefix, conn_id)
- # json.loads gives error
- secret = ast.literal_eval(secret_string) if secret_string else None
- except ValueError: # 'malformed node or string: ' error, for empty conns
- connection = None
- secret = None
+ else:
+ warnings.warn(
+ f"In future versions, `{type(self).__name__}.get_conn_value` will return a JSON string when"
+ " full_url_mode is False, not a URI.",
+ DeprecationWarning,
+ )
+
+ # It is very rare for user code to get to this point, since:
+ #
+ # - When full_url_mode is True, the previous statement returns.
+ # - When full_url_mode is False, get_connection() does not call
+ # `get_conn_value`. Additionally, full_url_mode defaults to True.
+ #
+ # So the code would have to be calling `get_conn_value` directly, and
+ # the user would be using a non-default setting.
+ #
+ # As of Airflow 2.3.0, get_conn_value() is allowed to return a JSON
+ # string in the base implementation. This is a way to deprecate this
+ # behavior gracefully.
+
+ secret_string = self._get_secret(self.connections_prefix, conn_id)
+
+ secret = self._deserialize_json_string(secret_string)
+ connection = None
# These lines will check if we have with some denomination stored an username, password and host
if secret:
@@ -205,7 +393,7 @@ def get_conn_value(self, conn_id: str):
return connection
- def get_conn_uri(self, conn_id: str) -> Optional[str]:
+ def get_conn_uri(self, conn_id: str) -> str | None:
"""
Return URI representation of Connection conn_id.
@@ -214,7 +402,7 @@ def get_conn_uri(self, conn_id: str) -> Optional[str]:
:param conn_id: the connection id
:return: deserialized Connection
"""
- if _parse_version(airflow_version) >= (2, 3):
+ if get_airflow_version() >= (2, 3):
warnings.warn(
f"Method `{self.__class__.__name__}.get_conn_uri` is deprecated and will be removed "
"in a future release. Please use method `get_conn_value` instead.",
@@ -223,7 +411,7 @@ def get_conn_uri(self, conn_id: str) -> Optional[str]:
)
return self.get_conn_value(conn_id)
- def get_variable(self, key: str) -> Optional[str]:
+ def get_variable(self, key: str) -> str | None:
"""
Get Airflow Variable from Environment Variable
:param key: Variable Key
@@ -234,7 +422,7 @@ def get_variable(self, key: str) -> Optional[str]:
return self._get_secret(self.variables_prefix, key)
- def get_config(self, key: str) -> Optional[str]:
+ def get_config(self, key: str) -> str | None:
"""
Get Airflow Configuration
:param key: Configuration Option Key
@@ -245,12 +433,13 @@ def get_config(self, key: str) -> Optional[str]:
return self._get_secret(self.config_prefix, key)
- def _get_secret(self, path_prefix, secret_id: str) -> Optional[str]:
+ def _get_secret(self, path_prefix, secret_id: str) -> str | None:
"""
Get secret value from Secrets Manager
:param path_prefix: Prefix for the Path to get Secret
:param secret_id: Secret Key
"""
+ error_msg = "An error occurred when calling the get_secret_value operation"
if path_prefix:
secrets_path = self.build_path(path_prefix, secret_id, self.sep)
else:
@@ -260,18 +449,39 @@ def _get_secret(self, path_prefix, secret_id: str) -> Optional[str]:
response = self.client.get_secret_value(
SecretId=secrets_path,
)
- return response.get('SecretString')
+ return response.get("SecretString")
except self.client.exceptions.ResourceNotFoundException:
self.log.debug(
- "An error occurred (ResourceNotFoundException) when calling the "
- "get_secret_value operation: "
- "Secret %s not found.",
+ "ResourceNotFoundException: %s. Secret %s not found.",
+ error_msg,
secret_id,
)
return None
- except self.client.exceptions.AccessDeniedException:
+ except self.client.exceptions.InvalidParameterException:
+ self.log.debug(
+ "InvalidParameterException: %s",
+ error_msg,
+ exc_info=True,
+ )
+ return None
+ except self.client.exceptions.InvalidRequestException:
+ self.log.debug(
+ "InvalidRequestException: %s",
+ error_msg,
+ exc_info=True,
+ )
+ return None
+ except self.client.exceptions.DecryptionFailure:
+ self.log.debug(
+ "DecryptionFailure: %s",
+ error_msg,
+ exc_info=True,
+ )
+ return None
+ except self.client.exceptions.InternalServiceError:
self.log.debug(
- "An error occurred (AccessDeniedException) when calling the get_secret_value operation",
+ "InternalServiceError: %s",
+ error_msg,
exc_info=True,
)
return None
diff --git a/airflow/providers/amazon/aws/secrets/systems_manager.py b/airflow/providers/amazon/aws/secrets/systems_manager.py
index e45a5500ab003..af45cb5e9813b 100644
--- a/airflow/providers/amazon/aws/secrets/systems_manager.py
+++ b/airflow/providers/amazon/aws/secrets/systems_manager.py
@@ -16,29 +16,16 @@
# specific language governing permissions and limitations
# under the License.
"""Objects relating to sourcing connections from AWS SSM Parameter Store"""
-import re
-import sys
-import warnings
-from typing import Optional
-
-import boto3
-
-from airflow.version import version as airflow_version
+from __future__ import annotations
-if sys.version_info >= (3, 8):
- from functools import cached_property
-else:
- from cached_property import cached_property
+import warnings
+from airflow.compat.functools import cached_property
+from airflow.providers.amazon.aws.utils import trim_none_values
from airflow.secrets import BaseSecretsBackend
from airflow.utils.log.logging_mixin import LoggingMixin
-def _parse_version(val):
- val = re.sub(r'(\d+\.\d+\.\d+).*', lambda x: x.group(1), val)
- return tuple(int(x) for x in val.split('.'))
-
-
class SystemsManagerParameterStoreBackend(BaseSecretsBackend, LoggingMixin):
"""
Retrieves Connection or Variables from AWS SSM Parameter Store
@@ -62,15 +49,26 @@ class SystemsManagerParameterStoreBackend(BaseSecretsBackend, LoggingMixin):
If set to None (null), requests for variables will not be sent to AWS SSM Parameter Store.
:param config_prefix: Specifies the prefix of the secret to read to get Variables.
If set to None (null), requests for configurations will not be sent to AWS SSM Parameter Store.
- :param profile_name: The name of a profile to use. If not given, then the default profile is used.
+
+ You can also pass additional keyword arguments listed in AWS Connection Extra config
+ to this class, and they would be used for establish connection and passed on to Boto3 client.
+
+ .. code-block:: ini
+
+ [secrets]
+ backend = airflow.providers.amazon.aws.secrets.systems_manager.SystemsManagerParameterStoreBackend
+ backend_kwargs = {"connections_prefix": "airflow/connections", "region_name": "eu-west-1"}
+
+ .. seealso::
+ :ref:`howto/connection:aws:configuring-the-connection`
+
"""
def __init__(
self,
- connections_prefix: str = '/airflow/connections',
- variables_prefix: str = '/airflow/variables',
- config_prefix: str = '/airflow/config',
- profile_name: Optional[str] = None,
+ connections_prefix: str = "/airflow/connections",
+ variables_prefix: str = "/airflow/variables",
+ config_prefix: str = "/airflow/config",
**kwargs,
):
super().__init__()
@@ -79,23 +77,43 @@ def __init__(
else:
self.connections_prefix = connections_prefix
if variables_prefix is not None:
- self.variables_prefix = variables_prefix.rstrip('/')
+ self.variables_prefix = variables_prefix.rstrip("/")
else:
self.variables_prefix = variables_prefix
if config_prefix is not None:
- self.config_prefix = config_prefix.rstrip('/')
+ self.config_prefix = config_prefix.rstrip("/")
else:
self.config_prefix = config_prefix
- self.profile_name = profile_name
+
+ self.profile_name = kwargs.get("profile_name", None)
+ # Remove client specific arguments from kwargs
+ self.api_version = kwargs.pop("api_version", None)
+ self.use_ssl = kwargs.pop("use_ssl", None)
+
self.kwargs = kwargs
@cached_property
def client(self):
"""Create a SSM client"""
- session = boto3.Session(profile_name=self.profile_name)
- return session.client("ssm", **self.kwargs)
-
- def get_conn_value(self, conn_id: str) -> Optional[str]:
+ from airflow.providers.amazon.aws.hooks.base_aws import SessionFactory
+ from airflow.providers.amazon.aws.utils.connection_wrapper import AwsConnectionWrapper
+
+ conn_id = f"{self.__class__.__name__}__connection"
+ conn_config = AwsConnectionWrapper.from_connection_metadata(conn_id=conn_id, extra=self.kwargs)
+ client_kwargs = trim_none_values(
+ {
+ "region_name": conn_config.region_name,
+ "verify": conn_config.verify,
+ "endpoint_url": conn_config.endpoint_url,
+ "api_version": self.api_version,
+ "use_ssl": self.use_ssl,
+ }
+ )
+
+ session = SessionFactory(conn=conn_config).create_session()
+ return session.client(service_name="ssm", **client_kwargs)
+
+ def get_conn_value(self, conn_id: str) -> str | None:
"""
Get param value
@@ -106,25 +124,27 @@ def get_conn_value(self, conn_id: str) -> Optional[str]:
return self._get_secret(self.connections_prefix, conn_id)
- def get_conn_uri(self, conn_id: str) -> Optional[str]:
+ def get_conn_uri(self, conn_id: str) -> str | None:
"""
Return URI representation of Connection conn_id.
As of Airflow version 2.3.0 this method is deprecated.
:param conn_id: the connection id
- :return: deserialized Connection
"""
- if _parse_version(airflow_version) >= (2, 3):
- warnings.warn(
- f"Method `{self.__class__.__name__}.get_conn_uri` is deprecated and will be removed "
- "in a future release. Please use method `get_conn_value` instead.",
- DeprecationWarning,
- stacklevel=2,
- )
- return self.get_conn_value(conn_id)
-
- def get_variable(self, key: str) -> Optional[str]:
+ warnings.warn(
+ f"Method `{self.__class__.__name__}.get_conn_uri` is deprecated and will be removed "
+ "in a future release. Please use method `get_conn_value` instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ value = self.get_conn_value(conn_id)
+ if value is None:
+ return None
+
+ return self.deserialize_connection(conn_id, value).get_uri()
+
+ def get_variable(self, key: str) -> str | None:
"""
Get Airflow Variable from Environment Variable
@@ -136,7 +156,7 @@ def get_variable(self, key: str) -> Optional[str]:
return self._get_secret(self.variables_prefix, key)
- def get_config(self, key: str) -> Optional[str]:
+ def get_config(self, key: str) -> str | None:
"""
Get Airflow Configuration
@@ -148,7 +168,7 @@ def get_config(self, key: str) -> Optional[str]:
return self._get_secret(self.config_prefix, key)
- def _get_secret(self, path_prefix: str, secret_id: str) -> Optional[str]:
+ def _get_secret(self, path_prefix: str, secret_id: str) -> str | None:
"""
Get secret value from Parameter Store.
diff --git a/airflow/providers/amazon/aws/sensors/athena.py b/airflow/providers/amazon/aws/sensors/athena.py
index 927f512143362..1954d15a8e004 100644
--- a/airflow/providers/amazon/aws/sensors/athena.py
+++ b/airflow/providers/amazon/aws/sensors/athena.py
@@ -15,17 +15,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-import sys
-from typing import TYPE_CHECKING, Any, Optional, Sequence
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any, Sequence
if TYPE_CHECKING:
from airflow.utils.context import Context
-if sys.version_info >= (3, 8):
- from functools import cached_property
-else:
- from cached_property import cached_property
-
+from airflow.compat.functools import cached_property
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.athena import AthenaHook
from airflow.sensors.base import BaseSensorOperator
@@ -37,7 +34,7 @@ class AthenaSensor(BaseSensorOperator):
If the query fails, the task will fail.
.. seealso::
- For more information on how to use this operator, take a look at the guide:
+ For more information on how to use this sensor, take a look at the guide:
:ref:`howto/sensor:AthenaSensor`
@@ -50,25 +47,25 @@ class AthenaSensor(BaseSensorOperator):
"""
INTERMEDIATE_STATES = (
- 'QUEUED',
- 'RUNNING',
+ "QUEUED",
+ "RUNNING",
)
FAILURE_STATES = (
- 'FAILED',
- 'CANCELLED',
+ "FAILED",
+ "CANCELLED",
)
- SUCCESS_STATES = ('SUCCEEDED',)
+ SUCCESS_STATES = ("SUCCEEDED",)
- template_fields: Sequence[str] = ('query_execution_id',)
+ template_fields: Sequence[str] = ("query_execution_id",)
template_ext: Sequence[str] = ()
- ui_color = '#66c3ff'
+ ui_color = "#66c3ff"
def __init__(
self,
*,
query_execution_id: str,
- max_retries: Optional[int] = None,
- aws_conn_id: str = 'aws_default',
+ max_retries: int | None = None,
+ aws_conn_id: str = "aws_default",
sleep_time: int = 10,
**kwargs: Any,
) -> None:
@@ -78,11 +75,11 @@ def __init__(
self.sleep_time = sleep_time
self.max_retries = max_retries
- def poke(self, context: 'Context') -> bool:
+ def poke(self, context: Context) -> bool:
state = self.hook.poll_query_status(self.query_execution_id, self.max_retries)
if state in self.FAILURE_STATES:
- raise AirflowException('Athena sensor failed')
+ raise AirflowException("Athena sensor failed")
if state in self.INTERMEDIATE_STATES:
return False
diff --git a/airflow/providers/amazon/aws/sensors/batch.py b/airflow/providers/amazon/aws/sensors/batch.py
index faab424e314fa..facd01a00e4cf 100644
--- a/airflow/providers/amazon/aws/sensors/batch.py
+++ b/airflow/providers/amazon/aws/sensors/batch.py
@@ -14,8 +14,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
-from typing import TYPE_CHECKING, Optional, Sequence
+import sys
+from typing import TYPE_CHECKING, Sequence
+
+if sys.version_info >= (3, 8):
+ from functools import cached_property
+else:
+ from cached_property import cached_property
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
@@ -36,29 +43,30 @@ class BatchSensor(BaseSensorOperator):
:param job_id: Batch job_id to check the state for
:param aws_conn_id: aws connection to use, defaults to 'aws_default'
+ :param region_name: aws region name associated with the client
"""
- template_fields: Sequence[str] = ('job_id',)
+ template_fields: Sequence[str] = ("job_id",)
template_ext: Sequence[str] = ()
- ui_color = '#66c3ff'
+ ui_color = "#66c3ff"
def __init__(
self,
*,
job_id: str,
- aws_conn_id: str = 'aws_default',
- region_name: Optional[str] = None,
+ aws_conn_id: str = "aws_default",
+ region_name: str | None = None,
**kwargs,
):
super().__init__(**kwargs)
self.job_id = job_id
self.aws_conn_id = aws_conn_id
self.region_name = region_name
- self.hook: Optional[BatchClientHook] = None
+ self.hook: BatchClientHook | None = None
- def poke(self, context: 'Context') -> bool:
+ def poke(self, context: Context) -> bool:
job_description = self.get_hook().get_job_description(self.job_id)
- state = job_description['status']
+ state = job_description["status"]
if state == BatchClientHook.SUCCESS_STATE:
return True
@@ -67,9 +75,9 @@ def poke(self, context: 'Context') -> bool:
return False
if state == BatchClientHook.FAILURE_STATE:
- raise AirflowException(f'Batch sensor failed. AWS Batch job status: {state}')
+ raise AirflowException(f"Batch sensor failed. AWS Batch job status: {state}")
- raise AirflowException(f'Batch sensor failed. Unknown AWS Batch job status: {state}')
+ raise AirflowException(f"Batch sensor failed. Unknown AWS Batch job status: {state}")
def get_hook(self) -> BatchClientHook:
"""Create and return a BatchClientHook"""
@@ -81,3 +89,129 @@ def get_hook(self) -> BatchClientHook:
region_name=self.region_name,
)
return self.hook
+
+
+class BatchComputeEnvironmentSensor(BaseSensorOperator):
+ """
+ Asks for the state of the Batch compute environment until it reaches a failure state or success state.
+ If the environment fails, the task will fail.
+
+ .. seealso::
+ For more information on how to use this sensor, take a look at the guide:
+ :ref:`howto/sensor:BatchComputeEnvironmentSensor`
+
+ :param compute_environment: Batch compute environment name
+
+ :param aws_conn_id: aws connection to use, defaults to 'aws_default'
+
+ :param region_name: aws region name associated with the client
+ """
+
+ template_fields: Sequence[str] = ("compute_environment",)
+ template_ext: Sequence[str] = ()
+ ui_color = "#66c3ff"
+
+ def __init__(
+ self,
+ compute_environment: str,
+ aws_conn_id: str = "aws_default",
+ region_name: str | None = None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.compute_environment = compute_environment
+ self.aws_conn_id = aws_conn_id
+ self.region_name = region_name
+
+ @cached_property
+ def hook(self) -> BatchClientHook:
+ """Create and return a BatchClientHook"""
+ return BatchClientHook(
+ aws_conn_id=self.aws_conn_id,
+ region_name=self.region_name,
+ )
+
+ def poke(self, context: Context) -> bool:
+ response = self.hook.client.describe_compute_environments(
+ computeEnvironments=[self.compute_environment]
+ )
+
+ if len(response["computeEnvironments"]) == 0:
+ raise AirflowException(f"AWS Batch compute environment {self.compute_environment} not found")
+
+ status = response["computeEnvironments"][0]["status"]
+
+ if status in BatchClientHook.COMPUTE_ENVIRONMENT_TERMINAL_STATUS:
+ return True
+
+ if status in BatchClientHook.COMPUTE_ENVIRONMENT_INTERMEDIATE_STATUS:
+ return False
+
+ raise AirflowException(
+ f"AWS Batch compute environment failed. AWS Batch compute environment status: {status}"
+ )
+
+
+class BatchJobQueueSensor(BaseSensorOperator):
+ """
+ Asks for the state of the Batch job queue until it reaches a failure state or success state.
+ If the queue fails, the task will fail.
+
+ .. seealso::
+ For more information on how to use this sensor, take a look at the guide:
+ :ref:`howto/sensor:BatchJobQueueSensor`
+
+ :param job_queue: Batch job queue name
+
+ :param treat_non_existing_as_deleted: If True, a non-existing Batch job queue is considered as a deleted
+ queue and as such a valid case.
+
+ :param aws_conn_id: aws connection to use, defaults to 'aws_default'
+
+ :param region_name: aws region name associated with the client
+ """
+
+ template_fields: Sequence[str] = ("job_queue",)
+ template_ext: Sequence[str] = ()
+ ui_color = "#66c3ff"
+
+ def __init__(
+ self,
+ job_queue: str,
+ treat_non_existing_as_deleted: bool = False,
+ aws_conn_id: str = "aws_default",
+ region_name: str | None = None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.job_queue = job_queue
+ self.treat_non_existing_as_deleted = treat_non_existing_as_deleted
+ self.aws_conn_id = aws_conn_id
+ self.region_name = region_name
+
+ @cached_property
+ def hook(self) -> BatchClientHook:
+ """Create and return a BatchClientHook"""
+ return BatchClientHook(
+ aws_conn_id=self.aws_conn_id,
+ region_name=self.region_name,
+ )
+
+ def poke(self, context: Context) -> bool:
+ response = self.hook.client.describe_job_queues(jobQueues=[self.job_queue])
+
+ if len(response["jobQueues"]) == 0:
+ if self.treat_non_existing_as_deleted:
+ return True
+ else:
+ raise AirflowException(f"AWS Batch job queue {self.job_queue} not found")
+
+ status = response["jobQueues"][0]["status"]
+
+ if status in BatchClientHook.JOB_QUEUE_TERMINAL_STATUS:
+ return True
+
+ if status in BatchClientHook.JOB_QUEUE_INTERMEDIATE_STATUS:
+ return False
+
+ raise AirflowException(f"AWS Batch job queue failed. AWS Batch job queue status: {status}")
diff --git a/airflow/providers/amazon/aws/sensors/cloud_formation.py b/airflow/providers/amazon/aws/sensors/cloud_formation.py
index fb01bdc7f6827..d2bd45592654f 100644
--- a/airflow/providers/amazon/aws/sensors/cloud_formation.py
+++ b/airflow/providers/amazon/aws/sensors/cloud_formation.py
@@ -16,17 +16,14 @@
# specific language governing permissions and limitations
# under the License.
"""This module contains sensors for AWS CloudFormation."""
-import sys
-from typing import TYPE_CHECKING, Optional, Sequence
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Sequence
if TYPE_CHECKING:
from airflow.utils.context import Context
-if sys.version_info >= (3, 8):
- from functools import cached_property
-else:
- from cached_property import cached_property
-
+from airflow.compat.functools import cached_property
from airflow.providers.amazon.aws.hooks.cloud_formation import CloudFormationHook
from airflow.sensors.base import BaseSensorOperator
@@ -36,36 +33,35 @@ class CloudFormationCreateStackSensor(BaseSensorOperator):
Waits for a stack to be created successfully on AWS CloudFormation.
.. seealso::
- For more information on how to use this operator, take a look at the guide:
+ For more information on how to use this sensor, take a look at the guide:
:ref:`howto/sensor:CloudFormationCreateStackSensor`
-
:param stack_name: The name of the stack to wait for (templated)
:param aws_conn_id: ID of the Airflow connection where credentials and extra configuration are
stored
:param poke_interval: Time in seconds that the job should wait between each try
"""
- template_fields: Sequence[str] = ('stack_name',)
- ui_color = '#C5CAE9'
+ template_fields: Sequence[str] = ("stack_name",)
+ ui_color = "#C5CAE9"
- def __init__(self, *, stack_name, aws_conn_id='aws_default', region_name=None, **kwargs):
+ def __init__(self, *, stack_name, aws_conn_id="aws_default", region_name=None, **kwargs):
super().__init__(**kwargs)
self.stack_name = stack_name
self.aws_conn_id = aws_conn_id
self.region_name = region_name
- def poke(self, context: 'Context'):
+ def poke(self, context: Context):
stack_status = self.hook.get_stack_status(self.stack_name)
- if stack_status == 'CREATE_COMPLETE':
+ if stack_status == "CREATE_COMPLETE":
return True
- if stack_status in ('CREATE_IN_PROGRESS', None):
+ if stack_status in ("CREATE_IN_PROGRESS", None):
return False
- raise ValueError(f'Stack {self.stack_name} in bad state: {stack_status}')
+ raise ValueError(f"Stack {self.stack_name} in bad state: {stack_status}")
@cached_property
def hook(self) -> CloudFormationHook:
- """Create and return an CloudFormationHook"""
+ """Create and return a CloudFormationHook"""
return CloudFormationHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
@@ -74,7 +70,7 @@ class CloudFormationDeleteStackSensor(BaseSensorOperator):
Waits for a stack to be deleted successfully on AWS CloudFormation.
.. seealso::
- For more information on how to use this operator, take a look at the guide:
+ For more information on how to use this sensor, take a look at the guide:
:ref:`howto/sensor:CloudFormationDeleteStackSensor`
:param stack_name: The name of the stack to wait for (templated)
@@ -83,15 +79,15 @@ class CloudFormationDeleteStackSensor(BaseSensorOperator):
:param poke_interval: Time in seconds that the job should wait between each try
"""
- template_fields: Sequence[str] = ('stack_name',)
- ui_color = '#C5CAE9'
+ template_fields: Sequence[str] = ("stack_name",)
+ ui_color = "#C5CAE9"
def __init__(
self,
*,
stack_name: str,
- aws_conn_id: str = 'aws_default',
- region_name: Optional[str] = None,
+ aws_conn_id: str = "aws_default",
+ region_name: str | None = None,
**kwargs,
):
super().__init__(**kwargs)
@@ -99,15 +95,15 @@ def __init__(
self.region_name = region_name
self.stack_name = stack_name
- def poke(self, context: 'Context'):
+ def poke(self, context: Context):
stack_status = self.hook.get_stack_status(self.stack_name)
- if stack_status in ('DELETE_COMPLETE', None):
+ if stack_status in ("DELETE_COMPLETE", None):
return True
- if stack_status == 'DELETE_IN_PROGRESS':
+ if stack_status == "DELETE_IN_PROGRESS":
return False
- raise ValueError(f'Stack {self.stack_name} in bad state: {stack_status}')
+ raise ValueError(f"Stack {self.stack_name} in bad state: {stack_status}")
@cached_property
def hook(self) -> CloudFormationHook:
- """Create and return an CloudFormationHook"""
+ """Create and return a CloudFormationHook"""
return CloudFormationHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
diff --git a/airflow/providers/amazon/aws/sensors/dms.py b/airflow/providers/amazon/aws/sensors/dms.py
index 26e6b7148fdcd..9a05d77a21253 100644
--- a/airflow/providers/amazon/aws/sensors/dms.py
+++ b/airflow/providers/amazon/aws/sensors/dms.py
@@ -15,8 +15,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
-from typing import TYPE_CHECKING, Iterable, Optional, Sequence
+from typing import TYPE_CHECKING, Iterable, Sequence
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.dms import DmsHook
@@ -40,15 +41,15 @@ class DmsTaskBaseSensor(BaseSensorOperator):
the task reaches any of these states
"""
- template_fields: Sequence[str] = ('replication_task_arn',)
+ template_fields: Sequence[str] = ("replication_task_arn",)
template_ext: Sequence[str] = ()
def __init__(
self,
replication_task_arn: str,
- aws_conn_id='aws_default',
- target_statuses: Optional[Iterable[str]] = None,
- termination_statuses: Optional[Iterable[str]] = None,
+ aws_conn_id="aws_default",
+ target_statuses: Iterable[str] | None = None,
+ termination_statuses: Iterable[str] | None = None,
*args,
**kwargs,
):
@@ -57,7 +58,7 @@ def __init__(
self.replication_task_arn = replication_task_arn
self.target_statuses: Iterable[str] = target_statuses or []
self.termination_statuses: Iterable[str] = termination_statuses or []
- self.hook: Optional[DmsHook] = None
+ self.hook: DmsHook | None = None
def get_hook(self) -> DmsHook:
"""Get DmsHook"""
@@ -67,21 +68,21 @@ def get_hook(self) -> DmsHook:
self.hook = DmsHook(self.aws_conn_id)
return self.hook
- def poke(self, context: 'Context'):
- status: Optional[str] = self.get_hook().get_task_status(self.replication_task_arn)
+ def poke(self, context: Context):
+ status: str | None = self.get_hook().get_task_status(self.replication_task_arn)
if not status:
raise AirflowException(
- f'Failed to read task status, task with ARN {self.replication_task_arn} not found'
+ f"Failed to read task status, task with ARN {self.replication_task_arn} not found"
)
- self.log.info('DMS Replication task (%s) has status: %s', self.replication_task_arn, status)
+ self.log.info("DMS Replication task (%s) has status: %s", self.replication_task_arn, status)
if status in self.target_statuses:
return True
if status in self.termination_statuses:
- raise AirflowException(f'Unexpected status: {status}')
+ raise AirflowException(f"Unexpected status: {status}")
return False
@@ -91,25 +92,25 @@ class DmsTaskCompletedSensor(DmsTaskBaseSensor):
Pokes DMS task until it is completed.
.. seealso::
- For more information on how to use this operator, take a look at the guide:
+ For more information on how to use this sensor, take a look at the guide:
:ref:`howto/sensor:DmsTaskCompletedSensor`
:param replication_task_arn: AWS DMS replication task ARN
"""
- template_fields: Sequence[str] = ('replication_task_arn',)
+ template_fields: Sequence[str] = ("replication_task_arn",)
template_ext: Sequence[str] = ()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
- self.target_statuses = ['stopped']
+ self.target_statuses = ["stopped"]
self.termination_statuses = [
- 'creating',
- 'deleting',
- 'failed',
- 'failed-move',
- 'modifying',
- 'moving',
- 'ready',
- 'testing',
+ "creating",
+ "deleting",
+ "failed",
+ "failed-move",
+ "modifying",
+ "moving",
+ "ready",
+ "testing",
]
diff --git a/airflow/providers/amazon/aws/sensors/dms_task.py b/airflow/providers/amazon/aws/sensors/dms_task.py
deleted file mode 100644
index 26a06c5812556..0000000000000
--- a/airflow/providers/amazon/aws/sensors/dms_task.py
+++ /dev/null
@@ -1,29 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.dms`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.sensors.dms import DmsTaskBaseSensor, DmsTaskCompletedSensor # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.dms`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/providers/amazon/aws/sensors/ec2.py b/airflow/providers/amazon/aws/sensors/ec2.py
index 7d4a640593fec..bfdb8fcd411c0 100644
--- a/airflow/providers/amazon/aws/sensors/ec2.py
+++ b/airflow/providers/amazon/aws/sensors/ec2.py
@@ -15,9 +15,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-#
+from __future__ import annotations
-from typing import TYPE_CHECKING, Optional, Sequence
+from typing import TYPE_CHECKING, Sequence
from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook
from airflow.sensors.base import BaseSensorOperator
@@ -51,7 +51,7 @@ def __init__(
target_state: str,
instance_id: str,
aws_conn_id: str = "aws_default",
- region_name: Optional[str] = None,
+ region_name: str | None = None,
**kwargs,
):
if target_state not in self.valid_states:
@@ -62,7 +62,7 @@ def __init__(
self.aws_conn_id = aws_conn_id
self.region_name = region_name
- def poke(self, context: 'Context'):
+ def poke(self, context: Context):
ec2_hook = EC2Hook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
instance_state = ec2_hook.get_instance_state(instance_id=self.instance_id)
self.log.info("instance state: %s", instance_state)
diff --git a/airflow/providers/amazon/aws/sensors/ec2_instance_state.py b/airflow/providers/amazon/aws/sensors/ec2_instance_state.py
deleted file mode 100644
index d166b69d20ffb..0000000000000
--- a/airflow/providers/amazon/aws/sensors/ec2_instance_state.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.ec2`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.sensors.ec2 import EC2InstanceStateSensor # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.ec2`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/providers/amazon/aws/sensors/ecs.py b/airflow/providers/amazon/aws/sensors/ecs.py
new file mode 100644
index 0000000000000..c1151b8f9a55c
--- /dev/null
+++ b/airflow/providers/amazon/aws/sensors/ecs.py
@@ -0,0 +1,188 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Sequence
+
+import boto3
+
+from airflow import AirflowException
+from airflow.compat.functools import cached_property
+from airflow.providers.amazon.aws.hooks.ecs import (
+ EcsClusterStates,
+ EcsHook,
+ EcsTaskDefinitionStates,
+ EcsTaskStates,
+)
+from airflow.sensors.base import BaseSensorOperator
+
+if TYPE_CHECKING:
+ from airflow.utils.context import Context
+
+DEFAULT_CONN_ID: str = "aws_default"
+
+
+def _check_failed(current_state, target_state, failure_states):
+ if (current_state != target_state) and (current_state in failure_states):
+ raise AirflowException(
+ f"Terminal state reached. Current state: {current_state}, Expected state: {target_state}"
+ )
+
+
+class EcsBaseSensor(BaseSensorOperator):
+ """Contains general sensor behavior for Elastic Container Service."""
+
+ def __init__(self, *, aws_conn_id: str | None = DEFAULT_CONN_ID, region: str | None = None, **kwargs):
+ self.aws_conn_id = aws_conn_id
+ self.region = region
+ super().__init__(**kwargs)
+
+ @cached_property
+ def hook(self) -> EcsHook:
+ """Create and return an EcsHook."""
+ return EcsHook(aws_conn_id=self.aws_conn_id, region_name=self.region)
+
+ @cached_property
+ def client(self) -> boto3.client:
+ """Create and return an EcsHook client."""
+ return self.hook.conn
+
+
+class EcsClusterStateSensor(EcsBaseSensor):
+ """
+ Polls the cluster state until it reaches a terminal state. Raises an
+ AirflowException with the failure reason if a failed state is reached.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/sensor:EcsClusterStateSensor`
+
+ :param cluster_name: The name of your cluster.
+ :param target_state: Success state to watch for. (Default: "ACTIVE")
+ :param failure_states: Fail if any of these states are reached before the
+ Success State. (Default: "FAILED" or "INACTIVE")
+ """
+
+ template_fields: Sequence[str] = ("cluster_name", "target_state", "failure_states")
+
+ def __init__(
+ self,
+ *,
+ cluster_name: str,
+ target_state: EcsClusterStates | None = EcsClusterStates.ACTIVE,
+ failure_states: set[EcsClusterStates] | None = None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.cluster_name = cluster_name
+ self.target_state = target_state
+ self.failure_states = failure_states or {EcsClusterStates.FAILED, EcsClusterStates.INACTIVE}
+
+ def poke(self, context: Context):
+ cluster_state = EcsClusterStates(self.hook.get_cluster_state(cluster_name=self.cluster_name))
+
+ self.log.info("Cluster state: %s, waiting for: %s", cluster_state, self.target_state)
+ _check_failed(cluster_state, self.target_state, self.failure_states)
+
+ return cluster_state == self.target_state
+
+
+class EcsTaskDefinitionStateSensor(EcsBaseSensor):
+ """
+ Polls the task definition state until it reaches a terminal state. Raises an
+ AirflowException with the failure reason if a failed state is reached.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/sensor:EcsTaskDefinitionStateSensor`
+
+ :param task_definition: The family for the latest ACTIVE revision, family and
+ revision (family:revision ) for a specific revision in the family, or full
+ Amazon Resource Name (ARN) of the task definition.
+ :param target_state: Success state to watch for. (Default: "ACTIVE")
+ """
+
+ template_fields: Sequence[str] = ("task_definition", "target_state", "failure_states")
+
+ def __init__(
+ self,
+ *,
+ task_definition: str,
+ target_state: EcsTaskDefinitionStates | None = EcsTaskDefinitionStates.ACTIVE,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.task_definition = task_definition
+ self.target_state = target_state
+ # There are only two possible states, so set failure_state to whatever is not the target_state
+ self.failure_states = {
+ (
+ EcsTaskDefinitionStates.INACTIVE
+ if target_state == EcsTaskDefinitionStates.ACTIVE
+ else EcsTaskDefinitionStates.ACTIVE
+ )
+ }
+
+ def poke(self, context: Context):
+ task_definition_state = EcsTaskDefinitionStates(
+ self.hook.get_task_definition_state(task_definition=self.task_definition)
+ )
+
+ self.log.info("Task Definition state: %s, waiting for: %s", task_definition_state, self.target_state)
+ _check_failed(task_definition_state, self.target_state, [self.failure_states])
+ return task_definition_state == self.target_state
+
+
+class EcsTaskStateSensor(EcsBaseSensor):
+ """
+ Polls the task state until it reaches a terminal state. Raises an
+ AirflowException with the failure reason if a failed state is reached.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/sensor:EcsTaskStateSensor`
+
+ :param cluster: The short name or full Amazon Resource Name (ARN) of the cluster that hosts the task.
+ :param task: The task ID or full ARN of the task to poll.
+ :param target_state: Success state to watch for. (Default: "ACTIVE")
+ :param failure_states: Fail if any of these states are reached before
+ the Success State. (Default: "STOPPED")
+ """
+
+ template_fields: Sequence[str] = ("cluster", "task", "target_state", "failure_states")
+
+ def __init__(
+ self,
+ *,
+ cluster: str,
+ task: str,
+ target_state: EcsTaskStates | None = EcsTaskStates.RUNNING,
+ failure_states: set[EcsTaskStates] | None = None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.cluster = cluster
+ self.task = task
+ self.target_state = target_state
+ self.failure_states = failure_states or {EcsTaskStates.STOPPED}
+
+ def poke(self, context: Context):
+ task_state = EcsTaskStates(self.hook.get_task_state(cluster=self.cluster, task=self.task))
+
+ self.log.info("Task state: %s, waiting for: %s", task_state, self.target_state)
+ _check_failed(task_state, self.target_state, self.failure_states)
+ return task_state == self.target_state
diff --git a/airflow/providers/amazon/aws/sensors/eks.py b/airflow/providers/amazon/aws/sensors/eks.py
index 92ed55da4d31e..f2ee37215173d 100644
--- a/airflow/providers/amazon/aws/sensors/eks.py
+++ b/airflow/providers/amazon/aws/sensors/eks.py
@@ -14,10 +14,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-#
"""Tracking the state of Amazon EKS Clusters, Amazon EKS managed node groups, and AWS Fargate profiles."""
-import warnings
-from typing import TYPE_CHECKING, Optional, Sequence
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Sequence
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.eks import (
@@ -85,7 +85,7 @@ def __init__(
cluster_name: str,
target_state: ClusterStates = ClusterStates.ACTIVE,
aws_conn_id: str = DEFAULT_CONN_ID,
- region: Optional[str] = None,
+ region: str | None = None,
**kwargs,
):
self.cluster_name = cluster_name
@@ -98,7 +98,7 @@ def __init__(
self.region = region
super().__init__(**kwargs)
- def poke(self, context: 'Context'):
+ def poke(self, context: Context):
eks_hook = EksHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region,
@@ -121,7 +121,7 @@ class EksFargateProfileStateSensor(BaseSensorOperator):
Check the state of an AWS Fargate profile until it reaches the target state or another terminal state.
.. seealso::
- For more information on how to use this operator, take a look at the guide:
+ For more information on how to use this sensor, take a look at the guide:
:ref:`howto/sensor:EksFargateProfileStateSensor`
:param cluster_name: The name of the Cluster which the AWS Fargate profile is attached to. (templated)
@@ -153,7 +153,7 @@ def __init__(
fargate_profile_name: str,
target_state: FargateProfileStates = FargateProfileStates.ACTIVE,
aws_conn_id: str = DEFAULT_CONN_ID,
- region: Optional[str] = None,
+ region: str | None = None,
**kwargs,
):
self.cluster_name = cluster_name
@@ -167,7 +167,7 @@ def __init__(
self.region = region
super().__init__(**kwargs)
- def poke(self, context: 'Context'):
+ def poke(self, context: Context):
eks_hook = EksHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region,
@@ -224,7 +224,7 @@ def __init__(
nodegroup_name: str,
target_state: NodegroupStates = NodegroupStates.ACTIVE,
aws_conn_id: str = DEFAULT_CONN_ID,
- region: Optional[str] = None,
+ region: str | None = None,
**kwargs,
):
self.cluster_name = cluster_name
@@ -238,7 +238,7 @@ def __init__(
self.region = region
super().__init__(**kwargs)
- def poke(self, context: 'Context'):
+ def poke(self, context: Context):
eks_hook = EksHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region,
@@ -256,51 +256,3 @@ def poke(self, context: 'Context'):
)
)
return nodegroup_state == self.target_state
-
-
-class EKSClusterStateSensor(EksClusterStateSensor):
- """
- This sensor is deprecated.
- Please use :class:`airflow.providers.amazon.aws.sensors.eks.EksClusterStateSensor`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- "This sensor is deprecated. "
- "Please use `airflow.providers.amazon.aws.sensors.eks.EksClusterStateSensor`.",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class EKSFargateProfileStateSensor(EksFargateProfileStateSensor):
- """
- This sensor is deprecated.
- Please use :class:`airflow.providers.amazon.aws.sensors.eks.EksFargateProfileStateSensor`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- "This sensor is deprecated. "
- "Please use `airflow.providers.amazon.aws.sensors.eks.EksFargateProfileStateSensor`.",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
-
-
-class EKSNodegroupStateSensor(EksNodegroupStateSensor):
- """
- This sensor is deprecated.
- Please use :class:`airflow.providers.amazon.aws.sensors.eks.EksNodegroupStateSensor`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- "This sensor is deprecated. "
- "Please use `airflow.providers.amazon.aws.sensors.eks.EksNodegroupStateSensor`.",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/providers/amazon/aws/sensors/emr.py b/airflow/providers/amazon/aws/sensors/emr.py
index a4c2b3a71142b..a3684fa249a1d 100644
--- a/airflow/providers/amazon/aws/sensors/emr.py
+++ b/airflow/providers/amazon/aws/sensors/emr.py
@@ -15,21 +15,18 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-import sys
-from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Sequence
+from __future__ import annotations
-if TYPE_CHECKING:
- from airflow.utils.context import Context
+from typing import TYPE_CHECKING, Any, Iterable, Sequence
+from airflow.exceptions import AirflowException
+from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook
+from airflow.sensors.base import BaseSensorOperator, poke_mode_only
-if sys.version_info >= (3, 8):
- from functools import cached_property
-else:
- from cached_property import cached_property
+if TYPE_CHECKING:
+ from airflow.utils.context import Context
-from airflow.exceptions import AirflowException
-from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook
-from airflow.sensors.base import BaseSensorOperator
+from airflow.compat.functools import cached_property
class EmrBaseSensor(BaseSensorOperator):
@@ -43,17 +40,17 @@ class EmrBaseSensor(BaseSensorOperator):
Subclasses should set ``target_states`` and ``failed_states`` fields.
- :param aws_conn_id: aws connection to uses
+ :param aws_conn_id: aws connection to use
"""
- ui_color = '#66c3ff'
+ ui_color = "#66c3ff"
- def __init__(self, *, aws_conn_id: str = 'aws_default', **kwargs):
+ def __init__(self, *, aws_conn_id: str = "aws_default", **kwargs):
super().__init__(**kwargs)
self.aws_conn_id = aws_conn_id
self.target_states: Iterable[str] = [] # will be set in subclasses
self.failed_states: Iterable[str] = [] # will be set in subclasses
- self.hook: Optional[EmrHook] = None
+ self.hook: EmrHook | None = None
def get_hook(self) -> EmrHook:
"""Get EmrHook"""
@@ -63,58 +60,173 @@ def get_hook(self) -> EmrHook:
self.hook = EmrHook(aws_conn_id=self.aws_conn_id)
return self.hook
- def poke(self, context: 'Context'):
+ def poke(self, context: Context):
response = self.get_emr_response()
- if response['ResponseMetadata']['HTTPStatusCode'] != 200:
- self.log.info('Bad HTTP response: %s', response)
+ if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
+ self.log.info("Bad HTTP response: %s", response)
return False
state = self.state_from_response(response)
- self.log.info('Job flow currently %s', state)
+ self.log.info("Job flow currently %s", state)
if state in self.target_states:
return True
if state in self.failed_states:
- final_message = 'EMR job failed'
+ final_message = "EMR job failed"
failure_message = self.failure_message_from_response(response)
if failure_message:
- final_message += ' ' + failure_message
+ final_message += " " + failure_message
raise AirflowException(final_message)
return False
- def get_emr_response(self) -> Dict[str, Any]:
+ def get_emr_response(self) -> dict[str, Any]:
"""
Make an API call with boto3 and get response.
:return: response
- :rtype: dict[str, Any]
"""
- raise NotImplementedError('Please implement get_emr_response() in subclass')
+ raise NotImplementedError("Please implement get_emr_response() in subclass")
@staticmethod
- def state_from_response(response: Dict[str, Any]) -> str:
+ def state_from_response(response: dict[str, Any]) -> str:
"""
Get state from response dictionary.
:param response: response from AWS API
:return: state
- :rtype: str
"""
- raise NotImplementedError('Please implement state_from_response() in subclass')
+ raise NotImplementedError("Please implement state_from_response() in subclass")
+
+ @staticmethod
+ def failure_message_from_response(response: dict[str, Any]) -> str | None:
+ """
+ Get failure message from response dictionary.
+
+ :param response: response from AWS API
+ :return: failure message
+ """
+ raise NotImplementedError("Please implement failure_message_from_response() in subclass")
+
+
+class EmrServerlessJobSensor(BaseSensorOperator):
+ """
+ Asks for the state of the job run until it reaches a failure state or success state.
+ If the job run fails, the task will fail.
+
+ .. seealso::
+ For more information on how to use this sensor, take a look at the guide:
+ :ref:`howto/sensor:EmrServerlessJobSensor`
+
+ :param application_id: application_id to check the state of
+ :param job_run_id: job_run_id to check the state of
+ :param target_states: a set of states to wait for, defaults to 'SUCCESS'
+ :param aws_conn_id: aws connection to use, defaults to 'aws_default'
+ """
+
+ template_fields: Sequence[str] = (
+ "application_id",
+ "job_run_id",
+ )
+
+ def __init__(
+ self,
+ *,
+ application_id: str,
+ job_run_id: str,
+ target_states: set | frozenset = frozenset(EmrServerlessHook.JOB_SUCCESS_STATES),
+ aws_conn_id: str = "aws_default",
+ **kwargs: Any,
+ ) -> None:
+ self.aws_conn_id = aws_conn_id
+ self.target_states = target_states
+ self.application_id = application_id
+ self.job_run_id = job_run_id
+ super().__init__(**kwargs)
+
+ def poke(self, context: Context) -> bool:
+ response = self.hook.conn.get_job_run(applicationId=self.application_id, jobRunId=self.job_run_id)
+
+ state = response["jobRun"]["state"]
+
+ if state in EmrServerlessHook.JOB_FAILURE_STATES:
+ failure_message = f"EMR Serverless job failed: {self.failure_message_from_response(response)}"
+ raise AirflowException(failure_message)
+
+ return state in self.target_states
+
+ @cached_property
+ def hook(self) -> EmrServerlessHook:
+ """Create and return an EmrServerlessHook"""
+ return EmrServerlessHook(aws_conn_id=self.aws_conn_id)
+
+ @staticmethod
+ def failure_message_from_response(response: dict[str, Any]) -> str | None:
+ """
+ Get failure message from response dictionary.
+
+ :param response: response from AWS API
+ :return: failure message
+ """
+ return response["jobRun"]["stateDetails"]
+
+
+class EmrServerlessApplicationSensor(BaseSensorOperator):
+ """
+ Asks for the state of the application until it reaches a failure state or success state.
+ If the application fails, the task will fail.
+
+ .. seealso::
+ For more information on how to use this sensor, take a look at the guide:
+ :ref:`howto/sensor:EmrServerlessApplicationSensor`
+
+ :param application_id: application_id to check the state of
+ :param target_states: a set of states to wait for, defaults to {'CREATED', 'STARTED'}
+ :param aws_conn_id: aws connection to use, defaults to 'aws_default'
+ """
+
+ template_fields: Sequence[str] = ("application_id",)
+
+ def __init__(
+ self,
+ *,
+ application_id: str,
+ target_states: set | frozenset = frozenset(EmrServerlessHook.APPLICATION_SUCCESS_STATES),
+ aws_conn_id: str = "aws_default",
+ **kwargs: Any,
+ ) -> None:
+ self.aws_conn_id = aws_conn_id
+ self.target_states = target_states
+ self.application_id = application_id
+ super().__init__(**kwargs)
+
+ def poke(self, context: Context) -> bool:
+ response = self.hook.conn.get_application(applicationId=self.application_id)
+
+ state = response["application"]["state"]
+
+ if state in EmrServerlessHook.APPLICATION_FAILURE_STATES:
+ failure_message = f"EMR Serverless job failed: {self.failure_message_from_response(response)}"
+ raise AirflowException(failure_message)
+
+ return state in self.target_states
+
+ @cached_property
+ def hook(self) -> EmrServerlessHook:
+ """Create and return an EmrServerlessHook"""
+ return EmrServerlessHook(aws_conn_id=self.aws_conn_id)
@staticmethod
- def failure_message_from_response(response: Dict[str, Any]) -> Optional[str]:
+ def failure_message_from_response(response: dict[str, Any]) -> str | None:
"""
Get failure message from response dictionary.
:param response: response from AWS API
:return: failure message
- :rtype: Optional[str]
"""
- raise NotImplementedError('Please implement failure_message_from_response() in subclass')
+ return response["application"]["stateDetails"]
class EmrContainerSensor(BaseSensorOperator):
@@ -146,17 +258,17 @@ class EmrContainerSensor(BaseSensorOperator):
)
SUCCESS_STATES = ("COMPLETED",)
- template_fields: Sequence[str] = ('virtual_cluster_id', 'job_id')
+ template_fields: Sequence[str] = ("virtual_cluster_id", "job_id")
template_ext: Sequence[str] = ()
- ui_color = '#66c3ff'
+ ui_color = "#66c3ff"
def __init__(
self,
*,
virtual_cluster_id: str,
job_id: str,
- max_retries: Optional[int] = None,
- aws_conn_id: str = 'aws_default',
+ max_retries: int | None = None,
+ aws_conn_id: str = "aws_default",
poll_interval: int = 10,
**kwargs: Any,
) -> None:
@@ -167,11 +279,15 @@ def __init__(
self.poll_interval = poll_interval
self.max_retries = max_retries
- def poke(self, context: 'Context') -> bool:
- state = self.hook.poll_query_status(self.job_id, self.max_retries, self.poll_interval)
+ def poke(self, context: Context) -> bool:
+ state = self.hook.poll_query_status(
+ self.job_id,
+ max_polling_attempts=self.max_retries,
+ poll_interval=self.poll_interval,
+ )
if state in self.FAILURE_STATES:
- raise AirflowException('EMR Containers sensor failed')
+ raise AirflowException("EMR Containers sensor failed")
if state in self.INTERMEDIATE_STATES:
return False
@@ -204,23 +320,23 @@ class EmrJobFlowSensor(EmrBaseSensor):
job flow reaches any of these states
"""
- template_fields: Sequence[str] = ('job_flow_id', 'target_states', 'failed_states')
+ template_fields: Sequence[str] = ("job_flow_id", "target_states", "failed_states")
template_ext: Sequence[str] = ()
def __init__(
self,
*,
job_flow_id: str,
- target_states: Optional[Iterable[str]] = None,
- failed_states: Optional[Iterable[str]] = None,
+ target_states: Iterable[str] | None = None,
+ failed_states: Iterable[str] | None = None,
**kwargs,
):
super().__init__(**kwargs)
self.job_flow_id = job_flow_id
- self.target_states = target_states or ['TERMINATED']
- self.failed_states = failed_states or ['TERMINATED_WITH_ERRORS']
+ self.target_states = target_states or ["TERMINATED"]
+ self.failed_states = failed_states or ["TERMINATED_WITH_ERRORS"]
- def get_emr_response(self) -> Dict[str, Any]:
+ def get_emr_response(self) -> dict[str, Any]:
"""
Make an API call with boto3 and get cluster-level details.
@@ -228,35 +344,32 @@ def get_emr_response(self) -> Dict[str, Any]:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/emr.html#EMR.Client.describe_cluster
:return: response
- :rtype: dict[str, Any]
"""
emr_client = self.get_hook().get_conn()
- self.log.info('Poking cluster %s', self.job_flow_id)
+ self.log.info("Poking cluster %s", self.job_flow_id)
return emr_client.describe_cluster(ClusterId=self.job_flow_id)
@staticmethod
- def state_from_response(response: Dict[str, Any]) -> str:
+ def state_from_response(response: dict[str, Any]) -> str:
"""
Get state from response dictionary.
:param response: response from AWS API
:return: current state of the cluster
- :rtype: str
"""
- return response['Cluster']['Status']['State']
+ return response["Cluster"]["Status"]["State"]
@staticmethod
- def failure_message_from_response(response: Dict[str, Any]) -> Optional[str]:
+ def failure_message_from_response(response: dict[str, Any]) -> str | None:
"""
Get failure message from response dictionary.
:param response: response from AWS API
:return: failure message
- :rtype: Optional[str]
"""
- cluster_status = response['Cluster']['Status']
- state_change_reason = cluster_status.get('StateChangeReason')
+ cluster_status = response["Cluster"]["Status"]
+ state_change_reason = cluster_status.get("StateChangeReason")
if state_change_reason:
return (
f"for code: {state_change_reason.get('Code', 'No code')} "
@@ -265,6 +378,7 @@ def failure_message_from_response(response: Dict[str, Any]) -> Optional[str]:
return None
+@poke_mode_only
class EmrStepSensor(EmrBaseSensor):
"""
Asks for the state of the step until it reaches any of the target states.
@@ -284,7 +398,7 @@ class EmrStepSensor(EmrBaseSensor):
step reaches any of these states
"""
- template_fields: Sequence[str] = ('job_flow_id', 'step_id', 'target_states', 'failed_states')
+ template_fields: Sequence[str] = ("job_flow_id", "step_id", "target_states", "failed_states")
template_ext: Sequence[str] = ()
def __init__(
@@ -292,17 +406,17 @@ def __init__(
*,
job_flow_id: str,
step_id: str,
- target_states: Optional[Iterable[str]] = None,
- failed_states: Optional[Iterable[str]] = None,
+ target_states: Iterable[str] | None = None,
+ failed_states: Iterable[str] | None = None,
**kwargs,
):
super().__init__(**kwargs)
self.job_flow_id = job_flow_id
self.step_id = step_id
- self.target_states = target_states or ['COMPLETED']
- self.failed_states = failed_states or ['CANCELLED', 'FAILED', 'INTERRUPTED']
+ self.target_states = target_states or ["COMPLETED"]
+ self.failed_states = failed_states or ["CANCELLED", "FAILED", "INTERRUPTED"]
- def get_emr_response(self) -> Dict[str, Any]:
+ def get_emr_response(self) -> dict[str, Any]:
"""
Make an API call with boto3 and get details about the cluster step.
@@ -310,34 +424,31 @@ def get_emr_response(self) -> Dict[str, Any]:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/emr.html#EMR.Client.describe_step
:return: response
- :rtype: dict[str, Any]
"""
emr_client = self.get_hook().get_conn()
- self.log.info('Poking step %s on cluster %s', self.step_id, self.job_flow_id)
+ self.log.info("Poking step %s on cluster %s", self.step_id, self.job_flow_id)
return emr_client.describe_step(ClusterId=self.job_flow_id, StepId=self.step_id)
@staticmethod
- def state_from_response(response: Dict[str, Any]) -> str:
+ def state_from_response(response: dict[str, Any]) -> str:
"""
Get state from response dictionary.
:param response: response from AWS API
:return: execution state of the cluster step
- :rtype: str
"""
- return response['Step']['Status']['State']
+ return response["Step"]["Status"]["State"]
@staticmethod
- def failure_message_from_response(response: Dict[str, Any]) -> Optional[str]:
+ def failure_message_from_response(response: dict[str, Any]) -> str | None:
"""
Get failure message from response dictionary.
:param response: response from AWS API
:return: failure message
- :rtype: Optional[str]
"""
- fail_details = response['Step']['Status'].get('FailureDetails')
+ fail_details = response["Step"]["Status"].get("FailureDetails")
if fail_details:
return (
f"for reason {fail_details.get('Reason')} "
diff --git a/airflow/providers/amazon/aws/sensors/emr_base.py b/airflow/providers/amazon/aws/sensors/emr_base.py
deleted file mode 100644
index 89991d7fa04aa..0000000000000
--- a/airflow/providers/amazon/aws/sensors/emr_base.py
+++ /dev/null
@@ -1,29 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.emr`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.sensors.emr import EmrBaseSensor # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.emr`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/providers/amazon/aws/sensors/emr_containers.py b/airflow/providers/amazon/aws/sensors/emr_containers.py
deleted file mode 100644
index 6cfa9adb7d76d..0000000000000
--- a/airflow/providers/amazon/aws/sensors/emr_containers.py
+++ /dev/null
@@ -1,45 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.emr`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.sensors.emr import EmrBaseSensor, EmrContainerSensor # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.emr`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class EMRContainerSensor(EmrContainerSensor):
- """
- This class is deprecated.
- Please use :class:`airflow.providers.amazon.aws.sensors.emr.EmrContainerSensor`.
- """
-
- def __init__(self, **kwargs):
- warnings.warn(
- """This class is deprecated.
- Please use `airflow.providers.amazon.aws.sensors.emr.EmrContainerSensor`.""",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(**kwargs)
diff --git a/airflow/providers/amazon/aws/sensors/emr_job_flow.py b/airflow/providers/amazon/aws/sensors/emr_job_flow.py
deleted file mode 100644
index 31b6dabb4a379..0000000000000
--- a/airflow/providers/amazon/aws/sensors/emr_job_flow.py
+++ /dev/null
@@ -1,29 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.emr`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.sensors.emr import EmrJobFlowSensor # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.emr`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/providers/amazon/aws/sensors/emr_step.py b/airflow/providers/amazon/aws/sensors/emr_step.py
deleted file mode 100644
index aca71619089e7..0000000000000
--- a/airflow/providers/amazon/aws/sensors/emr_step.py
+++ /dev/null
@@ -1,30 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.emr`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.sensors.emr import EmrJobFlowSensor, EmrStepSensor # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.emr`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/providers/amazon/aws/sensors/glacier.py b/airflow/providers/amazon/aws/sensors/glacier.py
index e92f5a4326b3a..857e578327a9e 100644
--- a/airflow/providers/amazon/aws/sensors/glacier.py
+++ b/airflow/providers/amazon/aws/sensors/glacier.py
@@ -15,6 +15,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
from enum import Enum
from typing import TYPE_CHECKING, Any, Sequence
@@ -38,7 +40,7 @@ class GlacierJobOperationSensor(BaseSensorOperator):
Glacier sensor for checking job state. This operator runs only in reschedule mode.
.. seealso::
- For more information on how to use this operator, take a look at the guide:
+ For more information on how to use this sensor, take a look at the guide:
:ref:`howto/sensor:GlacierJobOperationSensor`
:param aws_conn_id: The reference to the AWS connection details
@@ -65,7 +67,7 @@ class GlacierJobOperationSensor(BaseSensorOperator):
def __init__(
self,
*,
- aws_conn_id: str = 'aws_default',
+ aws_conn_id: str = "aws_default",
vault_name: str,
job_id: str,
poke_interval: int = 60 * 20,
@@ -79,7 +81,7 @@ def __init__(
self.poke_interval = poke_interval
self.mode = mode
- def poke(self, context: 'Context') -> bool:
+ def poke(self, context: Context) -> bool:
hook = GlacierHook(aws_conn_id=self.aws_conn_id)
response = hook.describe_job(vault_name=self.vault_name, job_id=self.job_id)
diff --git a/airflow/providers/amazon/aws/sensors/glue.py b/airflow/providers/amazon/aws/sensors/glue.py
index 525e7b8ee6234..87e8f2c249d63 100644
--- a/airflow/providers/amazon/aws/sensors/glue.py
+++ b/airflow/providers/amazon/aws/sensors/glue.py
@@ -15,7 +15,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-import warnings
+from __future__ import annotations
+
from typing import TYPE_CHECKING, Sequence
from airflow.exceptions import AirflowException
@@ -37,43 +38,51 @@ class GlueJobSensor(BaseSensorOperator):
:param job_name: The AWS Glue Job unique name
:param run_id: The AWS Glue current running job identifier
+ :param verbose: If True, more Glue Job Run logs show in the Airflow Task Logs. (default: False)
"""
- template_fields: Sequence[str] = ('job_name', 'run_id')
+ template_fields: Sequence[str] = ("job_name", "run_id")
- def __init__(self, *, job_name: str, run_id: str, aws_conn_id: str = 'aws_default', **kwargs):
+ def __init__(
+ self,
+ *,
+ job_name: str,
+ run_id: str,
+ verbose: bool = False,
+ aws_conn_id: str = "aws_default",
+ **kwargs,
+ ):
super().__init__(**kwargs)
self.job_name = job_name
self.run_id = run_id
+ self.verbose = verbose
self.aws_conn_id = aws_conn_id
- self.success_states = ['SUCCEEDED']
- self.errored_states = ['FAILED', 'STOPPED', 'TIMEOUT']
+ self.success_states: list[str] = ["SUCCEEDED"]
+ self.errored_states: list[str] = ["FAILED", "STOPPED", "TIMEOUT"]
+ self.next_log_token: str | None = None
- def poke(self, context: 'Context'):
+ def poke(self, context: Context):
hook = GlueJobHook(aws_conn_id=self.aws_conn_id)
self.log.info("Poking for job run status :for Glue Job %s and ID %s", self.job_name, self.run_id)
job_state = hook.get_job_state(job_name=self.job_name, run_id=self.run_id)
- if job_state in self.success_states:
- self.log.info("Exiting Job %s Run State: %s", self.run_id, job_state)
- return True
- elif job_state in self.errored_states:
- job_error_message = f"Exiting Job {self.run_id} Run State: {job_state}"
- raise AirflowException(job_error_message)
- else:
- return False
-
-
-class AwsGlueJobSensor(GlueJobSensor):
- """
- This sensor is deprecated.
- Please use :class:`airflow.providers.amazon.aws.sensors.glue.GlueJobSensor`.
- """
+ job_failed = False
- def __init__(self, *args, **kwargs):
- warnings.warn(
- "This sensor is deprecated. "
- "Please use :class:`airflow.providers.amazon.aws.sensors.glue.GlueJobSensor`.",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
+ try:
+ if job_state in self.success_states:
+ self.log.info("Exiting Job %s Run State: %s", self.run_id, job_state)
+ return True
+ elif job_state in self.errored_states:
+ job_failed = True
+ job_error_message = "Exiting Job %s Run State: %s", self.run_id, job_state
+ self.log.info(job_error_message)
+ raise AirflowException(job_error_message)
+ else:
+ return False
+ finally:
+ if self.verbose:
+ self.next_log_token = hook.print_job_logs(
+ job_name=self.job_name,
+ run_id=self.run_id,
+ job_failed=job_failed,
+ next_token=self.next_log_token,
+ )
diff --git a/airflow/providers/amazon/aws/sensors/glue_catalog_partition.py b/airflow/providers/amazon/aws/sensors/glue_catalog_partition.py
index c49277f34342f..21bf8cb772447 100644
--- a/airflow/providers/amazon/aws/sensors/glue_catalog_partition.py
+++ b/airflow/providers/amazon/aws/sensors/glue_catalog_partition.py
@@ -15,8 +15,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-import warnings
-from typing import TYPE_CHECKING, Optional, Sequence
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Sequence
from airflow.providers.amazon.aws.hooks.glue_catalog import GlueCatalogHook
from airflow.sensors.base import BaseSensorOperator
@@ -47,20 +48,20 @@ class GlueCatalogPartitionSensor(BaseSensorOperator):
"""
template_fields: Sequence[str] = (
- 'database_name',
- 'table_name',
- 'expression',
+ "database_name",
+ "table_name",
+ "expression",
)
- ui_color = '#C5CAE9'
+ ui_color = "#C5CAE9"
def __init__(
self,
*,
table_name: str,
expression: str = "ds='{{ ds }}'",
- aws_conn_id: str = 'aws_default',
- region_name: Optional[str] = None,
- database_name: str = 'default',
+ aws_conn_id: str = "aws_default",
+ region_name: str | None = None,
+ database_name: str = "default",
poke_interval: int = 60 * 3,
**kwargs,
):
@@ -70,14 +71,14 @@ def __init__(
self.table_name = table_name
self.expression = expression
self.database_name = database_name
- self.hook: Optional[GlueCatalogHook] = None
+ self.hook: GlueCatalogHook | None = None
- def poke(self, context: 'Context'):
+ def poke(self, context: Context):
"""Checks for existence of the partition in the AWS Glue Catalog table"""
- if '.' in self.table_name:
- self.database_name, self.table_name = self.table_name.split('.')
+ if "." in self.table_name:
+ self.database_name, self.table_name = self.table_name.split(".")
self.log.info(
- 'Poking for table %s. %s, expression %s', self.database_name, self.table_name, self.expression
+ "Poking for table %s. %s, expression %s", self.database_name, self.table_name, self.expression
)
return self.get_hook().check_for_partition(self.database_name, self.table_name, self.expression)
@@ -89,19 +90,3 @@ def get_hook(self) -> GlueCatalogHook:
self.hook = GlueCatalogHook(aws_conn_id=self.aws_conn_id, region_name=self.region_name)
return self.hook
-
-
-class AwsGlueCatalogPartitionSensor(GlueCatalogPartitionSensor):
- """
- This sensor is deprecated. Please use
- :class:`airflow.providers.amazon.aws.sensors.glue_catalog_partition.GlueCatalogPartitionSensor`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- "This sensor is deprecated. "
- "Please use :class:`airflow.providers.amazon.aws.sensors.glue_catalog_partition.GlueCatalogPartitionSensor`.", # noqa: 501
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/providers/amazon/aws/sensors/glue_crawler.py b/airflow/providers/amazon/aws/sensors/glue_crawler.py
index 52944ce2eb30e..3032c603c625b 100644
--- a/airflow/providers/amazon/aws/sensors/glue_crawler.py
+++ b/airflow/providers/amazon/aws/sensors/glue_crawler.py
@@ -15,8 +15,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-import warnings
-from typing import TYPE_CHECKING, Optional
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Sequence
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.glue_crawler import GlueCrawlerHook
@@ -39,21 +40,23 @@ class GlueCrawlerSensor(BaseSensorOperator):
:param aws_conn_id: aws connection to use, defaults to 'aws_default'
"""
- def __init__(self, *, crawler_name: str, aws_conn_id: str = 'aws_default', **kwargs) -> None:
+ template_fields: Sequence[str] = ("crawler_name",)
+
+ def __init__(self, *, crawler_name: str, aws_conn_id: str = "aws_default", **kwargs) -> None:
super().__init__(**kwargs)
self.crawler_name = crawler_name
self.aws_conn_id = aws_conn_id
- self.success_statuses = 'SUCCEEDED'
- self.errored_statuses = ('FAILED', 'CANCELLED')
- self.hook: Optional[GlueCrawlerHook] = None
+ self.success_statuses = "SUCCEEDED"
+ self.errored_statuses = ("FAILED", "CANCELLED")
+ self.hook: GlueCrawlerHook | None = None
- def poke(self, context: 'Context'):
+ def poke(self, context: Context):
hook = self.get_hook()
self.log.info("Poking for AWS Glue crawler: %s", self.crawler_name)
- crawler_state = hook.get_crawler(self.crawler_name)['State']
- if crawler_state == 'READY':
+ crawler_state = hook.get_crawler(self.crawler_name)["State"]
+ if crawler_state == "READY":
self.log.info("State: %s", crawler_state)
- crawler_status = hook.get_crawler(self.crawler_name)['LastCrawl']['Status']
+ crawler_status = hook.get_crawler(self.crawler_name)["LastCrawl"]["Status"]
if crawler_status == self.success_statuses:
self.log.info("Status: %s", crawler_status)
return True
@@ -69,19 +72,3 @@ def get_hook(self) -> GlueCrawlerHook:
self.hook = GlueCrawlerHook(aws_conn_id=self.aws_conn_id)
return self.hook
-
-
-class AwsGlueCrawlerSensor(GlueCrawlerSensor):
- """
- This sensor is deprecated. Please use
- :class:`airflow.providers.amazon.aws.sensors.glue_crawler.GlueCrawlerSensor`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- "This sensor is deprecated. "
- "Please use :class:`airflow.providers.amazon.aws.sensors.glue_crawler.GlueCrawlerSensor`.",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/providers/amazon/aws/sensors/quicksight.py b/airflow/providers/amazon/aws/sensors/quicksight.py
index da94980e80650..09cc92cf96dd5 100644
--- a/airflow/providers/amazon/aws/sensors/quicksight.py
+++ b/airflow/providers/amazon/aws/sensors/quicksight.py
@@ -15,10 +15,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
-import sys
-from typing import TYPE_CHECKING, Optional
+from typing import TYPE_CHECKING, Sequence
+from airflow.compat.functools import cached_property
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.quicksight import QuickSightHook
from airflow.providers.amazon.aws.hooks.sts import StsHook
@@ -27,11 +28,6 @@
if TYPE_CHECKING:
from airflow.utils.context import Context
-if sys.version_info >= (3, 8):
- from functools import cached_property
-else:
- from cached_property import cached_property
-
class QuickSightSensor(BaseSensorOperator):
"""
@@ -50,6 +46,8 @@ class QuickSightSensor(BaseSensorOperator):
maintained on each worker node).
"""
+ template_fields: Sequence[str] = ("data_set_id", "ingestion_id", "aws_conn_id")
+
def __init__(
self,
*,
@@ -64,16 +62,15 @@ def __init__(
self.aws_conn_id = aws_conn_id
self.success_status = "COMPLETED"
self.errored_statuses = ("FAILED", "CANCELLED")
- self.quicksight_hook: Optional[QuickSightHook] = None
- self.sts_hook: Optional[StsHook] = None
+ self.quicksight_hook: QuickSightHook | None = None
+ self.sts_hook: StsHook | None = None
- def poke(self, context: "Context"):
+ def poke(self, context: Context) -> bool:
"""
Pokes until the QuickSight Ingestion has successfully finished.
:param context: The task context during execution.
:return: True if it COMPLETED and False if not.
- :rtype: bool
"""
quicksight_hook = self.get_quicksight_hook
sts_hook = self.get_sts_hook
diff --git a/airflow/providers/amazon/aws/sensors/rds.py b/airflow/providers/amazon/aws/sensors/rds.py
index 3c24c82fbf940..731c8b5def6fb 100644
--- a/airflow/providers/amazon/aws/sensors/rds.py
+++ b/airflow/providers/amazon/aws/sensors/rds.py
@@ -14,12 +14,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
-from typing import TYPE_CHECKING, List, Optional, Sequence
+from typing import TYPE_CHECKING, Sequence
-from botocore.exceptions import ClientError
-
-from airflow import AirflowException
+from airflow.exceptions import AirflowNotFoundException
from airflow.providers.amazon.aws.hooks.rds import RdsHook
from airflow.providers.amazon.aws.utils.rds import RdsDbType
from airflow.sensors.base import BaseSensorOperator
@@ -34,43 +33,19 @@ class RdsBaseSensor(BaseSensorOperator):
ui_color = "#ddbb77"
ui_fgcolor = "#ffffff"
- def __init__(self, *args, aws_conn_id: str = "aws_conn_id", hook_params: Optional[dict] = None, **kwargs):
+ def __init__(self, *args, aws_conn_id: str = "aws_conn_id", hook_params: dict | None = None, **kwargs):
hook_params = hook_params or {}
self.hook = RdsHook(aws_conn_id=aws_conn_id, **hook_params)
- self.target_statuses: List[str] = []
+ self.target_statuses: list[str] = []
super().__init__(*args, **kwargs)
- def _describe_item(self, item_type: str, item_name: str) -> list:
-
- if item_type == 'instance_snapshot':
- db_snaps = self.hook.conn.describe_db_snapshots(DBSnapshotIdentifier=item_name)
- return db_snaps['DBSnapshots']
- elif item_type == 'cluster_snapshot':
- cl_snaps = self.hook.conn.describe_db_cluster_snapshots(DBClusterSnapshotIdentifier=item_name)
- return cl_snaps['DBClusterSnapshots']
- elif item_type == 'export_task':
- exports = self.hook.conn.describe_export_tasks(ExportTaskIdentifier=item_name)
- return exports['ExportTasks']
- else:
- raise AirflowException(f"Method for {item_type} is not implemented")
-
- def _check_item(self, item_type: str, item_name: str) -> bool:
- """Get certain item from `_describe_item()` and check its status"""
-
- try:
- items = self._describe_item(item_type, item_name)
- except ClientError:
- return False
- else:
- return bool(items) and any(map(lambda s: items[0]['Status'].lower() == s, self.target_statuses))
-
class RdsSnapshotExistenceSensor(RdsBaseSensor):
"""
Waits for RDS snapshot with a specific status.
.. seealso::
- For more information on how to use this operator, take a look at the guide:
+ For more information on how to use this sensor, take a look at the guide:
:ref:`howto/sensor:RdsSnapshotExistenceSensor`
:param db_type: Type of the DB - either "instance" or "cluster"
@@ -79,8 +54,8 @@ class RdsSnapshotExistenceSensor(RdsBaseSensor):
"""
template_fields: Sequence[str] = (
- 'db_snapshot_identifier',
- 'target_statuses',
+ "db_snapshot_identifier",
+ "target_statuses",
)
def __init__(
@@ -88,23 +63,27 @@ def __init__(
*,
db_type: str,
db_snapshot_identifier: str,
- target_statuses: Optional[List[str]] = None,
+ target_statuses: list[str] | None = None,
aws_conn_id: str = "aws_conn_id",
**kwargs,
):
super().__init__(aws_conn_id=aws_conn_id, **kwargs)
self.db_type = RdsDbType(db_type)
self.db_snapshot_identifier = db_snapshot_identifier
- self.target_statuses = target_statuses or ['available']
+ self.target_statuses = target_statuses or ["available"]
- def poke(self, context: 'Context'):
+ def poke(self, context: Context):
self.log.info(
- 'Poking for statuses : %s\nfor snapshot %s', self.target_statuses, self.db_snapshot_identifier
+ "Poking for statuses : %s\nfor snapshot %s", self.target_statuses, self.db_snapshot_identifier
)
- if self.db_type.value == "instance":
- return self._check_item(item_type='instance_snapshot', item_name=self.db_snapshot_identifier)
- else:
- return self._check_item(item_type='cluster_snapshot', item_name=self.db_snapshot_identifier)
+ try:
+ if self.db_type.value == "instance":
+ state = self.hook.get_db_snapshot_state(self.db_snapshot_identifier)
+ else:
+ state = self.hook.get_db_cluster_snapshot_state(self.db_snapshot_identifier)
+ except AirflowNotFoundException:
+ return False
+ return state in self.target_statuses
class RdsExportTaskExistenceSensor(RdsBaseSensor):
@@ -112,7 +91,7 @@ class RdsExportTaskExistenceSensor(RdsBaseSensor):
Waits for RDS export task with a specific status.
.. seealso::
- For more information on how to use this operator, take a look at the guide:
+ For more information on how to use this sensor, take a look at the guide:
:ref:`howto/sensor:RdsExportTaskExistenceSensor`
:param export_task_identifier: A unique identifier for the snapshot export task.
@@ -120,15 +99,15 @@ class RdsExportTaskExistenceSensor(RdsBaseSensor):
"""
template_fields: Sequence[str] = (
- 'export_task_identifier',
- 'target_statuses',
+ "export_task_identifier",
+ "target_statuses",
)
def __init__(
self,
*,
export_task_identifier: str,
- target_statuses: Optional[List[str]] = None,
+ target_statuses: list[str] | None = None,
aws_conn_id: str = "aws_default",
**kwargs,
):
@@ -136,21 +115,74 @@ def __init__(
self.export_task_identifier = export_task_identifier
self.target_statuses = target_statuses or [
- 'starting',
- 'in_progress',
- 'complete',
- 'canceling',
- 'canceled',
+ "starting",
+ "in_progress",
+ "complete",
+ "canceling",
+ "canceled",
]
- def poke(self, context: 'Context'):
+ def poke(self, context: Context):
+ self.log.info(
+ "Poking for statuses : %s\nfor export task %s", self.target_statuses, self.export_task_identifier
+ )
+ try:
+ state = self.hook.get_export_task_state(self.export_task_identifier)
+ except AirflowNotFoundException:
+ return False
+ return state in self.target_statuses
+
+
+class RdsDbSensor(RdsBaseSensor):
+ """
+ Waits for an RDS instance or cluster to enter one of a number of states
+
+ .. seealso::
+ For more information on how to use this sensor, take a look at the guide:
+ :ref:`howto/sensor:RdsDbSensor`
+
+ :param db_type: Type of the DB - either "instance" or "cluster" (default: 'instance')
+ :param db_identifier: The AWS identifier for the DB
+ :param target_statuses: Target status of DB
+ """
+
+ template_fields: Sequence[str] = (
+ "db_identifier",
+ "db_type",
+ "target_statuses",
+ )
+
+ def __init__(
+ self,
+ *,
+ db_identifier: str,
+ db_type: RdsDbType | str = RdsDbType.INSTANCE,
+ target_statuses: list[str] | None = None,
+ aws_conn_id: str = "aws_default",
+ **kwargs,
+ ):
+ super().__init__(aws_conn_id=aws_conn_id, **kwargs)
+ self.db_identifier = db_identifier
+ self.target_statuses = target_statuses or ["available"]
+ self.db_type = db_type
+
+ def poke(self, context: Context):
+ db_type = RdsDbType(self.db_type)
self.log.info(
- 'Poking for statuses : %s\nfor export task %s', self.target_statuses, self.export_task_identifier
+ "Poking for statuses : %s\nfor db instance %s", self.target_statuses, self.db_identifier
)
- return self._check_item(item_type='export_task', item_name=self.export_task_identifier)
+ try:
+ if db_type == RdsDbType.INSTANCE:
+ state = self.hook.get_db_instance_state(self.db_identifier)
+ else:
+ state = self.hook.get_db_cluster_state(self.db_identifier)
+ except AirflowNotFoundException:
+ return False
+ return state in self.target_statuses
__all__ = [
"RdsExportTaskExistenceSensor",
+ "RdsDbSensor",
"RdsSnapshotExistenceSensor",
]
diff --git a/airflow/providers/amazon/aws/sensors/redshift.py b/airflow/providers/amazon/aws/sensors/redshift.py
deleted file mode 100644
index 6a73e7ddba962..0000000000000
--- a/airflow/providers/amazon/aws/sensors/redshift.py
+++ /dev/null
@@ -1,30 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-import warnings
-
-from airflow.providers.amazon.aws.sensors.redshift_cluster import RedshiftClusterSensor
-
-AwsRedshiftClusterSensor = RedshiftClusterSensor
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.redshift_cluster`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-__all__ = ["AwsRedshiftClusterSensor", "RedshiftClusterSensor"]
diff --git a/airflow/providers/amazon/aws/sensors/redshift_cluster.py b/airflow/providers/amazon/aws/sensors/redshift_cluster.py
index ae772e95ffbf0..76f4f90111755 100644
--- a/airflow/providers/amazon/aws/sensors/redshift_cluster.py
+++ b/airflow/providers/amazon/aws/sensors/redshift_cluster.py
@@ -14,7 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import TYPE_CHECKING, Optional, Sequence
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Sequence
from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook
from airflow.sensors.base import BaseSensorOperator
@@ -35,24 +37,24 @@ class RedshiftClusterSensor(BaseSensorOperator):
:param target_status: The cluster status desired.
"""
- template_fields: Sequence[str] = ('cluster_identifier', 'target_status')
+ template_fields: Sequence[str] = ("cluster_identifier", "target_status")
def __init__(
self,
*,
cluster_identifier: str,
- target_status: str = 'available',
- aws_conn_id: str = 'aws_default',
+ target_status: str = "available",
+ aws_conn_id: str = "aws_default",
**kwargs,
):
super().__init__(**kwargs)
self.cluster_identifier = cluster_identifier
self.target_status = target_status
self.aws_conn_id = aws_conn_id
- self.hook: Optional[RedshiftHook] = None
+ self.hook: RedshiftHook | None = None
- def poke(self, context: 'Context'):
- self.log.info('Poking for status : %s\nfor cluster %s', self.target_status, self.cluster_identifier)
+ def poke(self, context: Context):
+ self.log.info("Poking for status : %s\nfor cluster %s", self.target_status, self.cluster_identifier)
return self.get_hook().cluster_status(self.cluster_identifier) == self.target_status
def get_hook(self) -> RedshiftHook:
diff --git a/airflow/providers/amazon/aws/sensors/s3.py b/airflow/providers/amazon/aws/sensors/s3.py
index 182b05864cf1d..c60e46841ea98 100644
--- a/airflow/providers/amazon/aws/sensors/s3.py
+++ b/airflow/providers/amazon/aws/sensors/s3.py
@@ -15,24 +15,18 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-#
+from __future__ import annotations
+import fnmatch
import os
import re
-import sys
-import warnings
from datetime import datetime
-from typing import TYPE_CHECKING, Callable, List, Optional, Sequence, Set, Union
+from typing import TYPE_CHECKING, Callable, Sequence
if TYPE_CHECKING:
from airflow.utils.context import Context
-
-if sys.version_info >= (3, 8):
- from functools import cached_property
-else:
- from cached_property import cached_property
-
+from airflow.compat.functools import cached_property
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.sensors.base import BaseSensorOperator, poke_mode_only
@@ -77,17 +71,17 @@ def check_fn(files: List) -> bool:
CA cert bundle than the one used by botocore.
"""
- template_fields: Sequence[str] = ('bucket_key', 'bucket_name')
+ template_fields: Sequence[str] = ("bucket_key", "bucket_name")
def __init__(
self,
*,
- bucket_key: Union[str, List[str]],
- bucket_name: Optional[str] = None,
+ bucket_key: str | list[str],
+ bucket_name: str | None = None,
wildcard_match: bool = False,
- check_fn: Optional[Callable[..., bool]] = None,
- aws_conn_id: str = 'aws_default',
- verify: Optional[Union[str, bool]] = None,
+ check_fn: Callable[..., bool] | None = None,
+ aws_conn_id: str = "aws_default",
+ verify: str | bool | None = None,
**kwargs,
):
super().__init__(**kwargs)
@@ -97,11 +91,11 @@ def __init__(
self.check_fn = check_fn
self.aws_conn_id = aws_conn_id
self.verify = verify
- self.hook: Optional[S3Hook] = None
+ self.hook: S3Hook | None = None
def _check_key(self, key):
- bucket_name, key = S3Hook.get_s3_bucket_key(self.bucket_name, key, 'bucket_name', 'bucket_key')
- self.log.info('Poking for key : s3://%s/%s', bucket_name, key)
+ bucket_name, key = S3Hook.get_s3_bucket_key(self.bucket_name, key, "bucket_name", "bucket_key")
+ self.log.info("Poking for key : s3://%s/%s", bucket_name, key)
"""
Set variable `files` which contains a list of dict which contains only the size
@@ -111,25 +105,26 @@ def _check_key(self, key):
}]
"""
if self.wildcard_match:
- prefix = re.split(r'[\[\*\?]', key, 1)[0]
- files = self.get_hook().get_file_metadata(prefix, bucket_name)
- if len(files) == 0:
+ prefix = re.split(r"[\[\*\?]", key, 1)[0]
+ keys = self.get_hook().get_file_metadata(prefix, bucket_name)
+ key_matches = [k for k in keys if fnmatch.fnmatch(k["Key"], key)]
+ if len(key_matches) == 0:
return False
# Reduce the set of metadata to size only
- files = list(map(lambda f: {'Size': f['Size']}, files))
+ files = list(map(lambda f: {"Size": f["Size"]}, key_matches))
else:
obj = self.get_hook().head_object(key, bucket_name)
if obj is None:
return False
- files = [{'Size': obj['ContentLength']}]
+ files = [{"Size": obj["ContentLength"]}]
if self.check_fn is not None:
return self.check_fn(files)
return True
- def poke(self, context: 'Context'):
+ def poke(self, context: Context):
return all(self._check_key(key) for key in self.bucket_key)
def get_hook(self) -> S3Hook:
@@ -141,40 +136,6 @@ def get_hook(self) -> S3Hook:
return self.hook
-class S3KeySizeSensor(S3KeySensor):
- """
- This class is deprecated.
- Please use :class:`~airflow.providers.amazon.aws.sensors.s3.S3KeySensor`.
- """
-
- def __init__(
- self,
- *,
- check_fn: Optional[Callable[..., bool]] = None,
- **kwargs,
- ):
- warnings.warn(
- """
- S3KeySizeSensor is deprecated.
- Please use `airflow.providers.amazon.aws.sensors.s3.S3KeySensor`.
- """,
- DeprecationWarning,
- stacklevel=2,
- )
-
- super().__init__(
- check_fn=check_fn if check_fn is not None else S3KeySizeSensor.default_check_fn, **kwargs
- )
-
- @staticmethod
- def default_check_fn(data: List) -> bool:
- """Default function for checking that S3 Objects have size more than 0
-
- :param data: List of the objects in S3 bucket.
- """
- return all(f.get('Size', 0) > 0 for f in data)
-
-
@poke_mode_only
class S3KeysUnchangedSensor(BaseSensorOperator):
"""
@@ -213,18 +174,18 @@ class S3KeysUnchangedSensor(BaseSensorOperator):
when this happens. If false an error will be raised.
"""
- template_fields: Sequence[str] = ('bucket_name', 'prefix')
+ template_fields: Sequence[str] = ("bucket_name", "prefix")
def __init__(
self,
*,
bucket_name: str,
prefix: str,
- aws_conn_id: str = 'aws_default',
- verify: Optional[Union[bool, str]] = None,
+ aws_conn_id: str = "aws_default",
+ verify: bool | str | None = None,
inactivity_period: float = 60 * 60,
min_objects: int = 1,
- previous_objects: Optional[Set[str]] = None,
+ previous_objects: set[str] | None = None,
allow_delete: bool = True,
**kwargs,
) -> None:
@@ -242,14 +203,14 @@ def __init__(
self.allow_delete = allow_delete
self.aws_conn_id = aws_conn_id
self.verify = verify
- self.last_activity_time: Optional[datetime] = None
+ self.last_activity_time: datetime | None = None
@cached_property
def hook(self):
"""Returns S3Hook."""
return S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
- def is_keys_unchanged(self, current_objects: Set[str]) -> bool:
+ def is_keys_unchanged(self, current_objects: set[str]) -> bool:
"""
Checks whether new objects have been uploaded and the inactivity_period
has passed and updates the state of the sensor accordingly.
@@ -313,36 +274,5 @@ def is_keys_unchanged(self, current_objects: Set[str]) -> bool:
return False
return False
- def poke(self, context: 'Context'):
+ def poke(self, context: Context):
return self.is_keys_unchanged(set(self.hook.list_keys(self.bucket_name, prefix=self.prefix)))
-
-
-class S3PrefixSensor(S3KeySensor):
- """
- This class is deprecated.
- Please use :class:`~airflow.providers.amazon.aws.sensors.s3.S3KeySensor`.
- """
-
- template_fields: Sequence[str] = ('prefix', 'bucket_name')
-
- def __init__(
- self,
- *,
- prefix: Union[str, List[str]],
- delimiter: str = '/',
- **kwargs,
- ):
- warnings.warn(
- """
- S3PrefixSensor is deprecated.
- Please use `airflow.providers.amazon.aws.sensors.s3.S3KeySensor`.
- """,
- DeprecationWarning,
- stacklevel=2,
- )
-
- self.prefix = prefix
- prefixes = [self.prefix] if isinstance(self.prefix, str) else self.prefix
- keys = [pref if pref.endswith(delimiter) else pref + delimiter for pref in prefixes]
-
- super().__init__(bucket_key=keys, **kwargs)
diff --git a/airflow/providers/amazon/aws/sensors/s3_key.py b/airflow/providers/amazon/aws/sensors/s3_key.py
deleted file mode 100644
index deff11d00d1f2..0000000000000
--- a/airflow/providers/amazon/aws/sensors/s3_key.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.s3`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.sensors.s3 import S3KeySensor, S3KeySizeSensor # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.s3`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/providers/amazon/aws/sensors/s3_keys_unchanged.py b/airflow/providers/amazon/aws/sensors/s3_keys_unchanged.py
deleted file mode 100644
index 792d29c46cfcd..0000000000000
--- a/airflow/providers/amazon/aws/sensors/s3_keys_unchanged.py
+++ /dev/null
@@ -1,28 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.s3`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.sensors.s3 import S3KeysUnchangedSensor # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.s3`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/providers/amazon/aws/sensors/s3_prefix.py b/airflow/providers/amazon/aws/sensors/s3_prefix.py
deleted file mode 100644
index 3990e8774a675..0000000000000
--- a/airflow/providers/amazon/aws/sensors/s3_prefix.py
+++ /dev/null
@@ -1,28 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.s3`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.sensors.s3 import S3PrefixSensor # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.s3`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/providers/amazon/aws/sensors/sagemaker.py b/airflow/providers/amazon/aws/sensors/sagemaker.py
index 925ddaed17fd9..2d9c9aacf0f06 100644
--- a/airflow/providers/amazon/aws/sensors/sagemaker.py
+++ b/airflow/providers/amazon/aws/sensors/sagemaker.py
@@ -14,9 +14,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
import time
-from typing import TYPE_CHECKING, Optional, Sequence, Set
+from typing import TYPE_CHECKING, Sequence
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.sagemaker import LogState, SageMakerHook
@@ -34,12 +35,12 @@ class SageMakerBaseSensor(BaseSensorOperator):
Subclasses should also implement NON_TERMINAL_STATES and FAILED_STATE methods.
"""
- ui_color = '#ededed'
+ ui_color = "#ededed"
- def __init__(self, *, aws_conn_id: str = 'aws_default', **kwargs):
+ def __init__(self, *, aws_conn_id: str = "aws_default", **kwargs):
super().__init__(**kwargs)
self.aws_conn_id = aws_conn_id
- self.hook: Optional[SageMakerHook] = None
+ self.hook: SageMakerHook | None = None
def get_hook(self) -> SageMakerHook:
"""Get SageMakerHook."""
@@ -48,39 +49,39 @@ def get_hook(self) -> SageMakerHook:
self.hook = SageMakerHook(aws_conn_id=self.aws_conn_id)
return self.hook
- def poke(self, context: 'Context'):
+ def poke(self, context: Context):
response = self.get_sagemaker_response()
- if response['ResponseMetadata']['HTTPStatusCode'] != 200:
- self.log.info('Bad HTTP response: %s', response)
+ if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
+ self.log.info("Bad HTTP response: %s", response)
return False
state = self.state_from_response(response)
- self.log.info('Job currently %s', state)
+ self.log.info("Job currently %s", state)
if state in self.non_terminal_states():
return False
if state in self.failed_states():
failed_reason = self.get_failed_reason_from_response(response)
- raise AirflowException(f'Sagemaker job failed for the following reason: {failed_reason}')
+ raise AirflowException(f"Sagemaker job failed for the following reason: {failed_reason}")
return True
- def non_terminal_states(self) -> Set[str]:
+ def non_terminal_states(self) -> set[str]:
"""Placeholder for returning states with should not terminate."""
- raise NotImplementedError('Please implement non_terminal_states() in subclass')
+ raise NotImplementedError("Please implement non_terminal_states() in subclass")
- def failed_states(self) -> Set[str]:
+ def failed_states(self) -> set[str]:
"""Placeholder for returning states with are considered failed."""
- raise NotImplementedError('Please implement failed_states() in subclass')
+ raise NotImplementedError("Please implement failed_states() in subclass")
def get_sagemaker_response(self) -> dict:
"""Placeholder for checking status of a SageMaker task."""
- raise NotImplementedError('Please implement get_sagemaker_response() in subclass')
+ raise NotImplementedError("Please implement get_sagemaker_response() in subclass")
def get_failed_reason_from_response(self, response: dict) -> str:
"""Placeholder for extracting the reason for failure from an AWS response."""
- return 'Unknown'
+ return "Unknown"
def state_from_response(self, response: dict) -> str:
"""Placeholder for extracting the state from an AWS response."""
- raise NotImplementedError('Please implement state_from_response() in subclass')
+ raise NotImplementedError("Please implement state_from_response() in subclass")
class SageMakerEndpointSensor(SageMakerBaseSensor):
@@ -95,7 +96,7 @@ class SageMakerEndpointSensor(SageMakerBaseSensor):
:param endpoint_name: Name of the endpoint instance to watch.
"""
- template_fields: Sequence[str] = ('endpoint_name',)
+ template_fields: Sequence[str] = ("endpoint_name",)
template_ext: Sequence[str] = ()
def __init__(self, *, endpoint_name, **kwargs):
@@ -109,14 +110,14 @@ def failed_states(self):
return SageMakerHook.failed_states
def get_sagemaker_response(self):
- self.log.info('Poking Sagemaker Endpoint %s', self.endpoint_name)
+ self.log.info("Poking Sagemaker Endpoint %s", self.endpoint_name)
return self.get_hook().describe_endpoint(self.endpoint_name)
def get_failed_reason_from_response(self, response):
- return response['FailureReason']
+ return response["FailureReason"]
def state_from_response(self, response):
- return response['EndpointStatus']
+ return response["EndpointStatus"]
class SageMakerTransformSensor(SageMakerBaseSensor):
@@ -131,7 +132,7 @@ class SageMakerTransformSensor(SageMakerBaseSensor):
:param job_name: Name of the transform job to watch.
"""
- template_fields: Sequence[str] = ('job_name',)
+ template_fields: Sequence[str] = ("job_name",)
template_ext: Sequence[str] = ()
def __init__(self, *, job_name: str, **kwargs):
@@ -145,14 +146,14 @@ def failed_states(self):
return SageMakerHook.failed_states
def get_sagemaker_response(self):
- self.log.info('Poking Sagemaker Transform Job %s', self.job_name)
+ self.log.info("Poking Sagemaker Transform Job %s", self.job_name)
return self.get_hook().describe_transform_job(self.job_name)
def get_failed_reason_from_response(self, response):
- return response['FailureReason']
+ return response["FailureReason"]
def state_from_response(self, response):
- return response['TransformJobStatus']
+ return response["TransformJobStatus"]
class SageMakerTuningSensor(SageMakerBaseSensor):
@@ -167,7 +168,7 @@ class SageMakerTuningSensor(SageMakerBaseSensor):
:param job_name: Name of the tuning instance to watch.
"""
- template_fields: Sequence[str] = ('job_name',)
+ template_fields: Sequence[str] = ("job_name",)
template_ext: Sequence[str] = ()
def __init__(self, *, job_name: str, **kwargs):
@@ -181,14 +182,14 @@ def failed_states(self):
return SageMakerHook.failed_states
def get_sagemaker_response(self):
- self.log.info('Poking Sagemaker Tuning Job %s', self.job_name)
+ self.log.info("Poking Sagemaker Tuning Job %s", self.job_name)
return self.get_hook().describe_tuning_job(self.job_name)
def get_failed_reason_from_response(self, response):
- return response['FailureReason']
+ return response["FailureReason"]
def state_from_response(self, response):
- return response['HyperParameterTuningJobStatus']
+ return response["HyperParameterTuningJobStatus"]
class SageMakerTrainingSensor(SageMakerBaseSensor):
@@ -204,7 +205,7 @@ class SageMakerTrainingSensor(SageMakerBaseSensor):
:param print_log: Prints the cloudwatch log if True; Defaults to True.
"""
- template_fields: Sequence[str] = ('job_name',)
+ template_fields: Sequence[str] = ("job_name",)
template_ext: Sequence[str] = ()
def __init__(self, *, job_name, print_log=True, **kwargs):
@@ -213,8 +214,8 @@ def __init__(self, *, job_name, print_log=True, **kwargs):
self.print_log = print_log
self.positions = {}
self.stream_names = []
- self.instance_count: Optional[int] = None
- self.state: Optional[int] = None
+ self.instance_count: int | None = None
+ self.state: int | None = None
self.last_description = None
self.last_describe_job_call = None
self.log_resource_inited = False
@@ -222,8 +223,8 @@ def __init__(self, *, job_name, print_log=True, **kwargs):
def init_log_resource(self, hook: SageMakerHook) -> None:
"""Set tailing LogState for associated training job."""
description = hook.describe_training_job(self.job_name)
- self.instance_count = description['ResourceConfig']['InstanceCount']
- status = description['TrainingJobStatus']
+ self.instance_count = description["ResourceConfig"]["InstanceCount"]
+ status = description["TrainingJobStatus"]
job_already_completed = status not in self.non_terminal_states()
self.state = LogState.COMPLETE if job_already_completed else LogState.TAILING
self.last_description = description
@@ -258,13 +259,13 @@ def get_sagemaker_response(self):
status = self.state_from_response(self.last_description)
if (status not in self.non_terminal_states()) and (status not in self.failed_states()):
billable_time = (
- self.last_description['TrainingEndTime'] - self.last_description['TrainingStartTime']
- ) * self.last_description['ResourceConfig']['InstanceCount']
- self.log.info('Billable seconds: %s', (int(billable_time.total_seconds()) + 1))
+ self.last_description["TrainingEndTime"] - self.last_description["TrainingStartTime"]
+ ) * self.last_description["ResourceConfig"]["InstanceCount"]
+ self.log.info("Billable seconds: %s", (int(billable_time.total_seconds()) + 1))
return self.last_description
def get_failed_reason_from_response(self, response):
- return response['FailureReason']
+ return response["FailureReason"]
def state_from_response(self, response):
- return response['TrainingJobStatus']
+ return response["TrainingJobStatus"]
diff --git a/airflow/providers/amazon/aws/sensors/sagemaker_base.py b/airflow/providers/amazon/aws/sensors/sagemaker_base.py
deleted file mode 100644
index 102c410d279a4..0000000000000
--- a/airflow/providers/amazon/aws/sensors/sagemaker_base.py
+++ /dev/null
@@ -1,29 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.sagemaker`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.sensors.sagemaker import SageMakerBaseSensor # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sagemaker`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/providers/amazon/aws/sensors/sagemaker_endpoint.py b/airflow/providers/amazon/aws/sensors/sagemaker_endpoint.py
deleted file mode 100644
index 00ed8442386b4..0000000000000
--- a/airflow/providers/amazon/aws/sensors/sagemaker_endpoint.py
+++ /dev/null
@@ -1,29 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.sagemaker`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.sensors.sagemaker import SageMakerEndpointSensor # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sagemaker`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/providers/amazon/aws/sensors/sagemaker_training.py b/airflow/providers/amazon/aws/sensors/sagemaker_training.py
deleted file mode 100644
index d1949964956e8..0000000000000
--- a/airflow/providers/amazon/aws/sensors/sagemaker_training.py
+++ /dev/null
@@ -1,29 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.sagemaker`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.sensors.sagemaker import SageMakerTrainingSensor # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sagemaker`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/providers/amazon/aws/sensors/sagemaker_transform.py b/airflow/providers/amazon/aws/sensors/sagemaker_transform.py
deleted file mode 100644
index 7a48f3e7de006..0000000000000
--- a/airflow/providers/amazon/aws/sensors/sagemaker_transform.py
+++ /dev/null
@@ -1,29 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.sagemaker`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.sensors.sagemaker import SageMakerTransformSensor # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sagemaker`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/providers/amazon/aws/sensors/sagemaker_tuning.py b/airflow/providers/amazon/aws/sensors/sagemaker_tuning.py
deleted file mode 100644
index d5f0d90555ffd..0000000000000
--- a/airflow/providers/amazon/aws/sensors/sagemaker_tuning.py
+++ /dev/null
@@ -1,29 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.sagemaker`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.sensors.sagemaker import SageMakerTuningSensor # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.sagemaker`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/providers/amazon/aws/sensors/sqs.py b/airflow/providers/amazon/aws/sensors/sqs.py
index cc026ec5e8306..8b09203d2e319 100644
--- a/airflow/providers/amazon/aws/sensors/sqs.py
+++ b/airflow/providers/amazon/aws/sensors/sqs.py
@@ -16,14 +16,16 @@
# specific language governing permissions and limitations
# under the License.
"""Reads and then deletes the message from SQS queue"""
+from __future__ import annotations
+
import json
-import warnings
-from typing import TYPE_CHECKING, Any, Optional, Sequence
+from typing import TYPE_CHECKING, Any, Collection, Sequence
from jsonpath_ng import parse
from typing_extensions import Literal
from airflow.exceptions import AirflowException
+from airflow.providers.amazon.aws.hooks.base_aws import BaseAwsConnection
from airflow.providers.amazon.aws.hooks.sqs import SqsHook
from airflow.sensors.base import BaseSensorOperator
@@ -34,9 +36,13 @@
class SqsSensor(BaseSensorOperator):
"""
Get messages from an Amazon SQS queue and then delete the messages from the queue.
- If deletion of messages fails an AirflowException is thrown. Otherwise, the messages
+ If deletion of messages fails, an AirflowException is thrown. Otherwise, the messages
are pushed through XCom with the key ``messages``.
+ By default,the sensor performs one and only one SQS call per poke, which limits the result to
+ a maximum of 10 messages. However, the total number of SQS API calls per poke can be controlled
+ by num_batches param.
+
.. seealso::
For more information on how to use this sensor, take a look at the guide:
:ref:`howto/sensor:SqsSensor`
@@ -44,6 +50,7 @@ class SqsSensor(BaseSensorOperator):
:param aws_conn_id: AWS connection id
:param sqs_queue: The SQS queue url (templated)
:param max_messages: The maximum number of messages to retrieve for each poke (templated)
+ :param num_batches: The number of times the sensor will call the SQS API to receive messages (default: 1)
:param wait_time_seconds: The time in seconds to wait for receiving messages (default: 1 second)
:param visibility_timeout: Visibility timeout, a period of time during which
Amazon SQS prevents other consumers from receiving and processing the message.
@@ -64,17 +71,18 @@ class SqsSensor(BaseSensorOperator):
"""
- template_fields: Sequence[str] = ('sqs_queue', 'max_messages', 'message_filtering_config')
+ template_fields: Sequence[str] = ("sqs_queue", "max_messages", "message_filtering_config")
def __init__(
self,
*,
sqs_queue,
- aws_conn_id: str = 'aws_default',
+ aws_conn_id: str = "aws_default",
max_messages: int = 5,
+ num_batches: int = 1,
wait_time_seconds: int = 1,
- visibility_timeout: Optional[int] = None,
- message_filtering: Optional[Literal["literal", "jsonpath"]] = None,
+ visibility_timeout: int | None = None,
+ message_filtering: Literal["literal", "jsonpath"] | None = None,
message_filtering_match_values: Any = None,
message_filtering_config: Any = None,
delete_message_on_reception: bool = True,
@@ -84,6 +92,7 @@ def __init__(
self.sqs_queue = sqs_queue
self.aws_conn_id = aws_conn_id
self.max_messages = max_messages
+ self.num_batches = num_batches
self.wait_time_seconds = wait_time_seconds
self.visibility_timeout = visibility_timeout
@@ -96,39 +105,37 @@ def __init__(
message_filtering_match_values = set(message_filtering_match_values)
self.message_filtering_match_values = message_filtering_match_values
- if self.message_filtering == 'literal':
+ if self.message_filtering == "literal":
if self.message_filtering_match_values is None:
- raise TypeError('message_filtering_match_values must be specified for literal matching')
+ raise TypeError("message_filtering_match_values must be specified for literal matching")
self.message_filtering_config = message_filtering_config
- self.hook: Optional[SqsHook] = None
+ self.hook: SqsHook | None = None
- def poke(self, context: 'Context'):
+ def poll_sqs(self, sqs_conn: BaseAwsConnection) -> Collection:
"""
- Check for message on subscribed queue and write to xcom the message with key ``messages``
+ Poll SQS queue to retrieve messages.
- :param context: the context object
- :return: ``True`` if message is available or ``False``
+ :param sqs_conn: SQS connection
+ :return: A list of messages retrieved from SQS
"""
- sqs_conn = self.get_hook().get_conn()
-
- self.log.info('SqsSensor checking for message on queue: %s', self.sqs_queue)
+ self.log.info("SqsSensor checking for message on queue: %s", self.sqs_queue)
receive_message_kwargs = {
- 'QueueUrl': self.sqs_queue,
- 'MaxNumberOfMessages': self.max_messages,
- 'WaitTimeSeconds': self.wait_time_seconds,
+ "QueueUrl": self.sqs_queue,
+ "MaxNumberOfMessages": self.max_messages,
+ "WaitTimeSeconds": self.wait_time_seconds,
}
if self.visibility_timeout is not None:
- receive_message_kwargs['VisibilityTimeout'] = self.visibility_timeout
+ receive_message_kwargs["VisibilityTimeout"] = self.visibility_timeout
response = sqs_conn.receive_message(**receive_message_kwargs)
if "Messages" not in response:
- return False
+ return []
- messages = response['Messages']
+ messages = response["Messages"]
num_messages = len(messages)
self.log.info("Received %d messages", num_messages)
@@ -136,28 +143,47 @@ def poke(self, context: 'Context'):
messages = self.filter_messages(messages)
num_messages = len(messages)
self.log.info("There are %d messages left after filtering", num_messages)
+ return messages
- if not num_messages:
- return False
+ def poke(self, context: Context):
+ """
+ Check subscribed queue for messages and write them to xcom with the ``messages`` key.
- if not self.delete_message_on_reception:
- context['ti'].xcom_push(key='messages', value=messages)
- return True
+ :param context: the context object
+ :return: ``True`` if message is available or ``False``
+ """
+ sqs_conn = self.get_hook().get_conn()
- self.log.info("Deleting %d messages", num_messages)
+ message_batch: list[Any] = []
- entries = [
- {'Id': message['MessageId'], 'ReceiptHandle': message['ReceiptHandle']} for message in messages
- ]
- response = sqs_conn.delete_message_batch(QueueUrl=self.sqs_queue, Entries=entries)
+ # perform multiple SQS call to retrieve messages in series
+ for _ in range(self.num_batches):
+ messages = self.poll_sqs(sqs_conn=sqs_conn)
- if 'Successful' in response:
- context['ti'].xcom_push(key='messages', value=messages)
- return True
- else:
- raise AirflowException(
- 'Delete SQS Messages failed ' + str(response) + ' for messages ' + str(messages)
- )
+ if not len(messages):
+ continue
+
+ message_batch.extend(messages)
+
+ if self.delete_message_on_reception:
+
+ self.log.info("Deleting %d messages", len(messages))
+
+ entries = [
+ {"Id": message["MessageId"], "ReceiptHandle": message["ReceiptHandle"]}
+ for message in messages
+ ]
+ response = sqs_conn.delete_message_batch(QueueUrl=self.sqs_queue, Entries=entries)
+
+ if "Successful" not in response:
+ raise AirflowException(
+ "Delete SQS Messages failed " + str(response) + " for messages " + str(messages)
+ )
+ if not len(message_batch):
+ return False
+
+ context["ti"].xcom_push(key="messages", value=message_batch)
+ return True
def get_hook(self) -> SqsHook:
"""Create and return an SqsHook"""
@@ -168,17 +194,17 @@ def get_hook(self) -> SqsHook:
return self.hook
def filter_messages(self, messages):
- if self.message_filtering == 'literal':
+ if self.message_filtering == "literal":
return self.filter_messages_literal(messages)
- if self.message_filtering == 'jsonpath':
+ if self.message_filtering == "jsonpath":
return self.filter_messages_jsonpath(messages)
else:
- raise NotImplementedError('Override this method to define custom filters')
+ raise NotImplementedError("Override this method to define custom filters")
def filter_messages_literal(self, messages):
filtered_messages = []
for message in messages:
- if message['Body'] in self.message_filtering_match_values:
+ if message["Body"] in self.message_filtering_match_values:
filtered_messages.append(message)
return filtered_messages
@@ -186,7 +212,7 @@ def filter_messages_jsonpath(self, messages):
jsonpath_expr = parse(self.message_filtering_config)
filtered_messages = []
for message in messages:
- body = message['Body']
+ body = message["Body"]
# Body is a string, deserialize to an object and then parse
body = json.loads(body)
results = jsonpath_expr.find(body)
@@ -200,18 +226,3 @@ def filter_messages_jsonpath(self, messages):
filtered_messages.append(message)
break
return filtered_messages
-
-
-class SQSSensor(SqsSensor):
- """
- This sensor is deprecated.
- Please use :class:`airflow.providers.amazon.aws.sensors.sqs.SqsSensor`.
- """
-
- def __init__(self, *args, **kwargs):
- warnings.warn(
- "This class is deprecated. Please use `airflow.providers.amazon.aws.sensors.sqs.SqsSensor`.",
- DeprecationWarning,
- stacklevel=2,
- )
- super().__init__(*args, **kwargs)
diff --git a/airflow/providers/amazon/aws/sensors/step_function.py b/airflow/providers/amazon/aws/sensors/step_function.py
index 6f82c0bc99d8c..fda6f932d8bbd 100644
--- a/airflow/providers/amazon/aws/sensors/step_function.py
+++ b/airflow/providers/amazon/aws/sensors/step_function.py
@@ -14,9 +14,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
import json
-from typing import TYPE_CHECKING, Optional, Sequence
+from typing import TYPE_CHECKING, Sequence
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.step_function import StepFunctionHook
@@ -43,45 +44,45 @@ class StepFunctionExecutionSensor(BaseSensorOperator):
:param aws_conn_id: aws connection to use, defaults to 'aws_default'
"""
- INTERMEDIATE_STATES = ('RUNNING',)
+ INTERMEDIATE_STATES = ("RUNNING",)
FAILURE_STATES = (
- 'FAILED',
- 'TIMED_OUT',
- 'ABORTED',
+ "FAILED",
+ "TIMED_OUT",
+ "ABORTED",
)
- SUCCESS_STATES = ('SUCCEEDED',)
+ SUCCESS_STATES = ("SUCCEEDED",)
- template_fields: Sequence[str] = ('execution_arn',)
+ template_fields: Sequence[str] = ("execution_arn",)
template_ext: Sequence[str] = ()
- ui_color = '#66c3ff'
+ ui_color = "#66c3ff"
def __init__(
self,
*,
execution_arn: str,
- aws_conn_id: str = 'aws_default',
- region_name: Optional[str] = None,
+ aws_conn_id: str = "aws_default",
+ region_name: str | None = None,
**kwargs,
):
super().__init__(**kwargs)
self.execution_arn = execution_arn
self.aws_conn_id = aws_conn_id
self.region_name = region_name
- self.hook: Optional[StepFunctionHook] = None
+ self.hook: StepFunctionHook | None = None
- def poke(self, context: 'Context'):
+ def poke(self, context: Context):
execution_status = self.get_hook().describe_execution(self.execution_arn)
- state = execution_status['status']
- output = json.loads(execution_status['output']) if 'output' in execution_status else None
+ state = execution_status["status"]
+ output = json.loads(execution_status["output"]) if "output" in execution_status else None
if state in self.FAILURE_STATES:
- raise AirflowException(f'Step Function sensor failed. State Machine Output: {output}')
+ raise AirflowException(f"Step Function sensor failed. State Machine Output: {output}")
if state in self.INTERMEDIATE_STATES:
return False
- self.log.info('Doing xcom_push of output')
- self.xcom_push(context, 'output', output)
+ self.log.info("Doing xcom_push of output")
+ self.xcom_push(context, "output", output)
return True
def get_hook(self) -> StepFunctionHook:
diff --git a/airflow/providers/amazon/aws/sensors/step_function_execution.py b/airflow/providers/amazon/aws/sensors/step_function_execution.py
deleted file mode 100644
index 267343cd1df9e..0000000000000
--- a/airflow/providers/amazon/aws/sensors/step_function_execution.py
+++ /dev/null
@@ -1,28 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.sensors.step_function`."""
-
-import warnings
-
-from airflow.providers.amazon.aws.sensors.step_function import StepFunctionExecutionSensor # noqa
-
-warnings.warn(
- "This module is deprecated. Please use `airflow.providers.amazon.aws.sensors.step_function`.",
- DeprecationWarning,
- stacklevel=2,
-)
diff --git a/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py b/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py
index a6f5f8da21aae..155f5439a6f0e 100644
--- a/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py
+++ b/airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py
@@ -15,17 +15,17 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-#
-
"""
This module contains operators to replicate records from
DynamoDB table to S3.
"""
+from __future__ import annotations
+
import json
from copy import copy
from os.path import getsize
from tempfile import NamedTemporaryFile
-from typing import IO, TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence
+from typing import IO, TYPE_CHECKING, Any, Callable, Sequence
from uuid import uuid4
from airflow.models import BaseOperator
@@ -36,12 +36,12 @@
from airflow.utils.context import Context
-def _convert_item_to_json_bytes(item: Dict[str, Any]) -> bytes:
- return (json.dumps(item) + '\n').encode('utf-8')
+def _convert_item_to_json_bytes(item: dict[str, Any]) -> bytes:
+ return (json.dumps(item) + "\n").encode("utf-8")
def _upload_file_to_s3(
- file_obj: IO, bucket_name: str, s3_key_prefix: str, aws_conn_id: str = 'aws_default'
+ file_obj: IO, bucket_name: str, s3_key_prefix: str, aws_conn_id: str = "aws_default"
) -> None:
s3_client = S3Hook(aws_conn_id=aws_conn_id).get_conn()
file_obj.seek(0)
@@ -80,8 +80,9 @@ class DynamoDBToS3Operator(BaseOperator):
"""
template_fields: Sequence[str] = (
- 's3_bucket_name',
- 'dynamodb_table_name',
+ "s3_bucket_name",
+ "s3_key_prefix",
+ "dynamodb_table_name",
)
template_fields_renderers = {
"dynamodb_scan_kwargs": "json",
@@ -93,10 +94,10 @@ def __init__(
dynamodb_table_name: str,
s3_bucket_name: str,
file_size: int,
- dynamodb_scan_kwargs: Optional[Dict[str, Any]] = None,
- s3_key_prefix: str = '',
- process_func: Callable[[Dict[str, Any]], bytes] = _convert_item_to_json_bytes,
- aws_conn_id: str = 'aws_default',
+ dynamodb_scan_kwargs: dict[str, Any] | None = None,
+ s3_key_prefix: str = "",
+ process_func: Callable[[dict[str, Any]], bytes] = _convert_item_to_json_bytes,
+ aws_conn_id: str = "aws_default",
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -108,12 +109,13 @@ def __init__(
self.s3_key_prefix = s3_key_prefix
self.aws_conn_id = aws_conn_id
- def execute(self, context: 'Context') -> None:
+ def execute(self, context: Context) -> None:
hook = DynamoDBHook(aws_conn_id=self.aws_conn_id)
table = hook.get_conn().Table(self.dynamodb_table_name)
scan_kwargs = copy(self.dynamodb_scan_kwargs) if self.dynamodb_scan_kwargs else {}
err = None
+ f: IO[Any]
with NamedTemporaryFile() as f:
try:
f = self._scan_dynamodb_and_upload_to_s3(f, scan_kwargs, table)
@@ -127,16 +129,16 @@ def execute(self, context: 'Context') -> None:
def _scan_dynamodb_and_upload_to_s3(self, temp_file: IO, scan_kwargs: dict, table: Any) -> IO:
while True:
response = table.scan(**scan_kwargs)
- items = response['Items']
+ items = response["Items"]
for item in items:
temp_file.write(self.process_func(item))
- if 'LastEvaluatedKey' not in response:
+ if "LastEvaluatedKey" not in response:
# no more items to scan
break
- last_evaluated_key = response['LastEvaluatedKey']
- scan_kwargs['ExclusiveStartKey'] = last_evaluated_key
+ last_evaluated_key = response["LastEvaluatedKey"]
+ scan_kwargs["ExclusiveStartKey"] = last_evaluated_key
# Upload the file to S3 if reach file size limit
if getsize(temp_file.name) >= self.file_size:
diff --git a/airflow/providers/amazon/aws/transfers/exasol_to_s3.py b/airflow/providers/amazon/aws/transfers/exasol_to_s3.py
index 0f2fb7a99ea10..d5e78adf481b1 100644
--- a/airflow/providers/amazon/aws/transfers/exasol_to_s3.py
+++ b/airflow/providers/amazon/aws/transfers/exasol_to_s3.py
@@ -16,9 +16,10 @@
# specific language governing permissions and limitations
# under the License.
"""Transfers data from Exasol database into a S3 Bucket."""
+from __future__ import annotations
from tempfile import NamedTemporaryFile
-from typing import TYPE_CHECKING, Dict, Optional, Sequence
+from typing import TYPE_CHECKING, Sequence
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
@@ -49,25 +50,25 @@ class ExasolToS3Operator(BaseOperator):
method of :class:`~pyexasol.connection.ExaConnection`.
"""
- template_fields: Sequence[str] = ('query_or_table', 'key', 'bucket_name', 'query_params', 'export_params')
+ template_fields: Sequence[str] = ("query_or_table", "key", "bucket_name", "query_params", "export_params")
template_fields_renderers = {"query_or_table": "sql", "query_params": "json", "export_params": "json"}
- template_ext: Sequence[str] = ('.sql',)
- ui_color = '#ededed'
+ template_ext: Sequence[str] = (".sql",)
+ ui_color = "#ededed"
def __init__(
self,
*,
query_or_table: str,
key: str,
- bucket_name: Optional[str] = None,
+ bucket_name: str | None = None,
replace: bool = False,
encrypt: bool = False,
gzip: bool = False,
- acl_policy: Optional[str] = None,
- query_params: Optional[Dict] = None,
- export_params: Optional[Dict] = None,
- exasol_conn_id: str = 'exasol_default',
- aws_conn_id: str = 'aws_default',
+ acl_policy: str | None = None,
+ query_params: dict | None = None,
+ export_params: dict | None = None,
+ exasol_conn_id: str = "exasol_default",
+ aws_conn_id: str = "aws_default",
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -83,7 +84,7 @@ def __init__(
self.exasol_conn_id = exasol_conn_id
self.aws_conn_id = aws_conn_id
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
exasol_hook = ExasolHook(exasol_conn_id=self.exasol_conn_id)
s3_hook = S3Hook(aws_conn_id=self.aws_conn_id)
diff --git a/airflow/providers/amazon/aws/transfers/ftp_to_s3.py b/airflow/providers/amazon/aws/transfers/ftp_to_s3.py
index 1426599bc4763..45811a82e4823 100644
--- a/airflow/providers/amazon/aws/transfers/ftp_to_s3.py
+++ b/airflow/providers/amazon/aws/transfers/ftp_to_s3.py
@@ -15,8 +15,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
from tempfile import NamedTemporaryFile
-from typing import TYPE_CHECKING, List, Optional, Sequence, Union
+from typing import TYPE_CHECKING, Sequence
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
@@ -60,7 +62,7 @@ class FTPToS3Operator(BaseOperator):
uploaded to the S3 bucket.
"""
- template_fields: Sequence[str] = ('ftp_path', 's3_bucket', 's3_key', 'ftp_filenames', 's3_filenames')
+ template_fields: Sequence[str] = ("ftp_path", "s3_bucket", "s3_key", "ftp_filenames", "s3_filenames")
def __init__(
self,
@@ -68,14 +70,14 @@ def __init__(
ftp_path: str,
s3_bucket: str,
s3_key: str,
- ftp_filenames: Optional[Union[str, List[str]]] = None,
- s3_filenames: Optional[Union[str, List[str]]] = None,
- ftp_conn_id: str = 'ftp_default',
- aws_conn_id: str = 'aws_default',
+ ftp_filenames: str | list[str] | None = None,
+ s3_filenames: str | list[str] | None = None,
+ ftp_conn_id: str = "ftp_default",
+ aws_conn_id: str = "aws_default",
replace: bool = False,
encrypt: bool = False,
gzip: bool = False,
- acl_policy: Optional[str] = None,
+ acl_policy: str | None = None,
**kwargs,
):
super().__init__(**kwargs)
@@ -90,8 +92,8 @@ def __init__(
self.encrypt = encrypt
self.gzip = gzip
self.acl_policy = acl_policy
- self.s3_hook: Optional[S3Hook] = None
- self.ftp_hook: Optional[FTPHook] = None
+ self.s3_hook: S3Hook | None = None
+ self.ftp_hook: FTPHook | None = None
def __upload_to_s3_from_ftp(self, remote_filename, s3_file_key):
with NamedTemporaryFile() as local_tmp_file:
@@ -108,35 +110,35 @@ def __upload_to_s3_from_ftp(self, remote_filename, s3_file_key):
gzip=self.gzip,
acl_policy=self.acl_policy,
)
- self.log.info('File upload to %s', s3_file_key)
+ self.log.info("File upload to %s", s3_file_key)
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
self.ftp_hook = FTPHook(ftp_conn_id=self.ftp_conn_id)
self.s3_hook = S3Hook(self.aws_conn_id)
if self.ftp_filenames:
if isinstance(self.ftp_filenames, str):
- self.log.info('Getting files in %s', self.ftp_path)
+ self.log.info("Getting files in %s", self.ftp_path)
list_dir = self.ftp_hook.list_directory(
path=self.ftp_path,
)
- if self.ftp_filenames == '*':
+ if self.ftp_filenames == "*":
files = list_dir
else:
ftp_filename: str = self.ftp_filenames
files = list(filter(lambda f: ftp_filename in f, list_dir))
for file in files:
- self.log.info('Moving file %s', file)
+ self.log.info("Moving file %s", file)
if self.s3_filenames and isinstance(self.s3_filenames, str):
filename = file.replace(self.ftp_filenames, self.s3_filenames)
else:
filename = file
- s3_file_key = f'{self.s3_key}{filename}'
+ s3_file_key = f"{self.s3_key}{filename}"
self.__upload_to_s3_from_ftp(file, s3_file_key)
else:
diff --git a/airflow/providers/amazon/aws/transfers/gcs_to_s3.py b/airflow/providers/amazon/aws/transfers/gcs_to_s3.py
index b521ce5360741..6a4a58103689b 100644
--- a/airflow/providers/amazon/aws/transfers/gcs_to_s3.py
+++ b/airflow/providers/amazon/aws/transfers/gcs_to_s3.py
@@ -16,9 +16,11 @@
# specific language governing permissions and limitations
# under the License.
"""This module contains Google Cloud Storage to S3 operator."""
+from __future__ import annotations
+
import os
import warnings
-from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Union
+from typing import TYPE_CHECKING, Sequence
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
@@ -83,30 +85,30 @@ class GCSToS3Operator(BaseOperator):
"""
template_fields: Sequence[str] = (
- 'bucket',
- 'prefix',
- 'delimiter',
- 'dest_s3_key',
- 'google_impersonation_chain',
+ "bucket",
+ "prefix",
+ "delimiter",
+ "dest_s3_key",
+ "google_impersonation_chain",
)
- ui_color = '#f0eee4'
+ ui_color = "#f0eee4"
def __init__(
self,
*,
bucket: str,
- prefix: Optional[str] = None,
- delimiter: Optional[str] = None,
- gcp_conn_id: str = 'google_cloud_default',
- google_cloud_storage_conn_id: Optional[str] = None,
- delegate_to: Optional[str] = None,
- dest_aws_conn_id: str = 'aws_default',
+ prefix: str | None = None,
+ delimiter: str | None = None,
+ gcp_conn_id: str = "google_cloud_default",
+ google_cloud_storage_conn_id: str | None = None,
+ delegate_to: str | None = None,
+ dest_aws_conn_id: str = "aws_default",
dest_s3_key: str,
- dest_verify: Optional[Union[str, bool]] = None,
+ dest_verify: str | bool | None = None,
replace: bool = False,
- google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
- dest_s3_extra_args: Optional[Dict] = None,
- s3_acl_policy: Optional[str] = None,
+ google_impersonation_chain: str | Sequence[str] | None = None,
+ dest_s3_extra_args: dict | None = None,
+ s3_acl_policy: str | None = None,
keep_directory_structure: bool = True,
**kwargs,
) -> None:
@@ -135,7 +137,7 @@ def __init__(
self.s3_acl_policy = s3_acl_policy
self.keep_directory_structure = keep_directory_structure
- def execute(self, context: 'Context') -> List[str]:
+ def execute(self, context: Context) -> list[str]:
# list all files in an Google Cloud Storage bucket
hook = GCSHook(
gcp_conn_id=self.gcp_conn_id,
@@ -144,7 +146,7 @@ def execute(self, context: 'Context') -> List[str]:
)
self.log.info(
- 'Getting list of the files. Bucket: %s; Delimiter: %s; Prefix: %s',
+ "Getting list of the files. Bucket: %s; Delimiter: %s; Prefix: %s",
self.bucket,
self.delimiter,
self.prefix,
@@ -170,7 +172,7 @@ def execute(self, context: 'Context') -> List[str]:
# in case that no files exists, return an empty array to avoid errors
existing_files = existing_files if existing_files is not None else []
# remove the prefix for the existing files to allow the match
- existing_files = [file.replace(prefix, '', 1) for file in existing_files]
+ existing_files = [file.replace(prefix, "", 1) for file in existing_files]
files = list(set(files) - set(existing_files))
if files:
diff --git a/airflow/providers/amazon/aws/transfers/glacier_to_gcs.py b/airflow/providers/amazon/aws/transfers/glacier_to_gcs.py
index 07d3410ee7334..4a189230a4bb1 100644
--- a/airflow/providers/amazon/aws/transfers/glacier_to_gcs.py
+++ b/airflow/providers/amazon/aws/transfers/glacier_to_gcs.py
@@ -15,8 +15,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
import tempfile
-from typing import TYPE_CHECKING, Optional, Sequence, Union
+from typing import TYPE_CHECKING, Sequence
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.glacier import GlacierHook
@@ -71,8 +73,8 @@ def __init__(
object_name: str,
gzip: bool,
chunk_size: int = 1024,
- delegate_to: Optional[str] = None,
- google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
+ delegate_to: str | None = None,
+ google_impersonation_chain: str | Sequence[str] | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -86,7 +88,7 @@ def __init__(
self.delegate_to = delegate_to
self.impersonation_chain = google_impersonation_chain
- def execute(self, context: 'Context') -> str:
+ def execute(self, context: Context) -> str:
glacier_hook = GlacierHook(aws_conn_id=self.aws_conn_id)
gcs_hook = GCSHook(
gcp_conn_id=self.gcp_conn_id,
diff --git a/airflow/providers/amazon/aws/transfers/google_api_to_s3.py b/airflow/providers/amazon/aws/transfers/google_api_to_s3.py
index f3e10b62b2aa0..223b13472948c 100644
--- a/airflow/providers/amazon/aws/transfers/google_api_to_s3.py
+++ b/airflow/providers/amazon/aws/transfers/google_api_to_s3.py
@@ -15,11 +15,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-#
"""This module allows you to transfer data from any Google API endpoint into a S3 Bucket."""
+from __future__ import annotations
+
import json
import sys
-from typing import TYPE_CHECKING, Optional, Sequence, Union
+from typing import TYPE_CHECKING, Sequence
from airflow.models import BaseOperator, TaskInstance
from airflow.models.xcom import MAX_XCOM_SIZE, XCOM_RETURN_KEY
@@ -57,6 +58,10 @@ class GoogleApiToS3Operator(BaseOperator):
:param google_api_endpoint_params: The params to control the corresponding endpoint result.
:param s3_destination_key: The url where to put the data retrieved from the endpoint in S3.
+
+ .. note See https://docs.aws.amazon.com/AmazonS3/latest/userguide/access-bucket-intro.html
+ for valid url formats.
+
:param google_api_response_via_xcom: Can be set to expose the google api response to xcom.
:param google_api_endpoint_params_via_xcom: If set to a value this value will be used as a key
for pulling from xcom and updating the google api endpoint params.
@@ -85,12 +90,13 @@ class GoogleApiToS3Operator(BaseOperator):
"""
template_fields: Sequence[str] = (
- 'google_api_endpoint_params',
- 's3_destination_key',
- 'google_impersonation_chain',
+ "google_api_endpoint_params",
+ "s3_destination_key",
+ "google_impersonation_chain",
+ "gcp_conn_id",
)
template_ext: Sequence[str] = ()
- ui_color = '#cc181e'
+ ui_color = "#cc181e"
def __init__(
self,
@@ -100,16 +106,16 @@ def __init__(
google_api_endpoint_path: str,
google_api_endpoint_params: dict,
s3_destination_key: str,
- google_api_response_via_xcom: Optional[str] = None,
- google_api_endpoint_params_via_xcom: Optional[str] = None,
- google_api_endpoint_params_via_xcom_task_ids: Optional[str] = None,
+ google_api_response_via_xcom: str | None = None,
+ google_api_endpoint_params_via_xcom: str | None = None,
+ google_api_endpoint_params_via_xcom_task_ids: str | None = None,
google_api_pagination: bool = False,
google_api_num_retries: int = 0,
s3_overwrite: bool = False,
- gcp_conn_id: str = 'google_cloud_default',
- delegate_to: Optional[str] = None,
- aws_conn_id: str = 'aws_default',
- google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
+ gcp_conn_id: str = "google_cloud_default",
+ delegate_to: str | None = None,
+ aws_conn_id: str = "aws_default",
+ google_impersonation_chain: str | Sequence[str] | None = None,
**kwargs,
):
super().__init__(**kwargs)
@@ -129,23 +135,23 @@ def __init__(
self.aws_conn_id = aws_conn_id
self.google_impersonation_chain = google_impersonation_chain
- def execute(self, context: 'Context') -> None:
+ def execute(self, context: Context) -> None:
"""
Transfers Google APIs json data to S3.
:param context: The context that is being provided when executing.
"""
- self.log.info('Transferring data from %s to s3', self.google_api_service_name)
+ self.log.info("Transferring data from %s to s3", self.google_api_service_name)
if self.google_api_endpoint_params_via_xcom:
- self._update_google_api_endpoint_params_via_xcom(context['task_instance'])
+ self._update_google_api_endpoint_params_via_xcom(context["task_instance"])
data = self._retrieve_data_from_google_api()
self._load_data_to_s3(data)
if self.google_api_response_via_xcom:
- self._expose_google_api_response_via_xcom(context['task_instance'], data)
+ self._expose_google_api_response_via_xcom(context["task_instance"], data)
def _retrieve_data_from_google_api(self) -> dict:
google_discovery_api_hook = GoogleDiscoveryApiHook(
@@ -165,7 +171,10 @@ def _retrieve_data_from_google_api(self) -> dict:
def _load_data_to_s3(self, data: dict) -> None:
s3_hook = S3Hook(aws_conn_id=self.aws_conn_id)
s3_hook.load_string(
- string_data=json.dumps(data), key=self.s3_destination_key, replace=self.s3_overwrite
+ string_data=json.dumps(data),
+ bucket_name=S3Hook.parse_s3_url(self.s3_destination_key)[0],
+ key=S3Hook.parse_s3_url(self.s3_destination_key)[1],
+ replace=self.s3_overwrite,
)
def _update_google_api_endpoint_params_via_xcom(self, task_instance: TaskInstance) -> None:
@@ -181,4 +190,4 @@ def _expose_google_api_response_via_xcom(self, task_instance: TaskInstance, data
if sys.getsizeof(data) < MAX_XCOM_SIZE:
task_instance.xcom_push(key=self.google_api_response_via_xcom or XCOM_RETURN_KEY, value=data)
else:
- raise RuntimeError('The size of the downloaded data is too large to push to XCom!')
+ raise RuntimeError("The size of the downloaded data is too large to push to XCom!")
diff --git a/airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py b/airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py
index 652eca22b86df..167665d627b82 100644
--- a/airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py
+++ b/airflow/providers/amazon/aws/transfers/hive_to_dynamodb.py
@@ -15,11 +15,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""This module contains operator to move data from Hive to DynamoDB."""
+from __future__ import annotations
import json
-from typing import TYPE_CHECKING, Callable, Optional, Sequence
+from typing import TYPE_CHECKING, Callable, Sequence
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.dynamodb import DynamoDBHook
@@ -52,10 +52,10 @@ class HiveToDynamoDBOperator(BaseOperator):
:param aws_conn_id: aws connection
"""
- template_fields: Sequence[str] = ('sql',)
- template_ext: Sequence[str] = ('.sql',)
+ template_fields: Sequence[str] = ("sql",)
+ template_ext: Sequence[str] = (".sql",)
template_fields_renderers = {"sql": "hql"}
- ui_color = '#a0e08c'
+ ui_color = "#a0e08c"
def __init__(
self,
@@ -63,13 +63,13 @@ def __init__(
sql: str,
table_name: str,
table_keys: list,
- pre_process: Optional[Callable] = None,
- pre_process_args: Optional[list] = None,
- pre_process_kwargs: Optional[list] = None,
- region_name: Optional[str] = None,
- schema: str = 'default',
- hiveserver2_conn_id: str = 'hiveserver2_default',
- aws_conn_id: str = 'aws_default',
+ pre_process: Callable | None = None,
+ pre_process_args: list | None = None,
+ pre_process_kwargs: list | None = None,
+ region_name: str | None = None,
+ schema: str = "default",
+ hiveserver2_conn_id: str = "hiveserver2_default",
+ aws_conn_id: str = "aws_default",
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -84,10 +84,10 @@ def __init__(
self.hiveserver2_conn_id = hiveserver2_conn_id
self.aws_conn_id = aws_conn_id
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
hive = HiveServer2Hook(hiveserver2_conn_id=self.hiveserver2_conn_id)
- self.log.info('Extracting data from Hive')
+ self.log.info("Extracting data from Hive")
self.log.info(self.sql)
data = hive.get_pandas_df(self.sql, schema=self.schema)
@@ -98,13 +98,13 @@ def execute(self, context: 'Context'):
region_name=self.region_name,
)
- self.log.info('Inserting rows into dynamodb')
+ self.log.info("Inserting rows into dynamodb")
if self.pre_process is None:
- dynamodb.write_batch_data(json.loads(data.to_json(orient='records')))
+ dynamodb.write_batch_data(json.loads(data.to_json(orient="records")))
else:
dynamodb.write_batch_data(
self.pre_process(data=data, args=self.pre_process_args, kwargs=self.pre_process_kwargs)
)
- self.log.info('Done.')
+ self.log.info("Done.")
diff --git a/airflow/providers/amazon/aws/transfers/imap_attachment_to_s3.py b/airflow/providers/amazon/aws/transfers/imap_attachment_to_s3.py
index e79276dbc7190..b4394d6389254 100644
--- a/airflow/providers/amazon/aws/transfers/imap_attachment_to_s3.py
+++ b/airflow/providers/amazon/aws/transfers/imap_attachment_to_s3.py
@@ -16,8 +16,10 @@
# specific language governing permissions and limitations
# under the License.
"""This module allows you to transfer mail attachments from a mail server into s3 bucket."""
+from __future__ import annotations
+
import warnings
-from typing import TYPE_CHECKING, Optional, Sequence
+from typing import TYPE_CHECKING, Sequence
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
@@ -51,7 +53,7 @@ class ImapAttachmentToS3Operator(BaseOperator):
:param aws_conn_id: AWS connection to use.
"""
- template_fields: Sequence[str] = ('imap_attachment_name', 's3_key', 'imap_mail_filter')
+ template_fields: Sequence[str] = ("imap_attachment_name", "s3_key", "imap_mail_filter")
def __init__(
self,
@@ -60,12 +62,12 @@ def __init__(
s3_bucket: str,
s3_key: str,
imap_check_regex: bool = False,
- imap_mail_folder: str = 'INBOX',
- imap_mail_filter: str = 'All',
+ imap_mail_folder: str = "INBOX",
+ imap_mail_filter: str = "All",
s3_overwrite: bool = False,
- imap_conn_id: str = 'imap_default',
- s3_conn_id: Optional[str] = None,
- aws_conn_id: str = 'aws_default',
+ imap_conn_id: str = "imap_default",
+ s3_conn_id: str | None = None,
+ aws_conn_id: str = "aws_default",
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -83,14 +85,14 @@ def __init__(
self.imap_conn_id = imap_conn_id
self.aws_conn_id = aws_conn_id
- def execute(self, context: 'Context') -> None:
+ def execute(self, context: Context) -> None:
"""
This function executes the transfer from the email server (via imap) into s3.
:param context: The context while executing.
"""
self.log.info(
- 'Transferring mail attachment %s from mail server via imap to s3 key %s...',
+ "Transferring mail attachment %s from mail server via imap to s3 key %s...",
self.imap_attachment_name,
self.s3_key,
)
diff --git a/airflow/providers/amazon/aws/transfers/local_to_s3.py b/airflow/providers/amazon/aws/transfers/local_to_s3.py
index d3e76dcb1d275..d431935738b09 100644
--- a/airflow/providers/amazon/aws/transfers/local_to_s3.py
+++ b/airflow/providers/amazon/aws/transfers/local_to_s3.py
@@ -15,7 +15,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import TYPE_CHECKING, Optional, Sequence, Union
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Sequence
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
@@ -62,20 +64,20 @@ class LocalFilesystemToS3Operator(BaseOperator):
uploaded to the S3 bucket.
"""
- template_fields: Sequence[str] = ('filename', 'dest_key', 'dest_bucket')
+ template_fields: Sequence[str] = ("filename", "dest_key", "dest_bucket")
def __init__(
self,
*,
filename: str,
dest_key: str,
- dest_bucket: Optional[str] = None,
- aws_conn_id: str = 'aws_default',
- verify: Optional[Union[str, bool]] = None,
+ dest_bucket: str | None = None,
+ aws_conn_id: str = "aws_default",
+ verify: str | bool | None = None,
replace: bool = False,
encrypt: bool = False,
gzip: bool = False,
- acl_policy: Optional[str] = None,
+ acl_policy: str | None = None,
**kwargs,
):
super().__init__(**kwargs)
@@ -90,10 +92,10 @@ def __init__(
self.gzip = gzip
self.acl_policy = acl_policy
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
s3_bucket, s3_key = s3_hook.get_s3_bucket_key(
- self.dest_bucket, self.dest_key, 'dest_bucket', 'dest_key'
+ self.dest_bucket, self.dest_key, "dest_bucket", "dest_key"
)
s3_hook.load_file(
self.filename,
diff --git a/airflow/providers/amazon/aws/transfers/mongo_to_s3.py b/airflow/providers/amazon/aws/transfers/mongo_to_s3.py
index 44aae36378a29..eaa41f114d7da 100644
--- a/airflow/providers/amazon/aws/transfers/mongo_to_s3.py
+++ b/airflow/providers/amazon/aws/transfers/mongo_to_s3.py
@@ -15,9 +15,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
import json
import warnings
-from typing import TYPE_CHECKING, Any, Iterable, Optional, Sequence, Union, cast
+from typing import TYPE_CHECKING, Any, Iterable, Sequence, cast
from bson import json_util
@@ -57,25 +59,25 @@ class MongoToS3Operator(BaseOperator):
:param compression: type of compression to use for output file in S3. Currently only gzip is supported.
"""
- template_fields: Sequence[str] = ('s3_bucket', 's3_key', 'mongo_query', 'mongo_collection')
- ui_color = '#589636'
+ template_fields: Sequence[str] = ("s3_bucket", "s3_key", "mongo_query", "mongo_collection")
+ ui_color = "#589636"
template_fields_renderers = {"mongo_query": "json"}
def __init__(
self,
*,
- s3_conn_id: Optional[str] = None,
- mongo_conn_id: str = 'mongo_default',
- aws_conn_id: str = 'aws_default',
+ s3_conn_id: str | None = None,
+ mongo_conn_id: str = "mongo_default",
+ aws_conn_id: str = "aws_default",
mongo_collection: str,
- mongo_query: Union[list, dict],
+ mongo_query: list | dict,
s3_bucket: str,
s3_key: str,
- mongo_db: Optional[str] = None,
- mongo_projection: Optional[Union[list, dict]] = None,
+ mongo_db: str | None = None,
+ mongo_projection: list | dict | None = None,
replace: bool = False,
allow_disk_use: bool = False,
- compression: Optional[str] = None,
+ compression: str | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -99,7 +101,7 @@ def __init__(
self.allow_disk_use = allow_disk_use
self.compression = compression
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
"""Is written to depend on transform method"""
s3_conn = S3Hook(self.aws_conn_id)
@@ -132,7 +134,7 @@ def execute(self, context: 'Context'):
)
@staticmethod
- def _stringify(iterable: Iterable, joinable: str = '\n') -> str:
+ def _stringify(iterable: Iterable, joinable: str = "\n") -> str:
"""
Takes an iterable (pymongo Cursor or Array) containing dictionaries and
returns a stringified version using python join
diff --git a/airflow/providers/amazon/aws/transfers/mysql_to_s3.py b/airflow/providers/amazon/aws/transfers/mysql_to_s3.py
deleted file mode 100644
index dc3d84ecb3658..0000000000000
--- a/airflow/providers/amazon/aws/transfers/mysql_to_s3.py
+++ /dev/null
@@ -1,72 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-import warnings
-from typing import Optional
-
-from airflow.exceptions import AirflowException
-from airflow.providers.amazon.aws.transfers.sql_to_s3 import SqlToS3Operator
-
-warnings.warn(
- "This module is deprecated. Please use airflow.providers.amazon.aws.transfers.sql_to_s3`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class MySQLToS3Operator(SqlToS3Operator):
- """
- This class is deprecated.
- Please use `airflow.providers.amazon.aws.transfers.sql_to_s3.SqlToS3Operator`.
- """
-
- template_fields_renderers = {
- "pd_csv_kwargs": "json",
- }
-
- def __init__(
- self,
- *,
- mysql_conn_id: str = 'mysql_default',
- pd_csv_kwargs: Optional[dict] = None,
- index: bool = False,
- header: bool = False,
- **kwargs,
- ) -> None:
- warnings.warn(
- """
- MySQLToS3Operator is deprecated.
- Please use `airflow.providers.amazon.aws.transfers.sql_to_s3.SqlToS3Operator`.
- """,
- DeprecationWarning,
- stacklevel=2,
- )
-
- pd_kwargs = kwargs.get('pd_kwargs', {})
- if kwargs.get('file_format', "csv") == "csv":
- if "path_or_buf" in pd_kwargs:
- raise AirflowException('The argument path_or_buf is not allowed, please remove it')
- if "index" not in pd_kwargs:
- pd_kwargs["index"] = index
- if "header" not in pd_kwargs:
- pd_kwargs["header"] = header
- kwargs["pd_kwargs"] = {**kwargs.get('pd_kwargs', {}), **pd_kwargs}
- elif pd_csv_kwargs is not None:
- raise TypeError("pd_csv_kwargs may not be specified when file_format='parquet'")
-
- super().__init__(sql_conn_id=mysql_conn_id, **kwargs)
diff --git a/airflow/providers/amazon/aws/transfers/redshift_to_s3.py b/airflow/providers/amazon/aws/transfers/redshift_to_s3.py
index 0bfdda44e78d6..ed57df0984054 100644
--- a/airflow/providers/amazon/aws/transfers/redshift_to_s3.py
+++ b/airflow/providers/amazon/aws/transfers/redshift_to_s3.py
@@ -16,7 +16,9 @@
# specific language governing permissions and limitations
# under the License.
"""Transfers data from AWS Redshift into a S3 Bucket."""
-from typing import TYPE_CHECKING, Iterable, List, Mapping, Optional, Sequence, Union
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Iterable, Mapping, Sequence
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook
@@ -29,7 +31,7 @@
class RedshiftToS3Operator(BaseOperator):
"""
- Executes an UNLOAD command to s3 as a CSV with headers
+ Execute an UNLOAD command to s3 as a CSV with headers.
.. seealso::
For more information on how to use this operator, take a look at the guide:
@@ -68,44 +70,45 @@ class RedshiftToS3Operator(BaseOperator):
"""
template_fields: Sequence[str] = (
- 's3_bucket',
- 's3_key',
- 'schema',
- 'table',
- 'unload_options',
- 'select_query',
+ "s3_bucket",
+ "s3_key",
+ "schema",
+ "table",
+ "unload_options",
+ "select_query",
+ "redshift_conn_id",
)
- template_ext: Sequence[str] = ('.sql',)
- template_fields_renderers = {'select_query': 'sql'}
- ui_color = '#ededed'
+ template_ext: Sequence[str] = (".sql",)
+ template_fields_renderers = {"select_query": "sql"}
+ ui_color = "#ededed"
def __init__(
self,
*,
s3_bucket: str,
s3_key: str,
- schema: Optional[str] = None,
- table: Optional[str] = None,
- select_query: Optional[str] = None,
- redshift_conn_id: str = 'redshift_default',
- aws_conn_id: str = 'aws_default',
- verify: Optional[Union[bool, str]] = None,
- unload_options: Optional[List] = None,
+ schema: str | None = None,
+ table: str | None = None,
+ select_query: str | None = None,
+ redshift_conn_id: str = "redshift_default",
+ aws_conn_id: str = "aws_default",
+ verify: bool | str | None = None,
+ unload_options: list | None = None,
autocommit: bool = False,
include_header: bool = False,
- parameters: Optional[Union[Mapping, Iterable]] = None,
+ parameters: Iterable | Mapping | None = None,
table_as_file_name: bool = True, # Set to True by default for not breaking current workflows
**kwargs,
) -> None:
super().__init__(**kwargs)
self.s3_bucket = s3_bucket
- self.s3_key = f'{s3_key}/{table}_' if (table and table_as_file_name) else s3_key
+ self.s3_key = f"{s3_key}/{table}_" if (table and table_as_file_name) else s3_key
self.schema = schema
self.table = table
self.redshift_conn_id = redshift_conn_id
self.aws_conn_id = aws_conn_id
self.verify = verify
- self.unload_options = unload_options or [] # type: List
+ self.unload_options: list = unload_options or []
self.autocommit = autocommit
self.include_header = include_header
self.parameters = parameters
@@ -117,12 +120,12 @@ def __init__(
self.select_query = f"SELECT * FROM {self.schema}.{self.table}"
else:
raise ValueError(
- 'Please provide both `schema` and `table` params or `select_query` to fetch the data.'
+ "Please provide both `schema` and `table` params or `select_query` to fetch the data."
)
- if self.include_header and 'HEADER' not in [uo.upper().strip() for uo in self.unload_options]:
+ if self.include_header and "HEADER" not in [uo.upper().strip() for uo in self.unload_options]:
self.unload_options = list(self.unload_options) + [
- 'HEADER',
+ "HEADER",
]
def _build_unload_query(
@@ -136,22 +139,22 @@ def _build_unload_query(
{unload_options};
"""
- def execute(self, context: 'Context') -> None:
+ def execute(self, context: Context) -> None:
redshift_hook = RedshiftSQLHook(redshift_conn_id=self.redshift_conn_id)
conn = S3Hook.get_connection(conn_id=self.aws_conn_id)
- if conn.extra_dejson.get('role_arn', False):
+ if conn.extra_dejson.get("role_arn", False):
credentials_block = f"aws_iam_role={conn.extra_dejson['role_arn']}"
else:
s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
credentials = s3_hook.get_credentials()
credentials_block = build_credentials_block(credentials)
- unload_options = '\n\t\t\t'.join(self.unload_options)
+ unload_options = "\n\t\t\t".join(self.unload_options)
unload_query = self._build_unload_query(
credentials_block, self.select_query, self.s3_key, unload_options
)
- self.log.info('Executing UNLOAD command...')
+ self.log.info("Executing UNLOAD command...")
redshift_hook.run(unload_query, self.autocommit, parameters=self.parameters)
self.log.info("UNLOAD command complete...")
diff --git a/airflow/providers/amazon/aws/transfers/s3_to_ftp.py b/airflow/providers/amazon/aws/transfers/s3_to_ftp.py
index 2e07a9575fd26..9a15772310747 100644
--- a/airflow/providers/amazon/aws/transfers/s3_to_ftp.py
+++ b/airflow/providers/amazon/aws/transfers/s3_to_ftp.py
@@ -15,6 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Sequence
@@ -46,7 +47,7 @@ class S3ToFTPOperator(BaseOperator):
establishing a connection to the FTP server.
"""
- template_fields: Sequence[str] = ('s3_bucket', 's3_key', 'ftp_path')
+ template_fields: Sequence[str] = ("s3_bucket", "s3_key", "ftp_path")
def __init__(
self,
@@ -54,8 +55,8 @@ def __init__(
s3_bucket,
s3_key,
ftp_path,
- aws_conn_id='aws_default',
- ftp_conn_id='ftp_default',
+ aws_conn_id="aws_default",
+ ftp_conn_id="ftp_default",
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -65,15 +66,15 @@ def __init__(
self.aws_conn_id = aws_conn_id
self.ftp_conn_id = ftp_conn_id
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
s3_hook = S3Hook(self.aws_conn_id)
ftp_hook = FTPHook(ftp_conn_id=self.ftp_conn_id)
s3_obj = s3_hook.get_key(self.s3_key, self.s3_bucket)
with NamedTemporaryFile() as local_tmp_file:
- self.log.info('Downloading file from %s', self.s3_key)
+ self.log.info("Downloading file from %s", self.s3_key)
s3_obj.download_fileobj(local_tmp_file)
local_tmp_file.seek(0)
ftp_hook.store_file(self.ftp_path, local_tmp_file.name)
- self.log.info('File stored in %s', {self.ftp_path})
+ self.log.info("File stored in %s", {self.ftp_path})
diff --git a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
index 014e23ec070f2..bbb83310b7bb7 100644
--- a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
+++ b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
@@ -14,9 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
-import warnings
-from typing import TYPE_CHECKING, List, Optional, Sequence, Union
+from typing import TYPE_CHECKING, Iterable, Sequence
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
@@ -28,7 +28,7 @@
from airflow.utils.context import Context
-AVAILABLE_METHODS = ['APPEND', 'REPLACE', 'UPSERT']
+AVAILABLE_METHODS = ["APPEND", "REPLACE", "UPSERT"]
class S3ToRedshiftOperator(BaseOperator):
@@ -64,9 +64,17 @@ class S3ToRedshiftOperator(BaseOperator):
:param upsert_keys: List of fields to use as key on upsert action
"""
- template_fields: Sequence[str] = ('s3_bucket', 's3_key', 'schema', 'table', 'column_list', 'copy_options')
+ template_fields: Sequence[str] = (
+ "s3_bucket",
+ "s3_key",
+ "schema",
+ "table",
+ "column_list",
+ "copy_options",
+ "redshift_conn_id",
+ )
template_ext: Sequence[str] = ()
- ui_color = '#99e699'
+ ui_color = "#99e699"
def __init__(
self,
@@ -75,27 +83,16 @@ def __init__(
table: str,
s3_bucket: str,
s3_key: str,
- redshift_conn_id: str = 'redshift_default',
- aws_conn_id: str = 'aws_default',
- verify: Optional[Union[bool, str]] = None,
- column_list: Optional[List[str]] = None,
- copy_options: Optional[List] = None,
+ redshift_conn_id: str = "redshift_default",
+ aws_conn_id: str = "aws_default",
+ verify: bool | str | None = None,
+ column_list: list[str] | None = None,
+ copy_options: list | None = None,
autocommit: bool = False,
- method: str = 'APPEND',
- upsert_keys: Optional[List[str]] = None,
+ method: str = "APPEND",
+ upsert_keys: list[str] | None = None,
**kwargs,
) -> None:
-
- if 'truncate_table' in kwargs:
- warnings.warn(
- """`truncate_table` is deprecated. Please use `REPLACE` method.""",
- DeprecationWarning,
- stacklevel=2,
- )
- if kwargs['truncate_table']:
- method = 'REPLACE'
- kwargs.pop('truncate_table', None)
-
super().__init__(**kwargs)
self.schema = schema
self.table = table
@@ -111,10 +108,10 @@ def __init__(
self.upsert_keys = upsert_keys
if self.method not in AVAILABLE_METHODS:
- raise AirflowException(f'Method not found! Available methods: {AVAILABLE_METHODS}')
+ raise AirflowException(f"Method not found! Available methods: {AVAILABLE_METHODS}")
def _build_copy_query(self, copy_destination: str, credentials_block: str, copy_options: str) -> str:
- column_names = "(" + ", ".join(self.column_list) + ")" if self.column_list else ''
+ column_names = "(" + ", ".join(self.column_list) + ")" if self.column_list else ""
return f"""
COPY {copy_destination} {column_names}
FROM 's3://{self.s3_bucket}/{self.s3_key}'
@@ -123,34 +120,34 @@ def _build_copy_query(self, copy_destination: str, credentials_block: str, copy_
{copy_options};
"""
- def execute(self, context: 'Context') -> None:
+ def execute(self, context: Context) -> None:
redshift_hook = RedshiftSQLHook(redshift_conn_id=self.redshift_conn_id)
conn = S3Hook.get_connection(conn_id=self.aws_conn_id)
- if conn.extra_dejson.get('role_arn', False):
+ if conn.extra_dejson.get("role_arn", False):
credentials_block = f"aws_iam_role={conn.extra_dejson['role_arn']}"
else:
s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
credentials = s3_hook.get_credentials()
credentials_block = build_credentials_block(credentials)
- copy_options = '\n\t\t\t'.join(self.copy_options)
- destination = f'{self.schema}.{self.table}'
- copy_destination = f'#{self.table}' if self.method == 'UPSERT' else destination
+ copy_options = "\n\t\t\t".join(self.copy_options)
+ destination = f"{self.schema}.{self.table}"
+ copy_destination = f"#{self.table}" if self.method == "UPSERT" else destination
copy_statement = self._build_copy_query(copy_destination, credentials_block, copy_options)
- sql: Union[list, str]
+ sql: str | Iterable[str]
- if self.method == 'REPLACE':
+ if self.method == "REPLACE":
sql = ["BEGIN;", f"DELETE FROM {destination};", copy_statement, "COMMIT"]
- elif self.method == 'UPSERT':
+ elif self.method == "UPSERT":
keys = self.upsert_keys or redshift_hook.get_table_primary_key(self.table, self.schema)
if not keys:
raise AirflowException(
f"No primary key on {self.schema}.{self.table}. Please provide keys on 'upsert_keys'"
)
- where_statement = ' AND '.join([f'{self.table}.{k} = {copy_destination}.{k}' for k in keys])
+ where_statement = " AND ".join([f"{self.table}.{k} = {copy_destination}.{k}" for k in keys])
sql = [
f"CREATE TABLE {copy_destination} (LIKE {destination});",
@@ -164,6 +161,6 @@ def execute(self, context: 'Context') -> None:
else:
sql = copy_statement
- self.log.info('Executing COPY command...')
+ self.log.info("Executing COPY command...")
redshift_hook.run(sql, autocommit=self.autocommit)
self.log.info("COPY command complete...")
diff --git a/airflow/providers/amazon/aws/transfers/s3_to_sftp.py b/airflow/providers/amazon/aws/transfers/s3_to_sftp.py
index 7c003cfb72608..3038c17cb0a6c 100644
--- a/airflow/providers/amazon/aws/transfers/s3_to_sftp.py
+++ b/airflow/providers/amazon/aws/transfers/s3_to_sftp.py
@@ -15,10 +15,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
import warnings
from tempfile import NamedTemporaryFile
-from typing import TYPE_CHECKING, Optional, Sequence
-from urllib.parse import urlparse
+from typing import TYPE_CHECKING, Sequence
+from urllib.parse import urlsplit
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
@@ -51,7 +53,7 @@ class S3ToSFTPOperator(BaseOperator):
downloading the file from S3.
"""
- template_fields: Sequence[str] = ('s3_key', 'sftp_path')
+ template_fields: Sequence[str] = ("s3_key", "sftp_path")
def __init__(
self,
@@ -59,9 +61,9 @@ def __init__(
s3_bucket: str,
s3_key: str,
sftp_path: str,
- sftp_conn_id: str = 'ssh_default',
- s3_conn_id: Optional[str] = None,
- aws_conn_id: str = 'aws_default',
+ sftp_conn_id: str = "ssh_default",
+ s3_conn_id: str | None = None,
+ aws_conn_id: str = "aws_default",
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -78,10 +80,10 @@ def __init__(
@staticmethod
def get_s3_key(s3_key: str) -> str:
"""This parses the correct format for S3 keys regardless of how the S3 url is passed."""
- parsed_s3_key = urlparse(s3_key)
- return parsed_s3_key.path.lstrip('/')
+ parsed_s3_key = urlsplit(s3_key)
+ return parsed_s3_key.path.lstrip("/")
- def execute(self, context: 'Context') -> None:
+ def execute(self, context: Context) -> None:
self.s3_key = self.get_s3_key(self.s3_key)
ssh_hook = SSHHook(ssh_conn_id=self.sftp_conn_id)
s3_hook = S3Hook(self.aws_conn_id)
diff --git a/airflow/providers/amazon/aws/transfers/salesforce_to_s3.py b/airflow/providers/amazon/aws/transfers/salesforce_to_s3.py
index a953693f1f193..f0d9c82f3fe63 100644
--- a/airflow/providers/amazon/aws/transfers/salesforce_to_s3.py
+++ b/airflow/providers/amazon/aws/transfers/salesforce_to_s3.py
@@ -14,10 +14,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
import os
import tempfile
-from typing import TYPE_CHECKING, Dict, Optional, Sequence
+from typing import TYPE_CHECKING, Sequence
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
@@ -70,7 +71,7 @@ def __init__(
s3_key: str,
salesforce_conn_id: str,
export_format: str = "csv",
- query_params: Optional[Dict] = None,
+ query_params: dict | None = None,
include_deleted: bool = False,
coerce_to_timestamp: bool = False,
record_time_added: bool = False,
@@ -78,7 +79,7 @@ def __init__(
replace: bool = False,
encrypt: bool = False,
gzip: bool = False,
- acl_policy: Optional[str] = None,
+ acl_policy: str | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -97,7 +98,7 @@ def __init__(
self.gzip = gzip
self.acl_policy = acl_policy
- def execute(self, context: 'Context') -> str:
+ def execute(self, context: Context) -> str:
salesforce_hook = SalesforceHook(salesforce_conn_id=self.salesforce_conn_id)
response = salesforce_hook.make_query(
query=self.salesforce_query,
diff --git a/airflow/providers/amazon/aws/transfers/sftp_to_s3.py b/airflow/providers/amazon/aws/transfers/sftp_to_s3.py
index 71376e3179c0a..546c4710267b4 100644
--- a/airflow/providers/amazon/aws/transfers/sftp_to_s3.py
+++ b/airflow/providers/amazon/aws/transfers/sftp_to_s3.py
@@ -15,9 +15,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Sequence
-from urllib.parse import urlparse
+from urllib.parse import urlsplit
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
@@ -50,7 +52,7 @@ class SFTPToS3Operator(BaseOperator):
if False streams file from SFTP to S3.
"""
- template_fields: Sequence[str] = ('s3_key', 'sftp_path')
+ template_fields: Sequence[str] = ("s3_key", "sftp_path")
def __init__(
self,
@@ -58,8 +60,8 @@ def __init__(
s3_bucket: str,
s3_key: str,
sftp_path: str,
- sftp_conn_id: str = 'ssh_default',
- s3_conn_id: str = 'aws_default',
+ sftp_conn_id: str = "ssh_default",
+ s3_conn_id: str = "aws_default",
use_temp_file: bool = True,
**kwargs,
) -> None:
@@ -74,10 +76,10 @@ def __init__(
@staticmethod
def get_s3_key(s3_key: str) -> str:
"""This parses the correct format for S3 keys regardless of how the S3 url is passed."""
- parsed_s3_key = urlparse(s3_key)
- return parsed_s3_key.path.lstrip('/')
+ parsed_s3_key = urlsplit(s3_key)
+ return parsed_s3_key.path.lstrip("/")
- def execute(self, context: 'Context') -> None:
+ def execute(self, context: Context) -> None:
self.s3_key = self.get_s3_key(self.s3_key)
ssh_hook = SSHHook(ssh_conn_id=self.sftp_conn_id)
s3_hook = S3Hook(self.s3_conn_id)
@@ -90,5 +92,5 @@ def execute(self, context: 'Context') -> None:
s3_hook.load_file(filename=f.name, key=self.s3_key, bucket_name=self.s3_bucket, replace=True)
else:
- with sftp_client.file(self.sftp_path, mode='rb') as data:
+ with sftp_client.file(self.sftp_path, mode="rb") as data:
s3_hook.get_conn().upload_fileobj(data, self.s3_bucket, self.s3_key, Callback=self.log.info)
diff --git a/airflow/providers/amazon/aws/transfers/sql_to_s3.py b/airflow/providers/amazon/aws/transfers/sql_to_s3.py
index f399c271416e4..713fbc005901f 100644
--- a/airflow/providers/amazon/aws/transfers/sql_to_s3.py
+++ b/airflow/providers/amazon/aws/transfers/sql_to_s3.py
@@ -15,11 +15,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+import enum
from collections import namedtuple
-from enum import Enum
from tempfile import NamedTemporaryFile
-from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Sequence, Union
+from typing import TYPE_CHECKING, Iterable, Mapping, Sequence
import numpy as np
import pandas as pd
@@ -27,25 +28,28 @@
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
-from airflow.hooks.dbapi import DbApiHook
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
+from airflow.providers.common.sql.hooks.sql import DbApiHook
if TYPE_CHECKING:
from airflow.utils.context import Context
-FILE_FORMAT = Enum(
- "FILE_FORMAT",
- "CSV, JSON, PARQUET",
-)
+class FILE_FORMAT(enum.Enum):
+ """Possible file formats."""
-FileOptions = namedtuple('FileOptions', ['mode', 'suffix', 'function'])
+ CSV = enum.auto()
+ JSON = enum.auto()
+ PARQUET = enum.auto()
+
+
+FileOptions = namedtuple("FileOptions", ["mode", "suffix", "function"])
FILE_OPTIONS_MAP = {
- FILE_FORMAT.CSV: FileOptions('r+', '.csv', 'to_csv'),
- FILE_FORMAT.JSON: FileOptions('r+', '.json', 'to_json'),
- FILE_FORMAT.PARQUET: FileOptions('rb+', '.parquet', 'to_parquet'),
+ FILE_FORMAT.CSV: FileOptions("r+", ".csv", "to_csv"),
+ FILE_FORMAT.JSON: FileOptions("r+", ".json", "to_json"),
+ FILE_FORMAT.PARQUET: FileOptions("rb+", ".parquet", "to_parquet"),
}
@@ -79,11 +83,12 @@ class SqlToS3Operator(BaseOperator):
"""
template_fields: Sequence[str] = (
- 's3_bucket',
- 's3_key',
- 'query',
+ "s3_bucket",
+ "s3_key",
+ "query",
+ "sql_conn_id",
)
- template_ext: Sequence[str] = ('.sql',)
+ template_ext: Sequence[str] = (".sql",)
template_fields_renderers = {
"query": "sql",
"pd_kwargs": "json",
@@ -96,12 +101,12 @@ def __init__(
s3_bucket: str,
s3_key: str,
sql_conn_id: str,
- parameters: Union[None, Mapping, Iterable] = None,
+ parameters: None | Mapping | Iterable = None,
replace: bool = False,
- aws_conn_id: str = 'aws_default',
- verify: Optional[Union[bool, str]] = None,
- file_format: Literal['csv', 'json', 'parquet'] = 'csv',
- pd_kwargs: Optional[dict] = None,
+ aws_conn_id: str = "aws_default",
+ verify: bool | str | None = None,
+ file_format: Literal["csv", "json", "parquet"] = "csv",
+ pd_kwargs: dict | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -116,17 +121,26 @@ def __init__(
self.parameters = parameters
if "path_or_buf" in self.pd_kwargs:
- raise AirflowException('The argument path_or_buf is not allowed, please remove it')
-
- self.file_format = getattr(FILE_FORMAT, file_format.upper(), None)
+ raise AirflowException("The argument path_or_buf is not allowed, please remove it")
- if self.file_format is None:
+ try:
+ self.file_format = FILE_FORMAT[file_format.upper()]
+ except KeyError:
raise AirflowException(f"The argument file_format doesn't support {file_format} value.")
@staticmethod
- def _fix_int_dtypes(df: pd.DataFrame) -> None:
- """Mutate DataFrame to set dtypes for int columns containing NaN values."""
+ def _fix_dtypes(df: pd.DataFrame, file_format: FILE_FORMAT) -> None:
+ """
+ Mutate DataFrame to set dtypes for float columns containing NaN values.
+ Set dtype of object to str to allow for downstream transformations.
+ """
for col in df:
+
+ if df[col].dtype.name == "object" and file_format == "parquet":
+ # if the type wasn't identified or converted, change it to a string so if can still be
+ # processed.
+ df[col] = df[col].astype(str)
+
if "float" in df[col].dtype.name and df[col].hasnans:
# inspect values to determine if dtype of non-null values is int or float
notna_series = df[col].dropna().values
@@ -139,13 +153,13 @@ def _fix_int_dtypes(df: pd.DataFrame) -> None:
df[col] = np.where(df[col].isnull(), None, df[col])
df[col] = df[col].astype(pd.Float64Dtype())
- def execute(self, context: 'Context') -> None:
+ def execute(self, context: Context) -> None:
sql_hook = self._get_hook()
s3_conn = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
data_df = sql_hook.get_pandas_df(sql=self.query, parameters=self.parameters)
self.log.info("Data from SQL obtained")
- self._fix_int_dtypes(data_df)
+ self._fix_dtypes(data_df, self.file_format)
file_options = FILE_OPTIONS_MAP[self.file_format]
with NamedTemporaryFile(mode=file_options.mode, suffix=file_options.suffix) as tmp_file:
@@ -162,7 +176,7 @@ def _get_hook(self) -> DbApiHook:
self.log.debug("Get connection for %s", self.sql_conn_id)
conn = BaseHook.get_connection(self.sql_conn_id)
hook = conn.get_hook()
- if not callable(getattr(hook, 'get_pandas_df', None)):
+ if not callable(getattr(hook, "get_pandas_df", None)):
raise AirflowException(
"This hook is not supported. The hook class must have get_pandas_df method."
)
diff --git a/airflow/providers/amazon/aws/utils/__init__.py b/airflow/providers/amazon/aws/utils/__init__.py
index 13a83393a9124..0aff1289dc5e5 100644
--- a/airflow/providers/amazon/aws/utils/__init__.py
+++ b/airflow/providers/amazon/aws/utils/__init__.py
@@ -14,3 +14,33 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
+import re
+from datetime import datetime
+
+from airflow.version import version
+
+
+def trim_none_values(obj: dict):
+ return {key: val for key, val in obj.items() if val is not None}
+
+
+def datetime_to_epoch(date_time: datetime) -> int:
+ """Convert a datetime object to an epoch integer (seconds)."""
+ return int(date_time.timestamp())
+
+
+def datetime_to_epoch_ms(date_time: datetime) -> int:
+ """Convert a datetime object to an epoch integer (milliseconds)."""
+ return int(date_time.timestamp() * 1_000)
+
+
+def datetime_to_epoch_us(date_time: datetime) -> int:
+ """Convert a datetime object to an epoch integer (microseconds)."""
+ return int(date_time.timestamp() * 1_000_000)
+
+
+def get_airflow_version() -> tuple[int, ...]:
+ val = re.sub(r"(\d+\.\d+\.\d+).*", lambda x: x.group(1), version)
+ return tuple(int(x) for x in val.split("."))
diff --git a/airflow/providers/amazon/aws/utils/connection_wrapper.py b/airflow/providers/amazon/aws/utils/connection_wrapper.py
new file mode 100644
index 0000000000000..f6ae3fdf01837
--- /dev/null
+++ b/airflow/providers/amazon/aws/utils/connection_wrapper.py
@@ -0,0 +1,475 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import json
+import warnings
+from copy import deepcopy
+from dataclasses import MISSING, InitVar, dataclass, field, fields
+from typing import TYPE_CHECKING, Any
+
+from botocore.config import Config
+
+from airflow.compat.functools import cached_property
+from airflow.exceptions import AirflowException
+from airflow.providers.amazon.aws.utils import trim_none_values
+from airflow.utils.log.logging_mixin import LoggingMixin
+from airflow.utils.log.secrets_masker import mask_secret
+from airflow.utils.types import NOTSET, ArgNotSet
+
+if TYPE_CHECKING:
+ from airflow.models.connection import Connection # Avoid circular imports.
+
+
+@dataclass
+class _ConnectionMetadata:
+ """Connection metadata data-class.
+
+ This class implements main :ref:`~airflow.models.connection.Connection` attributes
+ and use in AwsConnectionWrapper for avoid circular imports.
+
+ Only for internal usage, this class might change or removed in the future.
+ """
+
+ conn_id: str | None = None
+ conn_type: str | None = None
+ description: str | None = None
+ host: str | None = None
+ login: str | None = None
+ password: str | None = None
+ schema: str | None = None
+ port: int | None = None
+ extra: str | dict | None = None
+
+ @property
+ def extra_dejson(self):
+ if not self.extra:
+ return {}
+ extra = deepcopy(self.extra)
+ if isinstance(extra, str):
+ try:
+ extra = json.loads(extra)
+ except json.JSONDecodeError as err:
+ raise AirflowException(
+ f"'extra' expected valid JSON-Object string. Original error:\n * {err}"
+ ) from None
+ if not isinstance(extra, dict):
+ raise TypeError(f"Expected JSON-Object or dict, got {type(extra).__name__}.")
+ return extra
+
+
+@dataclass
+class AwsConnectionWrapper(LoggingMixin):
+ """
+ AWS Connection Wrapper class helper.
+ Use for validate and resolve AWS Connection parameters.
+
+ ``conn`` reference to Airflow Connection object or AwsConnectionWrapper
+ if it set to ``None`` than default values would use.
+
+ The precedence rules for ``region_name``
+ 1. Explicit set (in Hook) ``region_name``.
+ 2. Airflow Connection Extra 'region_name'.
+
+ The precedence rules for ``botocore_config``
+ 1. Explicit set (in Hook) ``botocore_config``.
+ 2. Construct from Airflow Connection Extra 'botocore_kwargs'.
+ 3. The wrapper's default value
+ """
+
+ conn: InitVar[Connection | AwsConnectionWrapper | _ConnectionMetadata | None]
+ region_name: str | None = field(default=None)
+ # boto3 client/resource configs
+ botocore_config: Config | None = field(default=None)
+ verify: bool | str | None = field(default=None)
+
+ # Reference to Airflow Connection attributes
+ # ``extra_config`` contains original Airflow Connection Extra.
+ conn_id: str | ArgNotSet | None = field(init=False, default=NOTSET)
+ conn_type: str | None = field(init=False, default=None)
+ login: str | None = field(init=False, repr=False, default=None)
+ password: str | None = field(init=False, repr=False, default=None)
+ extra_config: dict[str, Any] = field(init=False, repr=False, default_factory=dict)
+
+ # AWS Credentials from connection.
+ aws_access_key_id: str | None = field(init=False, default=None)
+ aws_secret_access_key: str | None = field(init=False, default=None)
+ aws_session_token: str | None = field(init=False, default=None)
+
+ # AWS Shared Credential profile_name
+ profile_name: str | None = field(init=False, default=None)
+ # Custom endpoint_url for boto3.client and boto3.resource
+ endpoint_url: str | None = field(init=False, default=None)
+
+ # Assume Role Configurations
+ role_arn: str | None = field(init=False, default=None)
+ assume_role_method: str | None = field(init=False, default=None)
+ assume_role_kwargs: dict[str, Any] = field(init=False, default_factory=dict)
+
+ @cached_property
+ def conn_repr(self):
+ return f"AWS Connection (conn_id={self.conn_id!r}, conn_type={self.conn_type!r})"
+
+ def __post_init__(self, conn: Connection):
+ if isinstance(conn, type(self)):
+ # For every field with init=False we copy reference value from original wrapper
+ # For every field with init=True we use init values if it not equal default
+ # We can't use ``dataclasses.replace`` in classmethod because
+ # we limited by InitVar arguments since it not stored in object,
+ # and also we do not want to run __post_init__ method again which print all logs/warnings again.
+ for fl in fields(conn):
+ value = getattr(conn, fl.name)
+ if not fl.init:
+ setattr(self, fl.name, value)
+ else:
+ if fl.default is not MISSING:
+ default = fl.default
+ elif fl.default_factory is not MISSING:
+ default = fl.default_factory() # zero-argument callable
+ else:
+ continue # Value mandatory, skip
+
+ orig_value = getattr(self, fl.name)
+ if orig_value == default:
+ # Only replace value if it not equal default value
+ setattr(self, fl.name, value)
+ return
+ elif not conn:
+ return
+
+ # Assign attributes from AWS Connection
+ self.conn_id = conn.conn_id
+ self.conn_type = conn.conn_type or "aws"
+ self.login = conn.login
+ self.password = conn.password
+ self.extra_config = deepcopy(conn.extra_dejson)
+
+ if self.conn_type.lower() == "s3":
+ warnings.warn(
+ f"{self.conn_repr} has connection type 's3', "
+ "which has been replaced by connection type 'aws'. "
+ "Please update your connection to have `conn_type='aws'`.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ elif self.conn_type != "aws":
+ warnings.warn(
+ f"{self.conn_repr} expected connection type 'aws', got {self.conn_type!r}. "
+ "This connection might not work correctly. "
+ "Please use Amazon Web Services Connection type.",
+ UserWarning,
+ stacklevel=2,
+ )
+
+ extra = deepcopy(conn.extra_dejson)
+ session_kwargs = extra.get("session_kwargs", {})
+ if session_kwargs:
+ warnings.warn(
+ "'session_kwargs' in extra config is deprecated and will be removed in a future releases. "
+ f"Please specify arguments passed to boto3 Session directly in {self.conn_repr} extra.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+
+ # Retrieve initial connection credentials
+ init_credentials = self._get_credentials(**extra)
+ self.aws_access_key_id, self.aws_secret_access_key, self.aws_session_token = init_credentials
+
+ if not self.region_name:
+ if "region_name" in extra:
+ self.region_name = extra["region_name"]
+ self.log.debug("Retrieving region_name=%s from %s extra.", self.region_name, self.conn_repr)
+ elif "region_name" in session_kwargs:
+ self.region_name = session_kwargs["region_name"]
+ self.log.debug(
+ "Retrieving region_name=%s from %s extra['session_kwargs'].",
+ self.region_name,
+ self.conn_repr,
+ )
+
+ if self.verify is None and "verify" in extra:
+ self.verify = extra["verify"]
+ self.log.debug("Retrieving verify=%s from %s extra.", self.verify, self.conn_repr)
+
+ if "profile_name" in extra:
+ self.profile_name = extra["profile_name"]
+ self.log.debug("Retrieving profile_name=%s from %s extra.", self.profile_name, self.conn_repr)
+ elif "profile_name" in session_kwargs:
+ self.profile_name = session_kwargs["profile_name"]
+ self.log.debug(
+ "Retrieving profile_name=%s from %s extra['session_kwargs'].",
+ self.profile_name,
+ self.conn_repr,
+ )
+
+ # Warn the user that an invalid parameter is being used which actually not related to 'profile_name'.
+ # ToDo: Remove this check entirely as soon as drop support credentials from s3_config_file
+ if "profile" in extra and "s3_config_file" not in extra and not self.profile_name:
+ warnings.warn(
+ f"Found 'profile' without specifying 's3_config_file' in {self.conn_repr} extra. "
+ "If required profile from AWS Shared Credentials please "
+ f"set 'profile_name' in {self.conn_repr} extra.",
+ UserWarning,
+ stacklevel=2,
+ )
+
+ config_kwargs = extra.get("config_kwargs")
+ if not self.botocore_config and config_kwargs:
+ # https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
+ self.log.debug("Retrieving botocore config=%s from %s extra.", config_kwargs, self.conn_repr)
+ self.botocore_config = Config(**config_kwargs)
+
+ if conn.host:
+ warnings.warn(
+ f"Host {conn.host} specified in the connection is not used."
+ " Please, set it on extra['endpoint_url'] instead",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+
+ self.endpoint_url = extra.get("host")
+ if self.endpoint_url:
+ warnings.warn(
+ "extra['host'] is deprecated and will be removed in a future release."
+ " Please set extra['endpoint_url'] instead",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ else:
+ self.endpoint_url = extra.get("endpoint_url")
+
+ # Retrieve Assume Role Configuration
+ assume_role_configs = self._get_assume_role_configs(**extra)
+ self.role_arn, self.assume_role_method, self.assume_role_kwargs = assume_role_configs
+
+ @classmethod
+ def from_connection_metadata(
+ cls,
+ conn_id: str | None = None,
+ login: str | None = None,
+ password: str | None = None,
+ extra: dict[str, Any] | None = None,
+ ):
+ """
+ Create config from connection metadata.
+
+ :param conn_id: Custom connection ID.
+ :param login: AWS Access Key ID.
+ :param password: AWS Secret Access Key.
+ :param extra: Connection Extra metadata.
+ """
+ conn_meta = _ConnectionMetadata(
+ conn_id=conn_id, conn_type="aws", login=login, password=password, extra=extra
+ )
+ return cls(conn=conn_meta)
+
+ @property
+ def extra_dejson(self):
+ """Compatibility with `airflow.models.Connection.extra_dejson` property."""
+ return self.extra_config
+
+ @property
+ def session_kwargs(self) -> dict[str, Any]:
+ """Additional kwargs passed to boto3.session.Session."""
+ return trim_none_values(
+ {
+ "aws_access_key_id": self.aws_access_key_id,
+ "aws_secret_access_key": self.aws_secret_access_key,
+ "aws_session_token": self.aws_session_token,
+ "region_name": self.region_name,
+ "profile_name": self.profile_name,
+ }
+ )
+
+ def __bool__(self):
+ return self.conn_id is not NOTSET
+
+ def _get_credentials(
+ self,
+ *,
+ aws_access_key_id: str | None = None,
+ aws_secret_access_key: str | None = None,
+ aws_session_token: str | None = None,
+ # Deprecated Values
+ s3_config_file: str | None = None,
+ s3_config_format: str | None = None,
+ profile: str | None = None,
+ session_kwargs: dict[str, Any] | None = None,
+ **kwargs,
+ ) -> tuple[str | None, str | None, str | None]:
+ """
+ Get AWS credentials from connection login/password and extra.
+
+ ``aws_access_key_id`` and ``aws_secret_access_key`` order
+ 1. From Connection login and password
+ 2. From Connection extra['aws_access_key_id'] and extra['aws_access_key_id']
+ 3. (deprecated) Form Connection extra['session_kwargs']
+ 4. (deprecated) From local credentials file
+
+ Get ``aws_session_token`` from extra['aws_access_key_id']
+
+ """
+ session_kwargs = session_kwargs or {}
+ session_aws_access_key_id = session_kwargs.get("aws_access_key_id")
+ session_aws_secret_access_key = session_kwargs.get("aws_secret_access_key")
+ session_aws_session_token = session_kwargs.get("aws_session_token")
+
+ if self.login and self.password:
+ self.log.info("%s credentials retrieved from login and password.", self.conn_repr)
+ aws_access_key_id, aws_secret_access_key = self.login, self.password
+ elif aws_access_key_id and aws_secret_access_key:
+ self.log.info("%s credentials retrieved from extra.", self.conn_repr)
+ elif session_aws_access_key_id and session_aws_secret_access_key:
+ aws_access_key_id = session_aws_access_key_id
+ aws_secret_access_key = session_aws_secret_access_key
+ self.log.info("%s credentials retrieved from extra['session_kwargs'].", self.conn_repr)
+ elif s3_config_file:
+ aws_access_key_id, aws_secret_access_key = _parse_s3_config(
+ s3_config_file,
+ s3_config_format,
+ profile,
+ )
+ self.log.info("%s credentials retrieved from extra['s3_config_file']", self.conn_repr)
+
+ if aws_session_token:
+ self.log.info(
+ "%s session token retrieved from extra, please note you are responsible for renewing these.",
+ self.conn_repr,
+ )
+ elif session_aws_session_token:
+ aws_session_token = session_aws_session_token
+ self.log.info(
+ "%s session token retrieved from extra['session_kwargs'], "
+ "please note you are responsible for renewing these.",
+ self.conn_repr,
+ )
+
+ return aws_access_key_id, aws_secret_access_key, aws_session_token
+
+ def _get_assume_role_configs(
+ self,
+ *,
+ role_arn: str | None = None,
+ assume_role_method: str = "assume_role",
+ assume_role_kwargs: dict[str, Any] | None = None,
+ # Deprecated Values
+ aws_account_id: str | None = None,
+ aws_iam_role: str | None = None,
+ external_id: str | None = None,
+ **kwargs,
+ ) -> tuple[str | None, str | None, dict[Any, str]]:
+ """Get assume role configs from Connection extra."""
+ if role_arn:
+ self.log.debug("Retrieving role_arn=%r from %s extra.", role_arn, self.conn_repr)
+ elif aws_account_id and aws_iam_role:
+ warnings.warn(
+ "Constructing 'role_arn' from extra['aws_account_id'] and extra['aws_iam_role'] is deprecated"
+ f" and will be removed in a future releases."
+ f" Please set 'role_arn' in {self.conn_repr} extra.",
+ DeprecationWarning,
+ stacklevel=3,
+ )
+ role_arn = f"arn:aws:iam::{aws_account_id}:role/{aws_iam_role}"
+ self.log.debug(
+ "Constructions role_arn=%r from %s extra['aws_account_id'] and extra['aws_iam_role'].",
+ role_arn,
+ self.conn_repr,
+ )
+
+ if not role_arn:
+ # There is no reason obtain `assume_role_method` and `assume_role_kwargs` if `role_arn` not set.
+ return None, None, {}
+
+ supported_methods = ["assume_role", "assume_role_with_saml", "assume_role_with_web_identity"]
+ if assume_role_method not in supported_methods:
+ raise NotImplementedError(
+ f"Found assume_role_method={assume_role_method!r} in {self.conn_repr} extra."
+ f" Currently {supported_methods} are supported."
+ ' (Exclude this setting will default to "assume_role").'
+ )
+ self.log.debug("Retrieve assume_role_method=%r from %s.", assume_role_method, self.conn_repr)
+
+ assume_role_kwargs = assume_role_kwargs or {}
+ if "ExternalId" not in assume_role_kwargs and external_id:
+ warnings.warn(
+ "'external_id' in extra config is deprecated and will be removed in a future releases. "
+ f"Please set 'ExternalId' in 'assume_role_kwargs' in {self.conn_repr} extra.",
+ DeprecationWarning,
+ stacklevel=3,
+ )
+ assume_role_kwargs["ExternalId"] = external_id
+
+ return role_arn, assume_role_method, assume_role_kwargs
+
+
+def _parse_s3_config(
+ config_file_name: str, config_format: str | None = "boto", profile: str | None = None
+) -> tuple[str | None, str | None]:
+ """
+ Parses a config file for s3 credentials. Can currently
+ parse boto, s3cmd.conf and AWS SDK config formats
+
+ :param config_file_name: path to the config file
+ :param config_format: config type. One of "boto", "s3cmd" or "aws".
+ Defaults to "boto"
+ :param profile: profile name in AWS type config file
+ """
+ warnings.warn(
+ "Use local credentials file is never documented and well tested. "
+ "Obtain credentials by this way deprecated and will be removed in a future releases.",
+ DeprecationWarning,
+ stacklevel=4,
+ )
+
+ import configparser
+
+ config = configparser.ConfigParser()
+ if config.read(config_file_name): # pragma: no cover
+ sections = config.sections()
+ else:
+ raise AirflowException(f"Couldn't read {config_file_name}")
+ # Setting option names depending on file format
+ if config_format is None:
+ config_format = "boto"
+ conf_format = config_format.lower()
+ if conf_format == "boto": # pragma: no cover
+ if profile is not None and "profile " + profile in sections:
+ cred_section = "profile " + profile
+ else:
+ cred_section = "Credentials"
+ elif conf_format == "aws" and profile is not None:
+ cred_section = profile
+ else:
+ cred_section = "default"
+ # Option names
+ if conf_format in ("boto", "aws"): # pragma: no cover
+ key_id_option = "aws_access_key_id"
+ secret_key_option = "aws_secret_access_key"
+ else:
+ key_id_option = "access_key"
+ secret_key_option = "secret_key"
+ # Actual Parsing
+ if cred_section not in sections:
+ raise AirflowException("This config file format is not recognized")
+ else:
+ try:
+ access_key = config.get(cred_section, key_id_option)
+ secret_key = config.get(cred_section, secret_key_option)
+ mask_secret(secret_key)
+ except Exception:
+ raise AirflowException("Option Error in parsing s3 config file")
+ return access_key, secret_key
diff --git a/airflow/providers/amazon/aws/utils/eks_get_token.py b/airflow/providers/amazon/aws/utils/eks_get_token.py
index d9422b35e301c..27bf51676f8b1 100644
--- a/airflow/providers/amazon/aws/utils/eks_get_token.py
+++ b/airflow/providers/amazon/aws/utils/eks_get_token.py
@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
import argparse
import json
@@ -29,23 +30,23 @@
def get_expiration_time():
token_expiration = datetime.now(timezone.utc) + timedelta(minutes=TOKEN_EXPIRATION_MINUTES)
- return token_expiration.strftime('%Y-%m-%dT%H:%M:%SZ')
+ return token_expiration.strftime("%Y-%m-%dT%H:%M:%SZ")
def get_parser():
- parser = argparse.ArgumentParser(description='Get a token for authentication with an Amazon EKS cluster.')
+ parser = argparse.ArgumentParser(description="Get a token for authentication with an Amazon EKS cluster.")
parser.add_argument(
- '--cluster-name', help='The name of the cluster to generate kubeconfig file for.', required=True
+ "--cluster-name", help="The name of the cluster to generate kubeconfig file for.", required=True
)
parser.add_argument(
- '--aws-conn-id',
+ "--aws-conn-id",
help=(
- 'The Airflow connection used for AWS credentials. '
- 'If not specified or empty then the default boto3 behaviour is used.'
+ "The Airflow connection used for AWS credentials. "
+ "If not specified or empty then the default boto3 behaviour is used."
),
)
parser.add_argument(
- '--region-name', help='AWS region_name. If not specified then the default boto3 behaviour is used.'
+ "--region-name", help="AWS region_name. If not specified then the default boto3 behaviour is used."
)
return parser
@@ -66,5 +67,5 @@ def main():
print(json.dumps(exec_credential_object))
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/airflow/providers/amazon/aws/utils/emailer.py b/airflow/providers/amazon/aws/utils/emailer.py
index 7f0356a3a8d59..3e00abc78a08c 100644
--- a/airflow/providers/amazon/aws/utils/emailer.py
+++ b/airflow/providers/amazon/aws/utils/emailer.py
@@ -16,23 +16,25 @@
# specific language governing permissions and limitations
# under the License.
"""Airflow module for email backend using AWS SES"""
-from typing import Any, Dict, List, Optional, Union
+from __future__ import annotations
+
+from typing import Any
from airflow.providers.amazon.aws.hooks.ses import SesHook
def send_email(
- to: Union[List[str], str],
+ to: list[str] | str,
subject: str,
html_content: str,
- files: Optional[List] = None,
- cc: Optional[Union[List[str], str]] = None,
- bcc: Optional[Union[List[str], str]] = None,
- mime_subtype: str = 'mixed',
- mime_charset: str = 'utf-8',
- conn_id: str = 'aws_default',
- from_email: Optional[str] = None,
- custom_headers: Optional[Dict[str, Any]] = None,
+ files: list | None = None,
+ cc: list[str] | str | None = None,
+ bcc: list[str] | str | None = None,
+ mime_subtype: str = "mixed",
+ mime_charset: str = "utf-8",
+ conn_id: str = "aws_default",
+ from_email: str | None = None,
+ custom_headers: dict[str, Any] | None = None,
**kwargs,
) -> None:
"""Email backend for SES."""
diff --git a/airflow/providers/amazon/aws/utils/rds.py b/airflow/providers/amazon/aws/utils/rds.py
index 154f65b5560c1..873f2cf83ecf0 100644
--- a/airflow/providers/amazon/aws/utils/rds.py
+++ b/airflow/providers/amazon/aws/utils/rds.py
@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
from enum import Enum
diff --git a/airflow/providers/amazon/aws/utils/redshift.py b/airflow/providers/amazon/aws/utils/redshift.py
index bb64c9b46f463..d931cb047430d 100644
--- a/airflow/providers/amazon/aws/utils/redshift.py
+++ b/airflow/providers/amazon/aws/utils/redshift.py
@@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
import logging
diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml
index 413b6dcfed7ca..59277d469d37a 100644
--- a/airflow/providers/amazon/provider.yaml
+++ b/airflow/providers/amazon/provider.yaml
@@ -22,6 +22,12 @@ description: |
Amazon integration (including `Amazon Web Services (AWS) `__).
versions:
+ - 6.1.0
+ - 6.0.0
+ - 5.1.0
+ - 5.0.0
+ - 4.1.0
+ - 4.0.0
- 3.4.0
- 3.3.0
- 3.2.0
@@ -40,8 +46,22 @@ versions:
- 1.1.0
- 1.0.0
-additional-dependencies:
- - apache-airflow>=2.1.0
+dependencies:
+ - apache-airflow>=2.3.0
+ - apache-airflow-providers-common-sql>=1.3.1
+ - boto3>=1.24.0
+ # watchtower 3 has been released end Jan and introduced breaking change across the board that might
+ # change logging behaviour:
+ # https://github.com/kislyuk/watchtower/blob/develop/Changes.rst#changes-for-v300-2022-01-26
+ # TODO: update to watchtower >3
+ - watchtower~=2.0.1
+ - jsonpath_ng>=1.5.3
+ - redshift_connector>=2.0.888
+ - sqlalchemy_redshift>=0.8.6
+ - pandas>=0.17.1
+ - mypy-boto3-rds>=1.24.0
+ - mypy-boto3-redshift-data>=1.24.0
+ - mypy-boto3-appflow>=1.24.0
integrations:
- integration-name: Amazon Athena
@@ -101,6 +121,12 @@ integrations:
- /docs/apache-airflow-providers-amazon/operators/emr_eks.rst
logo: /integration-logos/aws/Amazon-EMR_light-bg@4x.png
tags: [aws]
+ - integration-name: Amazon EMR Serverless
+ external-doc-url: https://docs.aws.amazon.com/emr/latest/EMR-Serverless-UserGuide/emr-serverless.html
+ how-to-guide:
+ - /docs/apache-airflow-providers-amazon/operators/emr_serverless.rst
+ logo: /integration-logos/aws/Amazon-EMR_light-bg@4x.png
+ tags: [aws]
- integration-name: Amazon Glacier
external-doc-url: https://aws.amazon.com/glacier/
logo: /integration-logos/aws/Amazon-S3-Glacier_light-bg@4x.png
@@ -212,6 +238,12 @@ integrations:
external-doc-url: https://docs.aws.amazon.com/STS/latest/APIReference/welcome.html
logo: /integration-logos/aws/AWS-STS_light-bg@4x.png
tags: [aws]
+ - integration-name: Amazon Appflow
+ external-doc-url: https://docs.aws.amazon.com/appflow/1.0/APIReference/Welcome.html
+ logo: /integration-logos/aws/Amazon_AppFlow_light.png
+ how-to-guide:
+ - /docs/apache-airflow-providers-amazon/operators/appflow.rst
+ tags: [aws]
operators:
- integration-name: Amazon Athena
@@ -229,15 +261,8 @@ operators:
- integration-name: AWS Database Migration Service
python-modules:
- airflow.providers.amazon.aws.operators.dms
- - airflow.providers.amazon.aws.operators.dms_create_task
- - airflow.providers.amazon.aws.operators.dms_delete_task
- - airflow.providers.amazon.aws.operators.dms_describe_tasks
- - airflow.providers.amazon.aws.operators.dms_start_task
- - airflow.providers.amazon.aws.operators.dms_stop_task
- integration-name: Amazon EC2
python-modules:
- - airflow.providers.amazon.aws.operators.ec2_start_instance
- - airflow.providers.amazon.aws.operators.ec2_stop_instance
- airflow.providers.amazon.aws.operators.ec2
- integration-name: Amazon ECS
python-modules:
@@ -248,14 +273,9 @@ operators:
- integration-name: Amazon EMR
python-modules:
- airflow.providers.amazon.aws.operators.emr
- - airflow.providers.amazon.aws.operators.emr_add_steps
- - airflow.providers.amazon.aws.operators.emr_create_job_flow
- - airflow.providers.amazon.aws.operators.emr_modify_cluster
- - airflow.providers.amazon.aws.operators.emr_terminate_job_flow
- integration-name: Amazon EMR on EKS
python-modules:
- airflow.providers.amazon.aws.operators.emr
- - airflow.providers.amazon.aws.operators.emr_containers
- integration-name: Amazon Glacier
python-modules:
- airflow.providers.amazon.aws.operators.glacier
@@ -266,27 +286,13 @@ operators:
- integration-name: AWS Lambda
python-modules:
- airflow.providers.amazon.aws.operators.aws_lambda
+ - airflow.providers.amazon.aws.operators.lambda_function
- integration-name: Amazon Simple Storage Service (S3)
python-modules:
- - airflow.providers.amazon.aws.operators.s3_bucket
- - airflow.providers.amazon.aws.operators.s3_bucket_tagging
- - airflow.providers.amazon.aws.operators.s3_copy_object
- - airflow.providers.amazon.aws.operators.s3_delete_objects
- - airflow.providers.amazon.aws.operators.s3_file_transform
- - airflow.providers.amazon.aws.operators.s3_list
- - airflow.providers.amazon.aws.operators.s3_list_prefixes
- airflow.providers.amazon.aws.operators.s3
- integration-name: Amazon SageMaker
python-modules:
- airflow.providers.amazon.aws.operators.sagemaker
- - airflow.providers.amazon.aws.operators.sagemaker_base
- - airflow.providers.amazon.aws.operators.sagemaker_endpoint
- - airflow.providers.amazon.aws.operators.sagemaker_endpoint_config
- - airflow.providers.amazon.aws.operators.sagemaker_model
- - airflow.providers.amazon.aws.operators.sagemaker_processing
- - airflow.providers.amazon.aws.operators.sagemaker_training
- - airflow.providers.amazon.aws.operators.sagemaker_transform
- - airflow.providers.amazon.aws.operators.sagemaker_tuning
- integration-name: Amazon Simple Notification Service (SNS)
python-modules:
- airflow.providers.amazon.aws.operators.sns
@@ -295,21 +301,21 @@ operators:
- airflow.providers.amazon.aws.operators.sqs
- integration-name: AWS Step Functions
python-modules:
- - airflow.providers.amazon.aws.operators.step_function_get_execution_output
- - airflow.providers.amazon.aws.operators.step_function_start_execution
- airflow.providers.amazon.aws.operators.step_function
- integration-name: Amazon RDS
python-modules:
- airflow.providers.amazon.aws.operators.rds
- integration-name: Amazon Redshift
python-modules:
- - airflow.providers.amazon.aws.operators.redshift
- airflow.providers.amazon.aws.operators.redshift_sql
- airflow.providers.amazon.aws.operators.redshift_cluster
- airflow.providers.amazon.aws.operators.redshift_data
- integration-name: Amazon QuickSight
python-modules:
- airflow.providers.amazon.aws.operators.quicksight
+ - integration-name: Amazon Appflow
+ python-modules:
+ - airflow.providers.amazon.aws.operators.appflow
sensors:
- integration-name: Amazon Athena
@@ -323,25 +329,22 @@ sensors:
- airflow.providers.amazon.aws.sensors.cloud_formation
- integration-name: AWS Database Migration Service
python-modules:
- - airflow.providers.amazon.aws.sensors.dms_task
- airflow.providers.amazon.aws.sensors.dms
- integration-name: Amazon EC2
python-modules:
- - airflow.providers.amazon.aws.sensors.ec2_instance_state
- airflow.providers.amazon.aws.sensors.ec2
+ - integration-name: Amazon ECS
+ python-modules:
+ - airflow.providers.amazon.aws.sensors.ecs
- integration-name: Amazon Elastic Kubernetes Service (EKS)
python-modules:
- airflow.providers.amazon.aws.sensors.eks
- integration-name: Amazon EMR
python-modules:
- airflow.providers.amazon.aws.sensors.emr
- - airflow.providers.amazon.aws.sensors.emr_base
- - airflow.providers.amazon.aws.sensors.emr_job_flow
- - airflow.providers.amazon.aws.sensors.emr_step
- integration-name: Amazon EMR on EKS
python-modules:
- airflow.providers.amazon.aws.sensors.emr
- - airflow.providers.amazon.aws.sensors.emr_containers
- integration-name: Amazon Glacier
python-modules:
- airflow.providers.amazon.aws.sensors.glacier
@@ -355,28 +358,18 @@ sensors:
- airflow.providers.amazon.aws.sensors.rds
- integration-name: Amazon Redshift
python-modules:
- - airflow.providers.amazon.aws.sensors.redshift
- airflow.providers.amazon.aws.sensors.redshift_cluster
- integration-name: Amazon Simple Storage Service (S3)
python-modules:
- - airflow.providers.amazon.aws.sensors.s3_key
- - airflow.providers.amazon.aws.sensors.s3_keys_unchanged
- - airflow.providers.amazon.aws.sensors.s3_prefix
- airflow.providers.amazon.aws.sensors.s3
- integration-name: Amazon SageMaker
python-modules:
- airflow.providers.amazon.aws.sensors.sagemaker
- - airflow.providers.amazon.aws.sensors.sagemaker_base
- - airflow.providers.amazon.aws.sensors.sagemaker_endpoint
- - airflow.providers.amazon.aws.sensors.sagemaker_training
- - airflow.providers.amazon.aws.sensors.sagemaker_transform
- - airflow.providers.amazon.aws.sensors.sagemaker_tuning
- integration-name: Amazon Simple Queue Service (SQS)
python-modules:
- airflow.providers.amazon.aws.sensors.sqs
- integration-name: AWS Step Functions
python-modules:
- - airflow.providers.amazon.aws.sensors.step_function_execution
- airflow.providers.amazon.aws.sensors.step_function
- integration-name: Amazon QuickSight
python-modules:
@@ -389,7 +382,6 @@ hooks:
- integration-name: Amazon DynamoDB
python-modules:
- airflow.providers.amazon.aws.hooks.dynamodb
- - airflow.providers.amazon.aws.hooks.aws_dynamodb
- integration-name: Amazon Web Services
python-modules:
- airflow.providers.amazon.aws.hooks.base_aws
@@ -409,6 +401,9 @@ hooks:
- integration-name: Amazon EC2
python-modules:
- airflow.providers.amazon.aws.hooks.ec2
+ - integration-name: Amazon ECS
+ python-modules:
+ - airflow.providers.amazon.aws.hooks.ecs
- integration-name: Amazon ElastiCache
python-modules:
- airflow.providers.amazon.aws.hooks.elasticache_replication_group
@@ -420,7 +415,7 @@ hooks:
- airflow.providers.amazon.aws.hooks.emr
- integration-name: Amazon EMR on EKS
python-modules:
- - airflow.providers.amazon.aws.hooks.emr_containers
+ - airflow.providers.amazon.aws.hooks.emr
- integration-name: Amazon Glacier
python-modules:
- airflow.providers.amazon.aws.hooks.glacier
@@ -443,7 +438,6 @@ hooks:
- airflow.providers.amazon.aws.hooks.rds
- integration-name: Amazon Redshift
python-modules:
- - airflow.providers.amazon.aws.hooks.redshift
- airflow.providers.amazon.aws.hooks.redshift_sql
- airflow.providers.amazon.aws.hooks.redshift_cluster
- airflow.providers.amazon.aws.hooks.redshift_data
@@ -474,6 +468,9 @@ hooks:
- integration-name: AWS Security Token Service (STS)
python-modules:
- airflow.providers.amazon.aws.hooks.sts
+ - integration-name: Amazon Appflow
+ python-modules:
+ - airflow.providers.amazon.aws.hooks.appflow
transfers:
- source-integration-name: Amazon DynamoDB
@@ -504,10 +501,6 @@ transfers:
target-integration-name: Amazon Simple Storage Service (S3)
how-to-guide: /docs/apache-airflow-providers-amazon/operators/transfer/mongo_to_s3.rst
python-module: airflow.providers.amazon.aws.transfers.mongo_to_s3
- - source-integration-name: MySQL
- target-integration-name: Amazon Simple Storage Service (S3)
- how-to-guide: /docs/apache-airflow-providers-amazon/operators/transfer/sql_to_s3.rst
- python-module: airflow.providers.amazon.aws.transfers.mysql_to_s3
- source-integration-name: Amazon Redshift
target-integration-name: Amazon Simple Storage Service (S3)
how-to-guide: /docs/apache-airflow-providers-amazon/operators/transfer/redshift_to_s3.rst
@@ -543,32 +536,26 @@ transfers:
target-integration-name: Amazon Simple Storage Service (S3)
how-to-guide: /docs/apache-airflow-providers-amazon/operators/transfer/local_to_s3.rst
python-module: airflow.providers.amazon.aws.transfers.local_to_s3
- - source-integration-name: SQL
+ - source-integration-name: Common SQL
target-integration-name: Amazon Simple Storage Service (S3)
how-to-guide: /docs/apache-airflow-providers-amazon/operators/transfer/sql_to_s3.rst
python-module: airflow.providers.amazon.aws.transfers.sql_to_s3
-hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+
- - airflow.providers.amazon.aws.hooks.s3.S3Hook
- - airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook
- - airflow.providers.amazon.aws.hooks.emr.EmrHook
- - airflow.providers.amazon.aws.hooks.redshift_sql.RedshiftSQLHook
extra-links:
- - airflow.providers.amazon.aws.operators.emr.EmrClusterLink
- - airflow.providers.amazon.aws.operators.emr_create_job_flow.EmrClusterLink
+ - airflow.providers.amazon.aws.links.batch.BatchJobDefinitionLink
+ - airflow.providers.amazon.aws.links.batch.BatchJobDetailsLink
+ - airflow.providers.amazon.aws.links.batch.BatchJobQueueLink
+ - airflow.providers.amazon.aws.links.emr.EmrClusterLink
+ - airflow.providers.amazon.aws.links.logs.CloudWatchEventsLink
connection-types:
- - hook-class-name: airflow.providers.amazon.aws.hooks.s3.S3Hook
- connection-type: s3
- - hook-class-name: airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook
+ - hook-class-name: airflow.providers.amazon.aws.hooks.base_aws.AwsGenericHook
connection-type: aws
- hook-class-name: airflow.providers.amazon.aws.hooks.emr.EmrHook
connection-type: emr
- hook-class-name: airflow.providers.amazon.aws.hooks.redshift_sql.RedshiftSQLHook
connection-type: redshift
- - hook-class-name: airflow.providers.amazon.aws.hooks.redshift.RedshiftDataHook
- connection-type: aws
secrets-backends:
- airflow.providers.amazon.aws.secrets.secrets_manager.SecretsManagerBackend
diff --git a/airflow/providers/apache/beam/.latest-doc-only-change.txt b/airflow/providers/apache/beam/.latest-doc-only-change.txt
index 28124098645cf..ff7136e07d744 100644
--- a/airflow/providers/apache/beam/.latest-doc-only-change.txt
+++ b/airflow/providers/apache/beam/.latest-doc-only-change.txt
@@ -1 +1 @@
-6c3a67d4fccafe4ab6cd9ec8c7bacf2677f17038
+06acf40a4337759797f666d5bb27a5a393b74fed
diff --git a/airflow/providers/apache/beam/CHANGELOG.rst b/airflow/providers/apache/beam/CHANGELOG.rst
index 8969c402ce18f..cfdad5d2fa201 100644
--- a/airflow/providers/apache/beam/CHANGELOG.rst
+++ b/airflow/providers/apache/beam/CHANGELOG.rst
@@ -16,9 +16,67 @@
under the License.
+.. NOTE TO CONTRIBUTORS:
+ Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes
+ and you want to add an explanation to the users on how they are supposed to deal with them.
+ The changelog is updated and maintained semi-automatically by release manager.
+
Changelog
---------
+4.1.0
+.....
+
+This release of provider is only available for Airflow 2.3+ as explained in the
+`Apache Airflow providers support policy `_.
+
+Misc
+~~~~
+
+* ``Move min airflow version to 2.3.0 for all providers (#27196)``
+
+Features
+~~~~~~~~
+
+* ``Add backward compatibility with old versions of Apache Beam (#27263)``
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Add documentation for July 2022 Provider's release (#25030)``
+ * ``Update old style typing (#26872)``
+ * ``Enable string normalization in python formatting - providers (#27205)``
+ * ``Update docs for September Provider's release (#26731)``
+ * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)``
+ * ``Prepare docs for new providers release (August 2022) (#25618)``
+ * ``Move provider dependencies to inside provider folders (#24672)``
+
+4.0.0
+.....
+
+Breaking changes
+~~~~~~~~~~~~~~~~
+
+* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow
+ providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers
+
+Features
+~~~~~~~~
+
+* ``Added missing project_id to the wait_for_job (#24020)``
+* ``Support impersonation service account parameter for Dataflow runner (#23961)``
+
+Misc
+~~~~
+
+* ``chore: Refactoring and Cleaning Apache Providers (#24219)``
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Add explanatory note for contributors about updating Changelog (#24229)``
+ * ``AIP-47 - Migrate beam DAGs to new design #22439 (#24211)``
+ * ``Prepare docs for May 2022 provider's release (#24231)``
+ * ``Update package description to remove double min-airflow specification (#24292)``
+
3.4.0
.....
diff --git a/airflow/providers/apache/beam/example_dags/example_beam.py b/airflow/providers/apache/beam/example_dags/example_beam.py
deleted file mode 100644
index ea52458303129..0000000000000
--- a/airflow/providers/apache/beam/example_dags/example_beam.py
+++ /dev/null
@@ -1,437 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""
-Example Airflow DAG for Apache Beam operators
-"""
-import os
-from datetime import datetime
-from urllib.parse import urlparse
-
-from airflow import models
-from airflow.providers.apache.beam.operators.beam import (
- BeamRunGoPipelineOperator,
- BeamRunJavaPipelineOperator,
- BeamRunPythonPipelineOperator,
-)
-from airflow.providers.google.cloud.hooks.dataflow import DataflowJobStatus
-from airflow.providers.google.cloud.operators.dataflow import DataflowConfiguration
-from airflow.providers.google.cloud.sensors.dataflow import DataflowJobStatusSensor
-from airflow.providers.google.cloud.transfers.gcs_to_local import GCSToLocalFilesystemOperator
-from airflow.utils.trigger_rule import TriggerRule
-
-GCP_PROJECT_ID = os.environ.get('GCP_PROJECT_ID', 'example-project')
-GCS_INPUT = os.environ.get('APACHE_BEAM_PYTHON', 'gs://INVALID BUCKET NAME/shakespeare/kinglear.txt')
-GCS_TMP = os.environ.get('APACHE_BEAM_GCS_TMP', 'gs://INVALID BUCKET NAME/temp/')
-GCS_STAGING = os.environ.get('APACHE_BEAM_GCS_STAGING', 'gs://INVALID BUCKET NAME/staging/')
-GCS_OUTPUT = os.environ.get('APACHE_BEAM_GCS_OUTPUT', 'gs://INVALID BUCKET NAME/output')
-GCS_PYTHON = os.environ.get('APACHE_BEAM_PYTHON', 'gs://INVALID BUCKET NAME/wordcount_debugging.py')
-GCS_PYTHON_DATAFLOW_ASYNC = os.environ.get(
- 'APACHE_BEAM_PYTHON_DATAFLOW_ASYNC', 'gs://INVALID BUCKET NAME/wordcount_debugging.py'
-)
-GCS_GO = os.environ.get('APACHE_BEAM_GO', 'gs://INVALID BUCKET NAME/wordcount_debugging.go')
-GCS_GO_DATAFLOW_ASYNC = os.environ.get(
- 'APACHE_BEAM_GO_DATAFLOW_ASYNC', 'gs://INVALID BUCKET NAME/wordcount_debugging.go'
-)
-GCS_JAR_DIRECT_RUNNER = os.environ.get(
- 'APACHE_BEAM_DIRECT_RUNNER_JAR',
- 'gs://INVALID BUCKET NAME/tests/dataflow-templates-bundled-java=11-beam-v2.25.0-DirectRunner.jar',
-)
-GCS_JAR_DATAFLOW_RUNNER = os.environ.get(
- 'APACHE_BEAM_DATAFLOW_RUNNER_JAR', 'gs://INVALID BUCKET NAME/word-count-beam-bundled-0.1.jar'
-)
-GCS_JAR_SPARK_RUNNER = os.environ.get(
- 'APACHE_BEAM_SPARK_RUNNER_JAR',
- 'gs://INVALID BUCKET NAME/tests/dataflow-templates-bundled-java=11-beam-v2.25.0-SparkRunner.jar',
-)
-GCS_JAR_FLINK_RUNNER = os.environ.get(
- 'APACHE_BEAM_FLINK_RUNNER_JAR',
- 'gs://INVALID BUCKET NAME/tests/dataflow-templates-bundled-java=11-beam-v2.25.0-FlinkRunner.jar',
-)
-
-GCS_JAR_DIRECT_RUNNER_PARTS = urlparse(GCS_JAR_DIRECT_RUNNER)
-GCS_JAR_DIRECT_RUNNER_BUCKET_NAME = GCS_JAR_DIRECT_RUNNER_PARTS.netloc
-GCS_JAR_DIRECT_RUNNER_OBJECT_NAME = GCS_JAR_DIRECT_RUNNER_PARTS.path[1:]
-GCS_JAR_DATAFLOW_RUNNER_PARTS = urlparse(GCS_JAR_DATAFLOW_RUNNER)
-GCS_JAR_DATAFLOW_RUNNER_BUCKET_NAME = GCS_JAR_DATAFLOW_RUNNER_PARTS.netloc
-GCS_JAR_DATAFLOW_RUNNER_OBJECT_NAME = GCS_JAR_DATAFLOW_RUNNER_PARTS.path[1:]
-GCS_JAR_SPARK_RUNNER_PARTS = urlparse(GCS_JAR_SPARK_RUNNER)
-GCS_JAR_SPARK_RUNNER_BUCKET_NAME = GCS_JAR_SPARK_RUNNER_PARTS.netloc
-GCS_JAR_SPARK_RUNNER_OBJECT_NAME = GCS_JAR_SPARK_RUNNER_PARTS.path[1:]
-GCS_JAR_FLINK_RUNNER_PARTS = urlparse(GCS_JAR_FLINK_RUNNER)
-GCS_JAR_FLINK_RUNNER_BUCKET_NAME = GCS_JAR_FLINK_RUNNER_PARTS.netloc
-GCS_JAR_FLINK_RUNNER_OBJECT_NAME = GCS_JAR_FLINK_RUNNER_PARTS.path[1:]
-
-
-DEFAULT_ARGS = {
- 'default_pipeline_options': {'output': '/tmp/example_beam'},
- 'trigger_rule': TriggerRule.ALL_DONE,
-}
-START_DATE = datetime(2021, 1, 1)
-
-
-with models.DAG(
- "example_beam_native_java_direct_runner",
- schedule_interval=None, # Override to match your needs
- start_date=START_DATE,
- catchup=False,
- tags=['example'],
-) as dag_native_java_direct_runner:
-
- # [START howto_operator_start_java_direct_runner_pipeline]
- jar_to_local_direct_runner = GCSToLocalFilesystemOperator(
- task_id="jar_to_local_direct_runner",
- bucket=GCS_JAR_DIRECT_RUNNER_BUCKET_NAME,
- object_name=GCS_JAR_DIRECT_RUNNER_OBJECT_NAME,
- filename="/tmp/beam_wordcount_direct_runner_{{ ds_nodash }}.jar",
- )
-
- start_java_pipeline_direct_runner = BeamRunJavaPipelineOperator(
- task_id="start_java_pipeline_direct_runner",
- jar="/tmp/beam_wordcount_direct_runner_{{ ds_nodash }}.jar",
- pipeline_options={
- 'output': '/tmp/start_java_pipeline_direct_runner',
- 'inputFile': GCS_INPUT,
- },
- job_class='org.apache.beam.examples.WordCount',
- )
-
- jar_to_local_direct_runner >> start_java_pipeline_direct_runner
- # [END howto_operator_start_java_direct_runner_pipeline]
-
-with models.DAG(
- "example_beam_native_java_dataflow_runner",
- schedule_interval=None, # Override to match your needs
- start_date=START_DATE,
- catchup=False,
- tags=['example'],
-) as dag_native_java_dataflow_runner:
- # [START howto_operator_start_java_dataflow_runner_pipeline]
- jar_to_local_dataflow_runner = GCSToLocalFilesystemOperator(
- task_id="jar_to_local_dataflow_runner",
- bucket=GCS_JAR_DATAFLOW_RUNNER_BUCKET_NAME,
- object_name=GCS_JAR_DATAFLOW_RUNNER_OBJECT_NAME,
- filename="/tmp/beam_wordcount_dataflow_runner_{{ ds_nodash }}.jar",
- )
-
- start_java_pipeline_dataflow = BeamRunJavaPipelineOperator(
- task_id="start_java_pipeline_dataflow",
- runner="DataflowRunner",
- jar="/tmp/beam_wordcount_dataflow_runner_{{ ds_nodash }}.jar",
- pipeline_options={
- 'tempLocation': GCS_TMP,
- 'stagingLocation': GCS_STAGING,
- 'output': GCS_OUTPUT,
- },
- job_class='org.apache.beam.examples.WordCount',
- dataflow_config={"job_name": "{{task.task_id}}", "location": "us-central1"},
- )
-
- jar_to_local_dataflow_runner >> start_java_pipeline_dataflow
- # [END howto_operator_start_java_dataflow_runner_pipeline]
-
-with models.DAG(
- "example_beam_native_java_spark_runner",
- schedule_interval=None, # Override to match your needs
- start_date=START_DATE,
- catchup=False,
- tags=['example'],
-) as dag_native_java_spark_runner:
-
- jar_to_local_spark_runner = GCSToLocalFilesystemOperator(
- task_id="jar_to_local_spark_runner",
- bucket=GCS_JAR_SPARK_RUNNER_BUCKET_NAME,
- object_name=GCS_JAR_SPARK_RUNNER_OBJECT_NAME,
- filename="/tmp/beam_wordcount_spark_runner_{{ ds_nodash }}.jar",
- )
-
- start_java_pipeline_spark_runner = BeamRunJavaPipelineOperator(
- task_id="start_java_pipeline_spark_runner",
- runner="SparkRunner",
- jar="/tmp/beam_wordcount_spark_runner_{{ ds_nodash }}.jar",
- pipeline_options={
- 'output': '/tmp/start_java_pipeline_spark_runner',
- 'inputFile': GCS_INPUT,
- },
- job_class='org.apache.beam.examples.WordCount',
- )
-
- jar_to_local_spark_runner >> start_java_pipeline_spark_runner
-
-with models.DAG(
- "example_beam_native_java_flink_runner",
- schedule_interval=None, # Override to match your needs
- start_date=START_DATE,
- catchup=False,
- tags=['example'],
-) as dag_native_java_flink_runner:
-
- jar_to_local_flink_runner = GCSToLocalFilesystemOperator(
- task_id="jar_to_local_flink_runner",
- bucket=GCS_JAR_FLINK_RUNNER_BUCKET_NAME,
- object_name=GCS_JAR_FLINK_RUNNER_OBJECT_NAME,
- filename="/tmp/beam_wordcount_flink_runner_{{ ds_nodash }}.jar",
- )
-
- start_java_pipeline_flink_runner = BeamRunJavaPipelineOperator(
- task_id="start_java_pipeline_flink_runner",
- runner="FlinkRunner",
- jar="/tmp/beam_wordcount_flink_runner_{{ ds_nodash }}.jar",
- pipeline_options={
- 'output': '/tmp/start_java_pipeline_flink_runner',
- 'inputFile': GCS_INPUT,
- },
- job_class='org.apache.beam.examples.WordCount',
- )
-
- jar_to_local_flink_runner >> start_java_pipeline_flink_runner
-
-
-with models.DAG(
- "example_beam_native_python",
- start_date=START_DATE,
- schedule_interval=None, # Override to match your needs
- catchup=False,
- default_args=DEFAULT_ARGS,
- tags=['example'],
-) as dag_native_python:
-
- # [START howto_operator_start_python_direct_runner_pipeline_local_file]
- start_python_pipeline_local_direct_runner = BeamRunPythonPipelineOperator(
- task_id="start_python_pipeline_local_direct_runner",
- py_file='apache_beam.examples.wordcount',
- py_options=['-m'],
- py_requirements=['apache-beam[gcp]==2.26.0'],
- py_interpreter='python3',
- py_system_site_packages=False,
- )
- # [END howto_operator_start_python_direct_runner_pipeline_local_file]
-
- # [START howto_operator_start_python_direct_runner_pipeline_gcs_file]
- start_python_pipeline_direct_runner = BeamRunPythonPipelineOperator(
- task_id="start_python_pipeline_direct_runner",
- py_file=GCS_PYTHON,
- py_options=[],
- pipeline_options={"output": GCS_OUTPUT},
- py_requirements=['apache-beam[gcp]==2.26.0'],
- py_interpreter='python3',
- py_system_site_packages=False,
- )
- # [END howto_operator_start_python_direct_runner_pipeline_gcs_file]
-
- # [START howto_operator_start_python_dataflow_runner_pipeline_gcs_file]
- start_python_pipeline_dataflow_runner = BeamRunPythonPipelineOperator(
- task_id="start_python_pipeline_dataflow_runner",
- runner="DataflowRunner",
- py_file=GCS_PYTHON,
- pipeline_options={
- 'tempLocation': GCS_TMP,
- 'stagingLocation': GCS_STAGING,
- 'output': GCS_OUTPUT,
- },
- py_options=[],
- py_requirements=['apache-beam[gcp]==2.26.0'],
- py_interpreter='python3',
- py_system_site_packages=False,
- dataflow_config=DataflowConfiguration(
- job_name='{{task.task_id}}', project_id=GCP_PROJECT_ID, location="us-central1"
- ),
- )
- # [END howto_operator_start_python_dataflow_runner_pipeline_gcs_file]
-
- start_python_pipeline_local_spark_runner = BeamRunPythonPipelineOperator(
- task_id="start_python_pipeline_local_spark_runner",
- py_file='apache_beam.examples.wordcount',
- runner="SparkRunner",
- py_options=['-m'],
- py_requirements=['apache-beam[gcp]==2.26.0'],
- py_interpreter='python3',
- py_system_site_packages=False,
- )
-
- start_python_pipeline_local_flink_runner = BeamRunPythonPipelineOperator(
- task_id="start_python_pipeline_local_flink_runner",
- py_file='apache_beam.examples.wordcount',
- runner="FlinkRunner",
- py_options=['-m'],
- pipeline_options={
- 'output': '/tmp/start_python_pipeline_local_flink_runner',
- },
- py_requirements=['apache-beam[gcp]==2.26.0'],
- py_interpreter='python3',
- py_system_site_packages=False,
- )
-
- (
- [
- start_python_pipeline_local_direct_runner,
- start_python_pipeline_direct_runner,
- ]
- >> start_python_pipeline_local_flink_runner
- >> start_python_pipeline_local_spark_runner
- )
-
-
-with models.DAG(
- "example_beam_native_python_dataflow_async",
- default_args=DEFAULT_ARGS,
- start_date=START_DATE,
- schedule_interval=None, # Override to match your needs
- catchup=False,
- tags=['example'],
-) as dag_native_python_dataflow_async:
- # [START howto_operator_start_python_dataflow_runner_pipeline_async_gcs_file]
- start_python_job_dataflow_runner_async = BeamRunPythonPipelineOperator(
- task_id="start_python_job_dataflow_runner_async",
- runner="DataflowRunner",
- py_file=GCS_PYTHON_DATAFLOW_ASYNC,
- pipeline_options={
- 'tempLocation': GCS_TMP,
- 'stagingLocation': GCS_STAGING,
- 'output': GCS_OUTPUT,
- },
- py_options=[],
- py_requirements=['apache-beam[gcp]==2.26.0'],
- py_interpreter='python3',
- py_system_site_packages=False,
- dataflow_config=DataflowConfiguration(
- job_name='{{task.task_id}}',
- project_id=GCP_PROJECT_ID,
- location="us-central1",
- wait_until_finished=False,
- ),
- )
-
- wait_for_python_job_dataflow_runner_async_done = DataflowJobStatusSensor(
- task_id="wait-for-python-job-async-done",
- job_id="{{task_instance.xcom_pull('start_python_job_dataflow_runner_async')['dataflow_job_id']}}",
- expected_statuses={DataflowJobStatus.JOB_STATE_DONE},
- project_id=GCP_PROJECT_ID,
- location='us-central1',
- )
-
- start_python_job_dataflow_runner_async >> wait_for_python_job_dataflow_runner_async_done
- # [END howto_operator_start_python_dataflow_runner_pipeline_async_gcs_file]
-
-
-with models.DAG(
- "example_beam_native_go",
- start_date=START_DATE,
- schedule_interval="@once",
- catchup=False,
- default_args=DEFAULT_ARGS,
- tags=['example'],
-) as dag_native_go:
-
- # [START howto_operator_start_go_direct_runner_pipeline_local_file]
- start_go_pipeline_local_direct_runner = BeamRunGoPipelineOperator(
- task_id="start_go_pipeline_local_direct_runner",
- go_file='files/apache_beam/examples/wordcount.go',
- )
- # [END howto_operator_start_go_direct_runner_pipeline_local_file]
-
- # [START howto_operator_start_go_direct_runner_pipeline_gcs_file]
- start_go_pipeline_direct_runner = BeamRunGoPipelineOperator(
- task_id="start_go_pipeline_direct_runner",
- go_file=GCS_GO,
- pipeline_options={"output": GCS_OUTPUT},
- )
- # [END howto_operator_start_go_direct_runner_pipeline_gcs_file]
-
- # [START howto_operator_start_go_dataflow_runner_pipeline_gcs_file]
- start_go_pipeline_dataflow_runner = BeamRunGoPipelineOperator(
- task_id="start_go_pipeline_dataflow_runner",
- runner="DataflowRunner",
- go_file=GCS_GO,
- pipeline_options={
- 'tempLocation': GCS_TMP,
- 'stagingLocation': GCS_STAGING,
- 'output': GCS_OUTPUT,
- 'WorkerHarnessContainerImage': "apache/beam_go_sdk:latest",
- },
- dataflow_config=DataflowConfiguration(
- job_name='{{task.task_id}}', project_id=GCP_PROJECT_ID, location="us-central1"
- ),
- )
- # [END howto_operator_start_go_dataflow_runner_pipeline_gcs_file]
-
- start_go_pipeline_local_spark_runner = BeamRunGoPipelineOperator(
- task_id="start_go_pipeline_local_spark_runner",
- go_file='/files/apache_beam/examples/wordcount.go',
- runner="SparkRunner",
- pipeline_options={
- 'endpoint': '/your/spark/endpoint',
- },
- )
-
- start_go_pipeline_local_flink_runner = BeamRunGoPipelineOperator(
- task_id="start_go_pipeline_local_flink_runner",
- go_file='/files/apache_beam/examples/wordcount.go',
- runner="FlinkRunner",
- pipeline_options={
- 'output': '/tmp/start_go_pipeline_local_flink_runner',
- },
- )
-
- (
- [
- start_go_pipeline_local_direct_runner,
- start_go_pipeline_direct_runner,
- ]
- >> start_go_pipeline_local_flink_runner
- >> start_go_pipeline_local_spark_runner
- )
-
-
-with models.DAG(
- "example_beam_native_go_dataflow_async",
- default_args=DEFAULT_ARGS,
- start_date=START_DATE,
- schedule_interval="@once",
- catchup=False,
- tags=['example'],
-) as dag_native_go_dataflow_async:
- # [START howto_operator_start_go_dataflow_runner_pipeline_async_gcs_file]
- start_go_job_dataflow_runner_async = BeamRunGoPipelineOperator(
- task_id="start_go_job_dataflow_runner_async",
- runner="DataflowRunner",
- go_file=GCS_GO_DATAFLOW_ASYNC,
- pipeline_options={
- 'tempLocation': GCS_TMP,
- 'stagingLocation': GCS_STAGING,
- 'output': GCS_OUTPUT,
- 'WorkerHarnessContainerImage': "apache/beam_go_sdk:latest",
- },
- dataflow_config=DataflowConfiguration(
- job_name='{{task.task_id}}',
- project_id=GCP_PROJECT_ID,
- location="us-central1",
- wait_until_finished=False,
- ),
- )
-
- wait_for_go_job_dataflow_runner_async_done = DataflowJobStatusSensor(
- task_id="wait-for-go-job-async-done",
- job_id="{{task_instance.xcom_pull('start_go_job_dataflow_runner_async')['dataflow_job_id']}}",
- expected_statuses={DataflowJobStatus.JOB_STATE_DONE},
- project_id=GCP_PROJECT_ID,
- location='us-central1',
- )
-
- start_go_job_dataflow_runner_async >> wait_for_go_job_dataflow_runner_async_done
- # [END howto_operator_start_go_dataflow_runner_pipeline_async_gcs_file]
diff --git a/airflow/providers/apache/beam/hooks/beam.py b/airflow/providers/apache/beam/hooks/beam.py
index 0644e02b625f0..28a5abc0c6e3e 100644
--- a/airflow/providers/apache/beam/hooks/beam.py
+++ b/airflow/providers/apache/beam/hooks/beam.py
@@ -16,6 +16,9 @@
# specific language governing permissions and limitations
# under the License.
"""This module contains a Apache Beam Hook."""
+from __future__ import annotations
+
+import contextlib
import json
import os
import select
@@ -24,7 +27,9 @@
import subprocess
import textwrap
from tempfile import TemporaryDirectory
-from typing import Callable, List, Optional
+from typing import Callable
+
+from packaging.version import Version
from airflow.exceptions import AirflowConfigException, AirflowException
from airflow.hooks.base import BaseHook
@@ -50,7 +55,7 @@ class BeamRunnerType:
Twister2Runner = "Twister2Runner"
-def beam_options_to_args(options: dict) -> List[str]:
+def beam_options_to_args(options: dict) -> list[str]:
"""
Returns a formatted pipeline options from a dictionary of arguments
@@ -60,12 +65,11 @@ def beam_options_to_args(options: dict) -> List[str]:
:param options: Dictionary with options
:return: List of arguments
- :rtype: List[str]
"""
if not options:
return []
- args: List[str] = []
+ args: list[str] = []
for attr, value in options.items():
if value is None or (isinstance(value, bool) and value):
args.append(f"--{attr}")
@@ -88,14 +92,14 @@ class BeamCommandRunner(LoggingMixin):
def __init__(
self,
- cmd: List[str],
- process_line_callback: Optional[Callable[[str], None]] = None,
- working_directory: Optional[str] = None,
+ cmd: list[str],
+ process_line_callback: Callable[[str], None] | None = None,
+ working_directory: str | None = None,
) -> None:
super().__init__()
self.log.info("Running command: %s", " ".join(shlex.quote(c) for c in cmd))
self.process_line_callback = process_line_callback
- self.job_id: Optional[str] = None
+ self.job_id: str | None = None
self._proc = subprocess.Popen(
cmd,
@@ -173,9 +177,9 @@ def __init__(
def _start_pipeline(
self,
variables: dict,
- command_prefix: List[str],
- process_line_callback: Optional[Callable[[str], None]] = None,
- working_directory: Optional[str] = None,
+ command_prefix: list[str],
+ process_line_callback: Callable[[str], None] | None = None,
+ working_directory: str | None = None,
) -> None:
cmd = command_prefix + [
f"--runner={self.runner}",
@@ -193,11 +197,11 @@ def start_python_pipeline(
self,
variables: dict,
py_file: str,
- py_options: List[str],
+ py_options: list[str],
py_interpreter: str = "python3",
- py_requirements: Optional[List[str]] = None,
+ py_requirements: list[str] | None = None,
py_system_site_packages: bool = False,
- process_line_callback: Optional[Callable[[str], None]] = None,
+ process_line_callback: Callable[[str], None] | None = None,
):
"""
Starts Apache Beam python pipeline.
@@ -225,37 +229,47 @@ def start_python_pipeline(
if "labels" in variables:
variables["labels"] = [f"{key}={value}" for key, value in variables["labels"].items()]
- if py_requirements is not None:
- if not py_requirements and not py_system_site_packages:
- warning_invalid_environment = textwrap.dedent(
- """\
- Invalid method invocation. You have disabled inclusion of system packages and empty list
- required for installation, so it is not possible to create a valid virtual environment.
- In the virtual environment, apache-beam package must be installed for your job to be \
- executed. To fix this problem:
- * install apache-beam on the system, then set parameter py_system_site_packages to True,
- * add apache-beam to the list of required packages in parameter py_requirements.
- """
- )
- raise AirflowException(warning_invalid_environment)
-
- with TemporaryDirectory(prefix="apache-beam-venv") as tmp_dir:
+ with contextlib.ExitStack() as exit_stack:
+ if py_requirements is not None:
+ if not py_requirements and not py_system_site_packages:
+ warning_invalid_environment = textwrap.dedent(
+ """\
+ Invalid method invocation. You have disabled inclusion of system packages and empty
+ list required for installation, so it is not possible to create a valid virtual
+ environment. In the virtual environment, apache-beam package must be installed for
+ your job to be executed.
+
+ To fix this problem:
+ * install apache-beam on the system, then set parameter py_system_site_packages
+ to True,
+ * add apache-beam to the list of required packages in parameter py_requirements.
+ """
+ )
+ raise AirflowException(warning_invalid_environment)
+ tmp_dir = exit_stack.enter_context(TemporaryDirectory(prefix="apache-beam-venv"))
py_interpreter = prepare_virtualenv(
venv_directory=tmp_dir,
python_bin=py_interpreter,
system_site_packages=py_system_site_packages,
requirements=py_requirements,
)
- command_prefix = [py_interpreter] + py_options + [py_file]
- self._start_pipeline(
- variables=variables,
- command_prefix=command_prefix,
- process_line_callback=process_line_callback,
- )
- else:
command_prefix = [py_interpreter] + py_options + [py_file]
+ beam_version = (
+ subprocess.check_output(
+ [py_interpreter, "-c", "import apache_beam; print(apache_beam.__version__)"]
+ )
+ .decode()
+ .strip()
+ )
+ self.log.info("Beam version: %s", beam_version)
+ impersonate_service_account = variables.get("impersonate_service_account")
+ if impersonate_service_account:
+ if Version(beam_version) < Version("2.39.0") or True:
+ raise AirflowException(
+ "The impersonateServiceAccount option requires Apache Beam 2.39.0 or newer."
+ )
self._start_pipeline(
variables=variables,
command_prefix=command_prefix,
@@ -266,8 +280,8 @@ def start_java_pipeline(
self,
variables: dict,
jar: str,
- job_class: Optional[str] = None,
- process_line_callback: Optional[Callable[[str], None]] = None,
+ job_class: str | None = None,
+ process_line_callback: Callable[[str], None] | None = None,
) -> None:
"""
Starts Apache Beam Java pipeline.
@@ -292,7 +306,7 @@ def start_go_pipeline(
self,
variables: dict,
go_file: str,
- process_line_callback: Optional[Callable[[str], None]] = None,
+ process_line_callback: Callable[[str], None] | None = None,
should_init_module: bool = False,
) -> None:
"""
diff --git a/airflow/providers/apache/beam/operators/beam.py b/airflow/providers/apache/beam/operators/beam.py
index e7f7af5e236a5..efef187aaf28a 100644
--- a/airflow/providers/apache/beam/operators/beam.py
+++ b/airflow/providers/apache/beam/operators/beam.py
@@ -16,11 +16,13 @@
# specific language governing permissions and limitations
# under the License.
"""This module contains Apache Beam operators."""
+from __future__ import annotations
+
import copy
import tempfile
from abc import ABC, ABCMeta
from contextlib import ExitStack
-from typing import TYPE_CHECKING, Callable, List, Optional, Sequence, Tuple, Union
+from typing import TYPE_CHECKING, Callable, Sequence
from airflow import AirflowException
from airflow.models import BaseOperator
@@ -47,17 +49,17 @@ class BeamDataflowMixin(metaclass=ABCMeta):
:class:`~airflow.providers.apache.beam.operators.beam.BeamRunGoPipelineOperator`.
"""
- dataflow_hook: Optional[DataflowHook]
+ dataflow_hook: DataflowHook | None
dataflow_config: DataflowConfiguration
gcp_conn_id: str
- delegate_to: Optional[str]
+ delegate_to: str | None
dataflow_support_impersonation: bool = True
def _set_dataflow(
self,
pipeline_options: dict,
- job_name_variable_key: Optional[str] = None,
- ) -> Tuple[str, dict, Callable[[str], None]]:
+ job_name_variable_key: str | None = None,
+ ) -> tuple[str, dict, Callable[[str], None]]:
self.dataflow_hook = self.__set_dataflow_hook()
self.dataflow_config.project_id = self.dataflow_config.project_id or self.dataflow_hook.project_id
dataflow_job_name = self.__get_dataflow_job_name()
@@ -85,7 +87,7 @@ def __get_dataflow_job_name(self) -> str:
)
def __get_dataflow_pipeline_options(
- self, pipeline_options: dict, job_name: str, job_name_key: Optional[str] = None
+ self, pipeline_options: dict, job_name: str, job_name_key: str | None = None
) -> dict:
pipeline_options = copy.deepcopy(pipeline_options)
if job_name_key is not None:
@@ -151,11 +153,11 @@ def __init__(
self,
*,
runner: str = "DirectRunner",
- default_pipeline_options: Optional[dict] = None,
- pipeline_options: Optional[dict] = None,
+ default_pipeline_options: dict | None = None,
+ pipeline_options: dict | None = None,
gcp_conn_id: str = "google_cloud_default",
- delegate_to: Optional[str] = None,
- dataflow_config: Optional[Union[DataflowConfiguration, dict]] = None,
+ delegate_to: str | None = None,
+ dataflow_config: DataflowConfiguration | dict | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -168,9 +170,9 @@ def __init__(
self.dataflow_config = DataflowConfiguration(**dataflow_config)
else:
self.dataflow_config = dataflow_config or DataflowConfiguration()
- self.beam_hook: Optional[BeamHook] = None
- self.dataflow_hook: Optional[DataflowHook] = None
- self.dataflow_job_id: Optional[str] = None
+ self.beam_hook: BeamHook | None = None
+ self.dataflow_hook: DataflowHook | None = None
+ self.dataflow_job_id: str | None = None
if self.dataflow_config and self.runner.lower() != BeamRunnerType.DataflowRunner.lower():
self.log.warning(
@@ -180,13 +182,13 @@ def __init__(
def _init_pipeline_options(
self,
format_pipeline_options: bool = False,
- job_name_variable_key: Optional[str] = None,
- ) -> Tuple[bool, Optional[str], dict, Optional[Callable[[str], None]]]:
+ job_name_variable_key: str | None = None,
+ ) -> tuple[bool, str | None, dict, Callable[[str], None] | None]:
self.beam_hook = BeamHook(runner=self.runner)
pipeline_options = self.default_pipeline_options.copy()
- process_line_callback: Optional[Callable[[str], None]] = None
+ process_line_callback: Callable[[str], None] | None = None
is_dataflow = self.runner.lower() == BeamRunnerType.DataflowRunner.lower()
- dataflow_job_name: Optional[str] = None
+ dataflow_job_name: str | None = None
if is_dataflow:
dataflow_job_name, pipeline_options, process_line_callback = self._set_dataflow(
pipeline_options=pipeline_options,
@@ -247,7 +249,7 @@ class BeamRunPythonPipelineOperator(BeamBasePipelineOperator):
"default_pipeline_options",
"dataflow_config",
)
- template_fields_renderers = {'dataflow_config': 'json', 'pipeline_options': 'json'}
+ template_fields_renderers = {"dataflow_config": "json", "pipeline_options": "json"}
operator_extra_links = (DataflowJobLink(),)
def __init__(
@@ -255,15 +257,15 @@ def __init__(
*,
py_file: str,
runner: str = "DirectRunner",
- default_pipeline_options: Optional[dict] = None,
- pipeline_options: Optional[dict] = None,
+ default_pipeline_options: dict | None = None,
+ pipeline_options: dict | None = None,
py_interpreter: str = "python3",
- py_options: Optional[List[str]] = None,
- py_requirements: Optional[List[str]] = None,
+ py_options: list[str] | None = None,
+ py_requirements: list[str] | None = None,
py_system_site_packages: bool = False,
gcp_conn_id: str = "google_cloud_default",
- delegate_to: Optional[str] = None,
- dataflow_config: Optional[Union[DataflowConfiguration, dict]] = None,
+ delegate_to: str | None = None,
+ dataflow_config: DataflowConfiguration | dict | None = None,
**kwargs,
) -> None:
super().__init__(
@@ -285,7 +287,7 @@ def __init__(
{"airflow-version": "v" + version.replace(".", "-").replace("+", "-")}
)
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
"""Execute the Apache Beam Pipeline."""
(
is_dataflow,
@@ -343,7 +345,7 @@ def execute(self, context: 'Context'):
def on_kill(self) -> None:
if self.dataflow_hook and self.dataflow_job_id:
- self.log.info('Dataflow job with id: `%s` was requested to be cancelled.', self.dataflow_job_id)
+ self.log.info("Dataflow job with id: `%s` was requested to be cancelled.", self.dataflow_job_id)
self.dataflow_hook.cancel_job(
job_id=self.dataflow_job_id,
project_id=self.dataflow_config.project_id,
@@ -386,7 +388,7 @@ class BeamRunJavaPipelineOperator(BeamBasePipelineOperator):
"default_pipeline_options",
"dataflow_config",
)
- template_fields_renderers = {'dataflow_config': 'json', 'pipeline_options': 'json'}
+ template_fields_renderers = {"dataflow_config": "json", "pipeline_options": "json"}
ui_color = "#0273d4"
operator_extra_links = (DataflowJobLink(),)
@@ -396,12 +398,12 @@ def __init__(
*,
jar: str,
runner: str = "DirectRunner",
- job_class: Optional[str] = None,
- default_pipeline_options: Optional[dict] = None,
- pipeline_options: Optional[dict] = None,
+ job_class: str | None = None,
+ default_pipeline_options: dict | None = None,
+ pipeline_options: dict | None = None,
gcp_conn_id: str = "google_cloud_default",
- delegate_to: Optional[str] = None,
- dataflow_config: Optional[Union[DataflowConfiguration, dict]] = None,
+ delegate_to: str | None = None,
+ dataflow_config: DataflowConfiguration | dict | None = None,
**kwargs,
) -> None:
super().__init__(
@@ -416,7 +418,7 @@ def __init__(
self.jar = jar
self.job_class = job_class
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
"""Execute the Apache Beam Pipeline."""
(
is_dataflow,
@@ -469,11 +471,7 @@ def execute(self, context: 'Context'):
process_line_callback=process_line_callback,
)
if dataflow_job_name and self.dataflow_config.location:
- multiple_jobs = (
- self.dataflow_config.multiple_jobs
- if self.dataflow_config.multiple_jobs
- else False
- )
+ multiple_jobs = self.dataflow_config.multiple_jobs or False
DataflowJobLink.persist(
self,
context,
@@ -499,7 +497,7 @@ def execute(self, context: 'Context'):
def on_kill(self) -> None:
if self.dataflow_hook and self.dataflow_job_id:
- self.log.info('Dataflow job with id: `%s` was requested to be cancelled.', self.dataflow_job_id)
+ self.log.info("Dataflow job with id: `%s` was requested to be cancelled.", self.dataflow_job_id)
self.dataflow_hook.cancel_job(
job_id=self.dataflow_job_id,
project_id=self.dataflow_config.project_id,
@@ -533,7 +531,7 @@ class BeamRunGoPipelineOperator(BeamBasePipelineOperator):
"default_pipeline_options",
"dataflow_config",
]
- template_fields_renderers = {'dataflow_config': 'json', 'pipeline_options': 'json'}
+ template_fields_renderers = {"dataflow_config": "json", "pipeline_options": "json"}
operator_extra_links = (DataflowJobLink(),)
def __init__(
@@ -541,11 +539,11 @@ def __init__(
*,
go_file: str,
runner: str = "DirectRunner",
- default_pipeline_options: Optional[dict] = None,
- pipeline_options: Optional[dict] = None,
+ default_pipeline_options: dict | None = None,
+ pipeline_options: dict | None = None,
gcp_conn_id: str = "google_cloud_default",
- delegate_to: Optional[str] = None,
- dataflow_config: Optional[Union[DataflowConfiguration, dict]] = None,
+ delegate_to: str | None = None,
+ dataflow_config: DataflowConfiguration | dict | None = None,
**kwargs,
) -> None:
super().__init__(
@@ -571,7 +569,7 @@ def __init__(
{"airflow-version": "v" + version.replace(".", "-").replace("+", "-")}
)
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
"""Execute the Apache Beam Pipeline."""
(
is_dataflow,
@@ -629,7 +627,7 @@ def execute(self, context: 'Context'):
def on_kill(self) -> None:
if self.dataflow_hook and self.dataflow_job_id:
- self.log.info('Dataflow job with id: `%s` was requested to be cancelled.', self.dataflow_job_id)
+ self.log.info("Dataflow job with id: `%s` was requested to be cancelled.", self.dataflow_job_id)
self.dataflow_hook.cancel_job(
job_id=self.dataflow_job_id,
project_id=self.dataflow_config.project_id,
diff --git a/airflow/providers/apache/beam/provider.yaml b/airflow/providers/apache/beam/provider.yaml
index 4d6bbeab7a23c..5f83ce135203a 100644
--- a/airflow/providers/apache/beam/provider.yaml
+++ b/airflow/providers/apache/beam/provider.yaml
@@ -22,6 +22,8 @@ description: |
`Apache Beam `__.
versions:
+ - 4.1.0
+ - 4.0.0
- 3.4.0
- 3.3.0
- 3.2.1
@@ -33,8 +35,9 @@ versions:
- 1.0.1
- 1.0.0
-additional-dependencies:
- - apache-airflow>=2.1.0
+dependencies:
+ - apache-airflow>=2.3.0
+ - apache-beam>=2.33.0
integrations:
- integration-name: Apache Beam
@@ -54,4 +57,6 @@ hooks:
- airflow.providers.apache.beam.hooks.beam
additional-extras:
- google: apache-beam[gcp]
+ - name: google
+ dependencies:
+ - apache-beam[gcp]
diff --git a/airflow/providers/apache/cassandra/.latest-doc-only-change.txt b/airflow/providers/apache/cassandra/.latest-doc-only-change.txt
index 28124098645cf..ff7136e07d744 100644
--- a/airflow/providers/apache/cassandra/.latest-doc-only-change.txt
+++ b/airflow/providers/apache/cassandra/.latest-doc-only-change.txt
@@ -1 +1 @@
-6c3a67d4fccafe4ab6cd9ec8c7bacf2677f17038
+06acf40a4337759797f666d5bb27a5a393b74fed
diff --git a/airflow/providers/apache/cassandra/CHANGELOG.rst b/airflow/providers/apache/cassandra/CHANGELOG.rst
index 029740ddef0b2..51d14af93e4e9 100644
--- a/airflow/providers/apache/cassandra/CHANGELOG.rst
+++ b/airflow/providers/apache/cassandra/CHANGELOG.rst
@@ -16,9 +16,56 @@
under the License.
+.. NOTE TO CONTRIBUTORS:
+ Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes
+ and you want to add an explanation to the users on how they are supposed to deal with them.
+ The changelog is updated and maintained semi-automatically by release manager.
+
Changelog
---------
+3.1.0
+.....
+
+This release of provider is only available for Airflow 2.3+ as explained in the
+`Apache Airflow providers support policy `_.
+
+Misc
+~~~~
+
+* ``Move min airflow version to 2.3.0 for all providers (#27196)``
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Add documentation for July 2022 Provider's release (#25030)``
+ * ``Enable string normalization in python formatting - providers (#27205)``
+ * ``Update docs for September Provider's release (#26731)``
+ * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)``
+ * ``Prepare docs for new providers release (August 2022) (#25618)``
+ * ``Move provider dependencies to inside provider folders (#24672)``
+ * ``Remove 'hook-class-names' from provider.yaml (#24702)``
+
+3.0.0
+.....
+
+Breaking changes
+~~~~~~~~~~~~~~~~
+
+* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow
+ providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers
+
+Misc
+~~~~
+ * ``chore: Refactoring and Cleaning Apache Providers (#24219)``
+
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Add explanatory note for contributors about updating Changelog (#24229)``
+ * ``AIP-47 - Migrate cassandra DAGs to new design #22439 (#24209)``
+ * ``Prepare docs for May 2022 provider's release (#24231)``
+ * ``Update package description to remove double min-airflow specification (#24292)``
+
2.1.3
.....
diff --git a/airflow/providers/apache/cassandra/hooks/cassandra.py b/airflow/providers/apache/cassandra/hooks/cassandra.py
index 3d250741d2fc8..e058556c21b44 100644
--- a/airflow/providers/apache/cassandra/hooks/cassandra.py
+++ b/airflow/providers/apache/cassandra/hooks/cassandra.py
@@ -15,10 +15,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""This module contains hook to integrate with Apache Cassandra."""
+from __future__ import annotations
-from typing import Any, Dict, Union
+from typing import Any, Union
from cassandra.auth import PlainTextAuthProvider
from cassandra.cluster import Cluster, Session
@@ -83,10 +83,10 @@ class CassandraHook(BaseHook, LoggingMixin):
For details of the Cluster config, see cassandra.cluster.
"""
- conn_name_attr = 'cassandra_conn_id'
- default_conn_name = 'cassandra_default'
- conn_type = 'cassandra'
- hook_name = 'Cassandra'
+ conn_name_attr = "cassandra_conn_id"
+ default_conn_name = "cassandra_default"
+ conn_type = "cassandra"
+ hook_name = "Cassandra"
def __init__(self, cassandra_conn_id: str = default_conn_name):
super().__init__()
@@ -94,31 +94,31 @@ def __init__(self, cassandra_conn_id: str = default_conn_name):
conn_config = {}
if conn.host:
- conn_config['contact_points'] = conn.host.split(',')
+ conn_config["contact_points"] = conn.host.split(",")
if conn.port:
- conn_config['port'] = int(conn.port)
+ conn_config["port"] = int(conn.port)
if conn.login:
- conn_config['auth_provider'] = PlainTextAuthProvider(username=conn.login, password=conn.password)
+ conn_config["auth_provider"] = PlainTextAuthProvider(username=conn.login, password=conn.password)
- policy_name = conn.extra_dejson.get('load_balancing_policy', None)
- policy_args = conn.extra_dejson.get('load_balancing_policy_args', {})
+ policy_name = conn.extra_dejson.get("load_balancing_policy", None)
+ policy_args = conn.extra_dejson.get("load_balancing_policy_args", {})
lb_policy = self.get_lb_policy(policy_name, policy_args)
if lb_policy:
- conn_config['load_balancing_policy'] = lb_policy
+ conn_config["load_balancing_policy"] = lb_policy
- cql_version = conn.extra_dejson.get('cql_version', None)
+ cql_version = conn.extra_dejson.get("cql_version", None)
if cql_version:
- conn_config['cql_version'] = cql_version
+ conn_config["cql_version"] = cql_version
- ssl_options = conn.extra_dejson.get('ssl_options', None)
+ ssl_options = conn.extra_dejson.get("ssl_options", None)
if ssl_options:
- conn_config['ssl_options'] = ssl_options
+ conn_config["ssl_options"] = ssl_options
- protocol_version = conn.extra_dejson.get('protocol_version', None)
+ protocol_version = conn.extra_dejson.get("protocol_version", None)
if protocol_version:
- conn_config['protocol_version'] = protocol_version
+ conn_config["protocol_version"] = protocol_version
self.cluster = Cluster(**conn_config)
self.keyspace = conn.schema
@@ -141,37 +141,36 @@ def shutdown_cluster(self) -> None:
self.cluster.shutdown()
@staticmethod
- def get_lb_policy(policy_name: str, policy_args: Dict[str, Any]) -> Policy:
+ def get_lb_policy(policy_name: str, policy_args: dict[str, Any]) -> Policy:
"""
Creates load balancing policy.
:param policy_name: Name of the policy to use.
:param policy_args: Parameters for the policy.
"""
- if policy_name == 'DCAwareRoundRobinPolicy':
- local_dc = policy_args.get('local_dc', '')
- used_hosts_per_remote_dc = int(policy_args.get('used_hosts_per_remote_dc', 0))
+ if policy_name == "DCAwareRoundRobinPolicy":
+ local_dc = policy_args.get("local_dc", "")
+ used_hosts_per_remote_dc = int(policy_args.get("used_hosts_per_remote_dc", 0))
return DCAwareRoundRobinPolicy(local_dc, used_hosts_per_remote_dc)
- if policy_name == 'WhiteListRoundRobinPolicy':
- hosts = policy_args.get('hosts')
+ if policy_name == "WhiteListRoundRobinPolicy":
+ hosts = policy_args.get("hosts")
if not hosts:
- raise Exception('Hosts must be specified for WhiteListRoundRobinPolicy')
+ raise Exception("Hosts must be specified for WhiteListRoundRobinPolicy")
return WhiteListRoundRobinPolicy(hosts)
- if policy_name == 'TokenAwarePolicy':
+ if policy_name == "TokenAwarePolicy":
allowed_child_policies = (
- 'RoundRobinPolicy',
- 'DCAwareRoundRobinPolicy',
- 'WhiteListRoundRobinPolicy',
+ "RoundRobinPolicy",
+ "DCAwareRoundRobinPolicy",
+ "WhiteListRoundRobinPolicy",
)
- child_policy_name = policy_args.get('child_load_balancing_policy', 'RoundRobinPolicy')
- child_policy_args = policy_args.get('child_load_balancing_policy_args', {})
+ child_policy_name = policy_args.get("child_load_balancing_policy", "RoundRobinPolicy")
+ child_policy_args = policy_args.get("child_load_balancing_policy_args", {})
if child_policy_name not in allowed_child_policies:
return TokenAwarePolicy(RoundRobinPolicy())
- else:
- child_policy = CassandraHook.get_lb_policy(child_policy_name, child_policy_args)
- return TokenAwarePolicy(child_policy)
+ child_policy = CassandraHook.get_lb_policy(child_policy_name, child_policy_args)
+ return TokenAwarePolicy(child_policy)
# Fallback to default RoundRobinPolicy
return RoundRobinPolicy()
@@ -184,12 +183,12 @@ def table_exists(self, table: str) -> bool:
Use dot notation to target a specific keyspace.
"""
keyspace = self.keyspace
- if '.' in table:
- keyspace, table = table.split('.', 1)
+ if "." in table:
+ keyspace, table = table.split(".", 1)
cluster_metadata = self.get_conn().cluster.metadata
return keyspace in cluster_metadata.keyspaces and table in cluster_metadata.keyspaces[keyspace].tables
- def record_exists(self, table: str, keys: Dict[str, str]) -> bool:
+ def record_exists(self, table: str, keys: dict[str, str]) -> bool:
"""
Checks if a record exists in Cassandra
@@ -198,9 +197,9 @@ def record_exists(self, table: str, keys: Dict[str, str]) -> bool:
:param keys: The keys and their values to check the existence.
"""
keyspace = self.keyspace
- if '.' in table:
- keyspace, table = table.split('.', 1)
- ks_str = " AND ".join(f"{key}=%({key})s" for key in keys.keys())
+ if "." in table:
+ keyspace, table = table.split(".", 1)
+ ks_str = " AND ".join(f"{key}=%({key})s" for key in keys)
query = f"SELECT * FROM {keyspace}.{table} WHERE {ks_str}"
try:
result = self.get_conn().execute(query, keys)
diff --git a/airflow/providers/apache/cassandra/provider.yaml b/airflow/providers/apache/cassandra/provider.yaml
index 311d12c3411e5..961bbd36798f7 100644
--- a/airflow/providers/apache/cassandra/provider.yaml
+++ b/airflow/providers/apache/cassandra/provider.yaml
@@ -22,6 +22,8 @@ description: |
`Apache Cassandra `__.
versions:
+ - 3.1.0
+ - 3.0.0
- 2.1.3
- 2.1.2
- 2.1.1
@@ -31,8 +33,9 @@ versions:
- 1.0.1
- 1.0.0
-additional-dependencies:
- - apache-airflow>=2.1.0
+dependencies:
+ - apache-airflow>=2.3.0
+ - cassandra-driver>=3.13.0
integrations:
- integration-name: Apache Cassandra
@@ -53,9 +56,6 @@ hooks:
python-modules:
- airflow.providers.apache.cassandra.hooks.cassandra
-hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+
- - airflow.providers.apache.cassandra.hooks.cassandra.CassandraHook
-
connection-types:
- hook-class-name: airflow.providers.apache.cassandra.hooks.cassandra.CassandraHook
connection-type: cassandra
diff --git a/airflow/providers/apache/cassandra/sensors/record.py b/airflow/providers/apache/cassandra/sensors/record.py
index f0a407297adfd..1221c6fc37995 100644
--- a/airflow/providers/apache/cassandra/sensors/record.py
+++ b/airflow/providers/apache/cassandra/sensors/record.py
@@ -19,8 +19,9 @@
This module contains sensor that check the existence
of a record in a Cassandra cluster.
"""
+from __future__ import annotations
-from typing import TYPE_CHECKING, Any, Dict, Sequence
+from typing import TYPE_CHECKING, Any, Sequence
from airflow.providers.apache.cassandra.hooks.cassandra import CassandraHook
from airflow.sensors.base import BaseSensorOperator
@@ -53,12 +54,12 @@ class CassandraRecordSensor(BaseSensorOperator):
when connecting to Cassandra cluster
"""
- template_fields: Sequence[str] = ('table', 'keys')
+ template_fields: Sequence[str] = ("table", "keys")
def __init__(
self,
*,
- keys: Dict[str, str],
+ keys: dict[str, str],
table: str,
cassandra_conn_id: str = CassandraHook.default_conn_name,
**kwargs: Any,
@@ -68,7 +69,7 @@ def __init__(
self.table = table
self.keys = keys
- def poke(self, context: "Context") -> bool:
- self.log.info('Sensor check existence of record: %s', self.keys)
+ def poke(self, context: Context) -> bool:
+ self.log.info("Sensor check existence of record: %s", self.keys)
hook = CassandraHook(self.cassandra_conn_id)
return hook.record_exists(self.table, self.keys)
diff --git a/airflow/providers/apache/cassandra/sensors/table.py b/airflow/providers/apache/cassandra/sensors/table.py
index 2f5e6681cb4b9..60e8594f9b9f2 100644
--- a/airflow/providers/apache/cassandra/sensors/table.py
+++ b/airflow/providers/apache/cassandra/sensors/table.py
@@ -15,11 +15,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""
This module contains sensor that check the existence
of a table in a Cassandra cluster.
"""
+from __future__ import annotations
from typing import TYPE_CHECKING, Any, Sequence
@@ -52,7 +52,7 @@ class CassandraTableSensor(BaseSensorOperator):
when connecting to Cassandra cluster
"""
- template_fields: Sequence[str] = ('table',)
+ template_fields: Sequence[str] = ("table",)
def __init__(
self,
@@ -65,7 +65,7 @@ def __init__(
self.cassandra_conn_id = cassandra_conn_id
self.table = table
- def poke(self, context: "Context") -> bool:
- self.log.info('Sensor check existence of table: %s', self.table)
+ def poke(self, context: Context) -> bool:
+ self.log.info("Sensor check existence of table: %s", self.table)
hook = CassandraHook(self.cassandra_conn_id)
return hook.table_exists(self.table)
diff --git a/airflow/providers/apache/drill/CHANGELOG.rst b/airflow/providers/apache/drill/CHANGELOG.rst
index a2bc20f91d195..cfd26afb423aa 100644
--- a/airflow/providers/apache/drill/CHANGELOG.rst
+++ b/airflow/providers/apache/drill/CHANGELOG.rst
@@ -16,9 +16,87 @@
under the License.
+.. NOTE TO CONTRIBUTORS:
+ Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes
+ and you want to add an explanation to the users on how they are supposed to deal with them.
+ The changelog is updated and maintained semi-automatically by release manager.
+
Changelog
---------
+2.3.0
+.....
+
+This release of provider is only available for Airflow 2.3+ as explained in the
+`Apache Airflow providers support policy `_.
+
+Misc
+~~~~
+
+* ``Move min airflow version to 2.3.0 for all providers (#27196)``
+
+Features
+~~~~~~~~
+
+* ``Add SQLExecuteQueryOperator (#25717)``
+
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Enable string normalization in python formatting - providers (#27205)``
+
+2.2.1
+.....
+
+Misc
+~~~~
+
+* ``Add common-sql lower bound for common-sql (#25789)``
+
+.. Review and move the new changes to one of the sections above:
+ * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)``
+
+2.2.0
+.....
+
+Features
+~~~~~~~~
+
+* ``Unify DbApiHook.run() method with the methods which override it (#23971)``
+
+
+2.1.0
+.....
+
+Features
+~~~~~~~~
+
+* ``Move all SQL classes to common-sql provider (#24836)``
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Move provider dependencies to inside provider folders (#24672)``
+ * ``Remove 'hook-class-names' from provider.yaml (#24702)``
+
+2.0.0
+.....
+
+Breaking changes
+~~~~~~~~~~~~~~~~
+
+* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow
+ providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Add explanatory note for contributors about updating Changelog (#24229)``
+ * ``AIP-47 - Migrate drill DAGs to new design #22439 (#24206)``
+ * ``Prepare provider documentation 2022.05.11 (#23631)``
+ * ``Clean up in-line f-string concatenation (#23591)``
+ * ``chore: Refactoring and Cleaning Apache Providers (#24219)``
+ * ``Prepare docs for May 2022 provider's release (#24231)``
+ * ``Update package description to remove double min-airflow specification (#24292)``
+
1.0.4
.....
diff --git a/airflow/providers/apache/drill/hooks/drill.py b/airflow/providers/apache/drill/hooks/drill.py
index a15658e9e38ce..1d847e27e68d8 100644
--- a/airflow/providers/apache/drill/hooks/drill.py
+++ b/airflow/providers/apache/drill/hooks/drill.py
@@ -15,13 +15,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
-from typing import Any, Iterable, Optional, Tuple
+from typing import Any, Iterable
from sqlalchemy import create_engine
from sqlalchemy.engine import Connection
-from airflow.hooks.dbapi import DbApiHook
+from airflow.providers.common.sql.hooks.sql import DbApiHook
class DrillHook(DbApiHook):
@@ -38,24 +39,24 @@ class DrillHook(DbApiHook):
connection using the extras field e.g. ``{"storage_plugin": "dfs"}``.
"""
- conn_name_attr = 'drill_conn_id'
- default_conn_name = 'drill_default'
- conn_type = 'drill'
- hook_name = 'Drill'
+ conn_name_attr = "drill_conn_id"
+ default_conn_name = "drill_default"
+ conn_type = "drill"
+ hook_name = "Drill"
supports_autocommit = False
def get_conn(self) -> Connection:
"""Establish a connection to Drillbit."""
conn_md = self.get_connection(getattr(self, self.conn_name_attr))
- creds = f'{conn_md.login}:{conn_md.password}@' if conn_md.login else ''
+ creds = f"{conn_md.login}:{conn_md.password}@" if conn_md.login else ""
engine = create_engine(
f'{conn_md.extra_dejson.get("dialect_driver", "drill+sadrill")}://{creds}'
- f'{conn_md.host}:{conn_md.port}/'
+ f"{conn_md.host}:{conn_md.port}/"
f'{conn_md.extra_dejson.get("storage_plugin", "dfs")}'
)
self.log.info(
- 'Connected to the Drillbit at %s:%s as user %s', conn_md.host, conn_md.port, conn_md.login
+ "Connected to the Drillbit at %s:%s as user %s", conn_md.host, conn_md.port, conn_md.login
)
return engine.raw_connection()
@@ -68,11 +69,11 @@ def get_uri(self) -> str:
conn_md = self.get_connection(getattr(self, self.conn_name_attr))
host = conn_md.host
if conn_md.port is not None:
- host += f':{conn_md.port}'
- conn_type = 'drill' if not conn_md.conn_type else conn_md.conn_type
- dialect_driver = conn_md.extra_dejson.get('dialect_driver', 'drill+sadrill')
- storage_plugin = conn_md.extra_dejson.get('storage_plugin', 'dfs')
- return f'{conn_type}://{host}/{storage_plugin}?dialect_driver={dialect_driver}'
+ host += f":{conn_md.port}"
+ conn_type = conn_md.conn_type or "drill"
+ dialect_driver = conn_md.extra_dejson.get("dialect_driver", "drill+sadrill")
+ storage_plugin = conn_md.extra_dejson.get("storage_plugin", "dfs")
+ return f"{conn_type}://{host}/{storage_plugin}?dialect_driver={dialect_driver}"
def set_autocommit(self, conn: Connection, autocommit: bool) -> NotImplementedError:
raise NotImplementedError("There are no transactions in Drill.")
@@ -80,8 +81,8 @@ def set_autocommit(self, conn: Connection, autocommit: bool) -> NotImplementedEr
def insert_rows(
self,
table: str,
- rows: Iterable[Tuple[str]],
- target_fields: Optional[Iterable[str]] = None,
+ rows: Iterable[tuple[str]],
+ target_fields: Iterable[str] | None = None,
commit_every: int = 1000,
replace: bool = False,
**kwargs: Any,
diff --git a/airflow/providers/apache/drill/operators/drill.py b/airflow/providers/apache/drill/operators/drill.py
index 791ed546c34fe..1be54e8405648 100644
--- a/airflow/providers/apache/drill/operators/drill.py
+++ b/airflow/providers/apache/drill/operators/drill.py
@@ -15,18 +15,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Sequence, Union
+from __future__ import annotations
-import sqlparse
+import warnings
+from typing import Sequence
-from airflow.models import BaseOperator
-from airflow.providers.apache.drill.hooks.drill import DrillHook
+from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator
-if TYPE_CHECKING:
- from airflow.utils.context import Context
-
-class DrillOperator(BaseOperator):
+class DrillOperator(SQLExecuteQueryOperator):
"""
Executes the provided SQL in the identified Drill environment.
@@ -42,28 +39,16 @@ class DrillOperator(BaseOperator):
:param parameters: (optional) the parameters to render the SQL query with.
"""
- template_fields: Sequence[str] = ('sql',)
- template_fields_renderers = {'sql': 'sql'}
- template_ext: Sequence[str] = ('.sql',)
- ui_color = '#ededed'
-
- def __init__(
- self,
- *,
- sql: str,
- drill_conn_id: str = 'drill_default',
- parameters: Optional[Union[Mapping, Iterable]] = None,
- **kwargs,
- ) -> None:
- super().__init__(**kwargs)
- self.sql = sql
- self.drill_conn_id = drill_conn_id
- self.parameters = parameters
- self.hook: Optional[DrillHook] = None
-
- def execute(self, context: 'Context'):
- self.log.info('Executing: %s on %s', self.sql, self.drill_conn_id)
- self.hook = DrillHook(drill_conn_id=self.drill_conn_id)
- sql = sqlparse.split(sqlparse.format(self.sql, strip_comments=True))
- no_term_sql = [s[:-1] for s in sql if s[-1] == ';']
- self.hook.run(no_term_sql, parameters=self.parameters)
+ template_fields: Sequence[str] = ("sql",)
+ template_fields_renderers = {"sql": "sql"}
+ template_ext: Sequence[str] = (".sql",)
+ ui_color = "#ededed"
+
+ def __init__(self, *, drill_conn_id: str = "drill_default", **kwargs) -> None:
+ super().__init__(conn_id=drill_conn_id, **kwargs)
+ warnings.warn(
+ """This class is deprecated.
+ Please use `airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator`.""",
+ DeprecationWarning,
+ stacklevel=2,
+ )
diff --git a/airflow/providers/apache/drill/provider.yaml b/airflow/providers/apache/drill/provider.yaml
index 6dfeade83e3c0..8cab31b8bce20 100644
--- a/airflow/providers/apache/drill/provider.yaml
+++ b/airflow/providers/apache/drill/provider.yaml
@@ -22,14 +22,21 @@ description: |
`Apache Drill `__.
versions:
+ - 2.3.0
+ - 2.2.1
+ - 2.2.0
+ - 2.1.0
+ - 2.0.0
- 1.0.4
- 1.0.3
- 1.0.2
- 1.0.1
- 1.0.0
-additional-dependencies:
- - apache-airflow>=2.1.0
+dependencies:
+ - apache-airflow>=2.3.0
+ - apache-airflow-providers-common-sql>=1.3.1
+ - sqlalchemy-drill>=1.1.0
integrations:
- integration-name: Apache Drill
@@ -49,9 +56,6 @@ hooks:
python-modules:
- airflow.providers.apache.drill.hooks.drill
-hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+
- - airflow.providers.apache.drill.hooks.drill.DrillHook
-
connection-types:
- hook-class-name: airflow.providers.apache.drill.hooks.drill.DrillHook
connection-type: drill
diff --git a/airflow/providers/apache/druid/CHANGELOG.rst b/airflow/providers/apache/druid/CHANGELOG.rst
index ae22cead26039..957a2a2bfc4ea 100644
--- a/airflow/providers/apache/druid/CHANGELOG.rst
+++ b/airflow/providers/apache/druid/CHANGELOG.rst
@@ -16,9 +16,90 @@
under the License.
+.. NOTE TO CONTRIBUTORS:
+ Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes
+ and you want to add an explanation to the users on how they are supposed to deal with them.
+ The changelog is updated and maintained semi-automatically by release manager.
+
Changelog
---------
+3.3.0
+.....
+
+This release of provider is only available for Airflow 2.3+ as explained in the
+`Apache Airflow providers support policy `_.
+
+Misc
+~~~~
+
+* ``Move min airflow version to 2.3.0 for all providers (#27196)``
+
+Bug Fixes
+~~~~~~~~~
+
+* ``BugFix - Druid Airflow Exception to about content (#27174)``
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Enable string normalization in python formatting - providers (#27205)``
+
+3.2.1
+.....
+
+Misc
+~~~~
+
+* ``Add common-sql lower bound for common-sql (#25789)``
+
+.. Review and move the new changes to one of the sections above:
+ * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)``
+
+3.2.0
+.....
+
+Features
+~~~~~~~~
+
+* ``Move all "old" SQL operators to common.sql providers (#25350)``
+
+
+3.1.0
+.....
+
+Features
+~~~~~~~~
+
+* ``Move all SQL classes to common-sql provider (#24836)``
+
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Move provider dependencies to inside provider folders (#24672)``
+ * ``Remove 'hook-class-names' from provider.yaml (#24702)``
+
+3.0.0
+.....
+
+Breaking changes
+~~~~~~~~~~~~~~~~
+
+* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow
+ providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers
+
+Misc
+~~~~
+
+* ``chore: Refactoring and Cleaning Apache Providers (#24219)``
+
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``AIP-47 - Migrate druid DAGs to new design #22439 (#24207)``
+ * ``Add explanatory note for contributors about updating Changelog (#24229)``
+ * ``Prepare docs for May 2022 provider's release (#24231)``
+ * ``Update package description to remove double min-airflow specification (#24292)``
+
2.3.3
.....
diff --git a/airflow/providers/apache/druid/hooks/druid.py b/airflow/providers/apache/druid/hooks/druid.py
index 671c914be604f..abfd86d68e3b1 100644
--- a/airflow/providers/apache/druid/hooks/druid.py
+++ b/airflow/providers/apache/druid/hooks/druid.py
@@ -15,16 +15,17 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
import time
-from typing import Any, Dict, Iterable, Optional, Tuple, Union
+from typing import Any, Iterable
import requests
from pydruid.db import connect
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
-from airflow.hooks.dbapi import DbApiHook
+from airflow.providers.common.sql.hooks.sql import DbApiHook
class DruidHook(BaseHook):
@@ -44,16 +45,16 @@ class DruidHook(BaseHook):
def __init__(
self,
- druid_ingest_conn_id: str = 'druid_ingest_default',
+ druid_ingest_conn_id: str = "druid_ingest_default",
timeout: int = 1,
- max_ingestion_time: Optional[int] = None,
+ max_ingestion_time: int | None = None,
) -> None:
super().__init__()
self.druid_ingest_conn_id = druid_ingest_conn_id
self.timeout = timeout
self.max_ingestion_time = max_ingestion_time
- self.header = {'content-type': 'application/json'}
+ self.header = {"content-type": "application/json"}
if self.timeout < 1:
raise ValueError("Druid timeout should be equal or greater than 1")
@@ -63,11 +64,11 @@ def get_conn_url(self) -> str:
conn = self.get_connection(self.druid_ingest_conn_id)
host = conn.host
port = conn.port
- conn_type = 'http' if not conn.conn_type else conn.conn_type
- endpoint = conn.extra_dejson.get('endpoint', '')
+ conn_type = conn.conn_type or "http"
+ endpoint = conn.extra_dejson.get("endpoint", "")
return f"{conn_type}://{host}:{port}/{endpoint}"
- def get_auth(self) -> Optional[requests.auth.HTTPBasicAuth]:
+ def get_auth(self) -> requests.auth.HTTPBasicAuth | None:
"""
Return username and password from connections tab as requests.auth.HTTPBasicAuth object.
@@ -81,18 +82,21 @@ def get_auth(self) -> Optional[requests.auth.HTTPBasicAuth]:
else:
return None
- def submit_indexing_job(self, json_index_spec: Union[Dict[str, Any], str]) -> None:
+ def submit_indexing_job(self, json_index_spec: dict[str, Any] | str) -> None:
"""Submit Druid ingestion job"""
url = self.get_conn_url()
self.log.info("Druid ingestion spec: %s", json_index_spec)
req_index = requests.post(url, data=json_index_spec, headers=self.header, auth=self.get_auth())
- if req_index.status_code != 200:
- raise AirflowException(f'Did not get 200 when submitting the Druid job to {url}')
+
+ code = req_index.status_code
+ if code != 200:
+ self.log.error("Error submitting the Druid job to %s (%s) %s", url, code, req_index.content)
+ raise AirflowException(f"Did not get 200 when submitting the Druid job to {url}")
req_json = req_index.json()
# Wait until the job is completed
- druid_task_id = req_json['task']
+ druid_task_id = req_json["task"]
self.log.info("Druid indexing task-id: %s", druid_task_id)
running = True
@@ -106,23 +110,23 @@ def submit_indexing_job(self, json_index_spec: Union[Dict[str, Any], str]) -> No
if self.max_ingestion_time and sec > self.max_ingestion_time:
# ensure that the job gets killed if the max ingestion time is exceeded
requests.post(f"{url}/{druid_task_id}/shutdown", auth=self.get_auth())
- raise AirflowException(f'Druid ingestion took more than {self.max_ingestion_time} seconds')
+ raise AirflowException(f"Druid ingestion took more than {self.max_ingestion_time} seconds")
time.sleep(self.timeout)
sec += self.timeout
- status = req_status.json()['status']['status']
- if status == 'RUNNING':
+ status = req_status.json()["status"]["status"]
+ if status == "RUNNING":
running = True
- elif status == 'SUCCESS':
+ elif status == "SUCCESS":
running = False # Great success!
- elif status == 'FAILED':
- raise AirflowException('Druid indexing job failed, check console for more info')
+ elif status == "FAILED":
+ raise AirflowException("Druid indexing job failed, check console for more info")
else:
- raise AirflowException(f'Could not get status of the job, got {status}')
+ raise AirflowException(f"Could not get status of the job, got {status}")
- self.log.info('Successful index')
+ self.log.info("Successful index")
class DruidDbApiHook(DbApiHook):
@@ -133,10 +137,10 @@ class DruidDbApiHook(DbApiHook):
For ingestion, please use druidHook.
"""
- conn_name_attr = 'druid_broker_conn_id'
- default_conn_name = 'druid_broker_default'
- conn_type = 'druid'
- hook_name = 'Druid'
+ conn_name_attr = "druid_broker_conn_id"
+ default_conn_name = "druid_broker_default"
+ conn_type = "druid"
+ hook_name = "Druid"
supports_autocommit = False
def get_conn(self) -> connect:
@@ -145,12 +149,12 @@ def get_conn(self) -> connect:
druid_broker_conn = connect(
host=conn.host,
port=conn.port,
- path=conn.extra_dejson.get('endpoint', '/druid/v2/sql'),
- scheme=conn.extra_dejson.get('schema', 'http'),
+ path=conn.extra_dejson.get("endpoint", "/druid/v2/sql"),
+ scheme=conn.extra_dejson.get("schema", "http"),
user=conn.login,
password=conn.password,
)
- self.log.info('Get the connection to druid broker on %s using user %s', conn.host, conn.login)
+ self.log.info("Get the connection to druid broker on %s using user %s", conn.host, conn.login)
return druid_broker_conn
def get_uri(self) -> str:
@@ -162,10 +166,10 @@ def get_uri(self) -> str:
conn = self.get_connection(getattr(self, self.conn_name_attr))
host = conn.host
if conn.port is not None:
- host += f':{conn.port}'
- conn_type = 'druid' if not conn.conn_type else conn.conn_type
- endpoint = conn.extra_dejson.get('endpoint', 'druid/v2/sql')
- return f'{conn_type}://{host}/{endpoint}'
+ host += f":{conn.port}"
+ conn_type = conn.conn_type or "druid"
+ endpoint = conn.extra_dejson.get("endpoint", "druid/v2/sql")
+ return f"{conn_type}://{host}/{endpoint}"
def set_autocommit(self, conn: connect, autocommit: bool) -> NotImplementedError:
raise NotImplementedError()
@@ -173,8 +177,8 @@ def set_autocommit(self, conn: connect, autocommit: bool) -> NotImplementedError
def insert_rows(
self,
table: str,
- rows: Iterable[Tuple[str]],
- target_fields: Optional[Iterable[str]] = None,
+ rows: Iterable[tuple[str]],
+ target_fields: Iterable[str] | None = None,
commit_every: int = 1000,
replace: bool = False,
**kwargs: Any,
diff --git a/airflow/providers/apache/druid/operators/druid.py b/airflow/providers/apache/druid/operators/druid.py
index 9fd8b3595af5c..2ef39886dab0a 100644
--- a/airflow/providers/apache/druid/operators/druid.py
+++ b/airflow/providers/apache/druid/operators/druid.py
@@ -15,8 +15,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
-from typing import TYPE_CHECKING, Any, Optional, Sequence
+from typing import TYPE_CHECKING, Any, Sequence
from airflow.models import BaseOperator
from airflow.providers.apache.druid.hooks.druid import DruidHook
@@ -37,17 +38,17 @@ class DruidOperator(BaseOperator):
:param max_ingestion_time: The maximum ingestion time before assuming the job failed
"""
- template_fields: Sequence[str] = ('json_index_file',)
- template_ext: Sequence[str] = ('.json',)
- template_fields_renderers = {'json_index_file': 'json'}
+ template_fields: Sequence[str] = ("json_index_file",)
+ template_ext: Sequence[str] = (".json",)
+ template_fields_renderers = {"json_index_file": "json"}
def __init__(
self,
*,
json_index_file: str,
- druid_ingest_conn_id: str = 'druid_ingest_default',
+ druid_ingest_conn_id: str = "druid_ingest_default",
timeout: int = 1,
- max_ingestion_time: Optional[int] = None,
+ max_ingestion_time: int | None = None,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
@@ -56,7 +57,7 @@ def __init__(
self.timeout = timeout
self.max_ingestion_time = max_ingestion_time
- def execute(self, context: "Context") -> None:
+ def execute(self, context: Context) -> None:
hook = DruidHook(
druid_ingest_conn_id=self.conn_id,
timeout=self.timeout,
diff --git a/airflow/providers/apache/druid/operators/druid_check.py b/airflow/providers/apache/druid/operators/druid_check.py
index 33a4151350e55..84c00c4d61b0f 100644
--- a/airflow/providers/apache/druid/operators/druid_check.py
+++ b/airflow/providers/apache/druid/operators/druid_check.py
@@ -15,21 +15,23 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
import warnings
-from airflow.operators.sql import SQLCheckOperator
+from airflow.providers.common.sql.operators.sql import SQLCheckOperator
class DruidCheckOperator(SQLCheckOperator):
"""
This class is deprecated.
- Please use `airflow.operators.sql.SQLCheckOperator`.
+ Please use `airflow.providers.common.sql.operators.sql.SQLCheckOperator`.
"""
- def __init__(self, druid_broker_conn_id: str = 'druid_broker_default', **kwargs):
+ def __init__(self, druid_broker_conn_id: str = "druid_broker_default", **kwargs):
warnings.warn(
"""This class is deprecated.
- Please use `airflow.operators.sql.SQLCheckOperator`.""",
+ Please use `airflow.providers.common.sql.operators.sql.SQLCheckOperator`.""",
DeprecationWarning,
stacklevel=2,
)
diff --git a/airflow/providers/apache/druid/provider.yaml b/airflow/providers/apache/druid/provider.yaml
index 214a1f15425ad..34e577e24b83d 100644
--- a/airflow/providers/apache/druid/provider.yaml
+++ b/airflow/providers/apache/druid/provider.yaml
@@ -22,6 +22,11 @@ description: |
`Apache Druid `__.
versions:
+ - 3.3.0
+ - 3.2.1
+ - 3.2.0
+ - 3.1.0
+ - 3.0.0
- 2.3.3
- 2.3.2
- 2.3.1
@@ -35,8 +40,10 @@ versions:
- 1.0.1
- 1.0.0
-additional-dependencies:
- - apache-airflow>=2.1.0
+dependencies:
+ - apache-airflow>=2.3.0
+ - apache-airflow-providers-common-sql>=1.3.1
+ - pydruid>=0.4.1
integrations:
- integration-name: Apache Druid
@@ -57,8 +64,6 @@ hooks:
python-modules:
- airflow.providers.apache.druid.hooks.druid
-hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+
- - airflow.providers.apache.druid.hooks.druid.DruidDbApiHook
connection-types:
- hook-class-name: airflow.providers.apache.druid.hooks.druid.DruidDbApiHook
diff --git a/airflow/providers/apache/druid/transfers/hive_to_druid.py b/airflow/providers/apache/druid/transfers/hive_to_druid.py
index dc5f74109acf1..4c7523dd8d434 100644
--- a/airflow/providers/apache/druid/transfers/hive_to_druid.py
+++ b/airflow/providers/apache/druid/transfers/hive_to_druid.py
@@ -15,10 +15,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""This module contains operator to move data from Hive to Druid."""
+from __future__ import annotations
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence
+from typing import TYPE_CHECKING, Any, Sequence
from airflow.models import BaseOperator
from airflow.providers.apache.druid.hooks.druid import DruidHook
@@ -64,9 +64,9 @@ class HiveToDruidOperator(BaseOperator):
:param job_properties: additional properties for job
"""
- template_fields: Sequence[str] = ('sql', 'intervals')
- template_ext: Sequence[str] = ('.sql',)
- template_fields_renderers = {'sql': 'hql'}
+ template_fields: Sequence[str] = ("sql", "intervals")
+ template_ext: Sequence[str] = (".sql",)
+ template_fields_renderers = {"sql": "hql"}
def __init__(
self,
@@ -74,25 +74,25 @@ def __init__(
sql: str,
druid_datasource: str,
ts_dim: str,
- metric_spec: Optional[List[Any]] = None,
- hive_cli_conn_id: str = 'hive_cli_default',
- druid_ingest_conn_id: str = 'druid_ingest_default',
- metastore_conn_id: str = 'metastore_default',
- hadoop_dependency_coordinates: Optional[List[str]] = None,
- intervals: Optional[List[Any]] = None,
+ metric_spec: list[Any] | None = None,
+ hive_cli_conn_id: str = "hive_cli_default",
+ druid_ingest_conn_id: str = "druid_ingest_default",
+ metastore_conn_id: str = "metastore_default",
+ hadoop_dependency_coordinates: list[str] | None = None,
+ intervals: list[Any] | None = None,
num_shards: float = -1,
target_partition_size: int = -1,
query_granularity: str = "NONE",
segment_granularity: str = "DAY",
- hive_tblproperties: Optional[Dict[Any, Any]] = None,
- job_properties: Optional[Dict[Any, Any]] = None,
+ hive_tblproperties: dict[Any, Any] | None = None,
+ job_properties: dict[Any, Any] | None = None,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.sql = sql
self.druid_datasource = druid_datasource
self.ts_dim = ts_dim
- self.intervals = intervals or ['{{ ds }}/{{ tomorrow_ds }}']
+ self.intervals = intervals or ["{{ ds }}/{{ tomorrow_ds }}"]
self.num_shards = num_shards
self.target_partition_size = target_partition_size
self.query_granularity = query_granularity
@@ -105,12 +105,12 @@ def __init__(
self.hive_tblproperties = hive_tblproperties or {}
self.job_properties = job_properties
- def execute(self, context: "Context") -> None:
+ def execute(self, context: Context) -> None:
hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id)
self.log.info("Extracting data from Hive")
- hive_table = 'druid.' + context['task_instance_key_str'].replace('.', '_')
- sql = self.sql.strip().strip(';')
- tblproperties = ''.join(f", '{k}' = '{v}'" for k, v in self.hive_tblproperties.items())
+ hive_table = "druid." + context["task_instance_key_str"].replace(".", "_")
+ sql = self.sql.strip().strip(";")
+ tblproperties = "".join(f", '{k}' = '{v}'" for k, v in self.hive_tblproperties.items())
hql = f"""\
SET mapred.output.compress=false;
SET hive.exec.compress.output=false;
@@ -152,7 +152,7 @@ def execute(self, context: "Context") -> None:
hql = f"DROP TABLE IF EXISTS {hive_table}"
hive.run_cli(hql)
- def construct_ingest_query(self, static_path: str, columns: List[str]) -> Dict[str, Any]:
+ def construct_ingest_query(self, static_path: str, columns: list[str]) -> dict[str, Any]:
"""
Builds an ingest query for an HDFS TSV load.
@@ -170,13 +170,13 @@ def construct_ingest_query(self, static_path: str, columns: List[str]) -> Dict[s
else:
num_shards = -1
- metric_names = [m['fieldName'] for m in self.metric_spec if m['type'] != 'count']
+ metric_names = [m["fieldName"] for m in self.metric_spec if m["type"] != "count"]
# Take all the columns, which are not the time dimension
# or a metric, as the dimension columns
dimensions = [c for c in columns if c not in metric_names and c != self.ts_dim]
- ingest_query_dict: Dict[str, Any] = {
+ ingest_query_dict: dict[str, Any] = {
"type": "index_hadoop",
"spec": {
"dataSchema": {
@@ -220,9 +220,9 @@ def construct_ingest_query(self, static_path: str, columns: List[str]) -> Dict[s
}
if self.job_properties:
- ingest_query_dict['spec']['tuningConfig']['jobProperties'].update(self.job_properties)
+ ingest_query_dict["spec"]["tuningConfig"]["jobProperties"].update(self.job_properties)
if self.hadoop_dependency_coordinates:
- ingest_query_dict['hadoopDependencyCoordinates'] = self.hadoop_dependency_coordinates
+ ingest_query_dict["hadoopDependencyCoordinates"] = self.hadoop_dependency_coordinates
return ingest_query_dict
diff --git a/airflow/providers/apache/hdfs/.latest-doc-only-change.txt b/airflow/providers/apache/hdfs/.latest-doc-only-change.txt
index e7c3c940c9c77..ff7136e07d744 100644
--- a/airflow/providers/apache/hdfs/.latest-doc-only-change.txt
+++ b/airflow/providers/apache/hdfs/.latest-doc-only-change.txt
@@ -1 +1 @@
-602abe8394fafe7de54df7e73af56de848cdf617
+06acf40a4337759797f666d5bb27a5a393b74fed
diff --git a/airflow/providers/apache/hdfs/CHANGELOG.rst b/airflow/providers/apache/hdfs/CHANGELOG.rst
index e71546e3a0c7c..b7f11a87c1024 100644
--- a/airflow/providers/apache/hdfs/CHANGELOG.rst
+++ b/airflow/providers/apache/hdfs/CHANGELOG.rst
@@ -16,9 +16,74 @@
under the License.
+.. NOTE TO CONTRIBUTORS:
+ Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes
+ and you want to add an explanation to the users on how they are supposed to deal with them.
+ The changelog is updated and maintained semi-automatically by release manager.
+
Changelog
---------
+3.2.0
+.....
+
+This release of provider is only available for Airflow 2.3+ as explained in the
+`Apache Airflow providers support policy `_.
+
+Misc
+~~~~
+
+* ``Move min airflow version to 2.3.0 for all providers (#27196)``
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Update old style typing (#26872)``
+ * ``Enable string normalization in python formatting - providers (#27205)``
+ * ``Update docs for September Provider's release (#26731)``
+ * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)``
+
+3.1.0
+.....
+
+Features
+~~~~~~~~
+
+* ``Adding Authentication to webhdfs sensor (#25110)``
+
+3.0.1
+.....
+
+Bug Fixes
+~~~~~~~~~
+
+* ``'WebHDFSHook' Bugfix/optional port (#24550)``
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Move provider dependencies to inside provider folders (#24672)``
+ * ``Remove 'hook-class-names' from provider.yaml (#24702)``
+
+3.0.0
+.....
+
+Breaking changes
+~~~~~~~~~~~~~~~~
+
+* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow
+ providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers
+
+Misc
+~~~~
+
+* ``chore: Refactoring and Cleaning Apache Providers (#24219)``
+
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Add explanatory note for contributors about updating Changelog (#24229)``
+ * ``Prepare docs for May 2022 provider's release (#24231)``
+ * ``Update package description to remove double min-airflow specification (#24292)``
+
2.2.3
.....
diff --git a/airflow/providers/apache/hdfs/hooks/hdfs.py b/airflow/providers/apache/hdfs/hooks/hdfs.py
index 6d1ce3d010125..0e98320cb77d7 100644
--- a/airflow/providers/apache/hdfs/hooks/hdfs.py
+++ b/airflow/providers/apache/hdfs/hooks/hdfs.py
@@ -16,7 +16,9 @@
# specific language governing permissions and limitations
# under the License.
"""Hook for HDFS operations"""
-from typing import Any, Optional
+from __future__ import annotations
+
+from typing import Any
from airflow.configuration import conf
from airflow.exceptions import AirflowException
@@ -43,20 +45,20 @@ class HDFSHook(BaseHook):
:param autoconfig: use snakebite's automatically configured client
"""
- conn_name_attr = 'hdfs_conn_id'
- default_conn_name = 'hdfs_default'
- conn_type = 'hdfs'
- hook_name = 'HDFS'
+ conn_name_attr = "hdfs_conn_id"
+ default_conn_name = "hdfs_default"
+ conn_type = "hdfs"
+ hook_name = "HDFS"
def __init__(
- self, hdfs_conn_id: str = 'hdfs_default', proxy_user: Optional[str] = None, autoconfig: bool = False
+ self, hdfs_conn_id: str = "hdfs_default", proxy_user: str | None = None, autoconfig: bool = False
):
super().__init__()
if not snakebite_loaded:
raise ImportError(
- 'This HDFSHook implementation requires snakebite, but '
- 'snakebite is not compatible with Python 3 '
- '(as of August 2015). Please help by submitting a PR!'
+ "This HDFSHook implementation requires snakebite, but "
+ "snakebite is not compatible with Python 3 "
+ "(as of August 2015). Please help by submitting a PR!"
)
self.hdfs_conn_id = hdfs_conn_id
self.proxy_user = proxy_user
@@ -68,7 +70,7 @@ def get_conn(self) -> Any:
# take the first.
effective_user = self.proxy_user
autoconfig = self.autoconfig
- use_sasl = conf.get('core', 'security') == 'kerberos'
+ use_sasl = conf.get("core", "security") == "kerberos"
try:
connections = self.get_connections(self.hdfs_conn_id)
@@ -76,8 +78,8 @@ def get_conn(self) -> Any:
if not effective_user:
effective_user = connections[0].login
if not autoconfig:
- autoconfig = connections[0].extra_dejson.get('autoconfig', False)
- hdfs_namenode_principal = connections[0].extra_dejson.get('hdfs_namenode_principal')
+ autoconfig = connections[0].extra_dejson.get("autoconfig", False)
+ hdfs_namenode_principal = connections[0].extra_dejson.get("hdfs_namenode_principal")
except AirflowException:
if not autoconfig:
raise
diff --git a/airflow/providers/apache/hdfs/hooks/webhdfs.py b/airflow/providers/apache/hdfs/hooks/webhdfs.py
index a32206ba9bef8..67608c481cd2a 100644
--- a/airflow/providers/apache/hdfs/hooks/webhdfs.py
+++ b/airflow/providers/apache/hdfs/hooks/webhdfs.py
@@ -16,9 +16,11 @@
# specific language governing permissions and limitations
# under the License.
"""Hook for Web HDFS"""
+from __future__ import annotations
+
import logging
import socket
-from typing import Any, Optional
+from typing import Any
import requests
from hdfs import HdfsError, InsecureClient
@@ -50,7 +52,7 @@ class WebHDFSHook(BaseHook):
:param proxy_user: The user used to authenticate.
"""
- def __init__(self, webhdfs_conn_id: str = 'webhdfs_default', proxy_user: Optional[str] = None):
+ def __init__(self, webhdfs_conn_id: str = "webhdfs_default", proxy_user: str | None = None):
super().__init__()
self.webhdfs_conn_id = webhdfs_conn_id
self.proxy_user = proxy_user
@@ -59,7 +61,6 @@ def get_conn(self) -> Any:
"""
Establishes a connection depending on the security mode set via config or environment variable.
:return: a hdfscli InsecureClient or KerberosClient object.
- :rtype: hdfs.InsecureClient or hdfs.ext.kerberos.KerberosClient
"""
connection = self._find_valid_server()
if connection is None:
@@ -68,42 +69,54 @@ def get_conn(self) -> Any:
def _find_valid_server(self) -> Any:
connection = self.get_connection(self.webhdfs_conn_id)
- namenodes = connection.host.split(',')
+ namenodes = connection.host.split(",")
for namenode in namenodes:
host_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.log.info("Trying to connect to %s:%s", namenode, connection.port)
try:
conn_check = host_socket.connect_ex((namenode, connection.port))
if conn_check == 0:
- self.log.info('Trying namenode %s', namenode)
+ self.log.info("Trying namenode %s", namenode)
client = self._get_client(
- namenode, connection.port, connection.login, connection.extra_dejson
+ namenode,
+ connection.port,
+ connection.login,
+ connection.get_password(),
+ connection.schema,
+ connection.extra_dejson,
)
- client.status('/')
- self.log.info('Using namenode %s for hook', namenode)
+ client.status("/")
+ self.log.info("Using namenode %s for hook", namenode)
host_socket.close()
return client
else:
self.log.warning("Could not connect to %s:%s", namenode, connection.port)
except HdfsError as hdfs_error:
- self.log.info('Read operation on namenode %s failed with error: %s', namenode, hdfs_error)
+ self.log.info("Read operation on namenode %s failed with error: %s", namenode, hdfs_error)
return None
- def _get_client(self, namenode: str, port: int, login: str, extra_dejson: dict) -> Any:
- connection_str = f'http://{namenode}:{port}'
+ def _get_client(
+ self, namenode: str, port: int, login: str, password: str | None, schema: str, extra_dejson: dict
+ ) -> Any:
+ connection_str = f"http://{namenode}"
session = requests.Session()
+ if password is not None:
+ session.auth = (login, password)
- if extra_dejson.get('use_ssl', False):
- connection_str = f'https://{namenode}:{port}'
- session.verify = extra_dejson.get('verify', True)
+ if extra_dejson.get("use_ssl", "False") == "True" or extra_dejson.get("use_ssl", False):
+ connection_str = f"https://{namenode}"
+ session.verify = extra_dejson.get("verify", False)
- if _kerberos_security_mode:
- client = KerberosClient(connection_str, session=session)
- else:
- proxy_user = self.proxy_user or login
- client = InsecureClient(connection_str, user=proxy_user, session=session)
+ if port is not None:
+ connection_str += f":{port}"
- return client
+ if schema is not None:
+ connection_str += f"/{schema}"
+
+ if _kerberos_security_mode:
+ return KerberosClient(connection_str, session=session)
+ proxy_user = self.proxy_user or login
+ return InsecureClient(connection_str, user=proxy_user, session=session)
def check_for_path(self, hdfs_path: str) -> bool:
"""
@@ -111,7 +124,6 @@ def check_for_path(self, hdfs_path: str) -> bool:
:param hdfs_path: The path to check.
:return: True if the path exists and False if not.
- :rtype: bool
"""
conn = self.get_conn()
diff --git a/airflow/providers/apache/hdfs/provider.yaml b/airflow/providers/apache/hdfs/provider.yaml
index a01a734d72629..2f575d7da32da 100644
--- a/airflow/providers/apache/hdfs/provider.yaml
+++ b/airflow/providers/apache/hdfs/provider.yaml
@@ -23,6 +23,10 @@ description: |
and `WebHDFS `__.
versions:
+ - 3.2.0
+ - 3.1.0
+ - 3.0.1
+ - 3.0.0
- 2.2.3
- 2.2.2
- 2.2.1
@@ -33,8 +37,10 @@ versions:
- 1.0.1
- 1.0.0
-additional-dependencies:
- - apache-airflow>=2.1.0
+dependencies:
+ - apache-airflow>=2.3.0
+ - snakebite-py3
+ - hdfs[avro,dataframe,kerberos]>=2.0.4
integrations:
- integration-name: Hadoop Distributed File System (HDFS)
@@ -66,9 +72,6 @@ hooks:
python-modules:
- airflow.providers.apache.hdfs.hooks.webhdfs
-hook-class-names:
- # deprecated - to be removed after providers add dependency on Airflow 2.2.0+
- - airflow.providers.apache.hdfs.hooks.hdfs.HDFSHook
connection-types:
- hook-class-name: airflow.providers.apache.hdfs.hooks.hdfs.HDFSHook
diff --git a/airflow/providers/apache/hdfs/sensors/hdfs.py b/airflow/providers/apache/hdfs/sensors/hdfs.py
index a445bb688e722..5c55209f4d1de 100644
--- a/airflow/providers/apache/hdfs/sensors/hdfs.py
+++ b/airflow/providers/apache/hdfs/sensors/hdfs.py
@@ -15,10 +15,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
import logging
import re
import sys
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Pattern, Sequence, Type
+from typing import TYPE_CHECKING, Any, Pattern, Sequence
from airflow import settings
from airflow.providers.apache.hdfs.hooks.hdfs import HDFSHook
@@ -45,23 +47,23 @@ class HdfsSensor(BaseSensorOperator):
:ref:`howto/operator:HdfsSensor`
"""
- template_fields: Sequence[str] = ('filepath',)
- ui_color = settings.WEB_COLORS['LIGHTBLUE']
+ template_fields: Sequence[str] = ("filepath",)
+ ui_color = settings.WEB_COLORS["LIGHTBLUE"]
def __init__(
self,
*,
filepath: str,
- hdfs_conn_id: str = 'hdfs_default',
- ignored_ext: Optional[List[str]] = None,
+ hdfs_conn_id: str = "hdfs_default",
+ ignored_ext: list[str] | None = None,
ignore_copying: bool = True,
- file_size: Optional[int] = None,
- hook: Type[HDFSHook] = HDFSHook,
+ file_size: int | None = None,
+ hook: type[HDFSHook] = HDFSHook,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
if ignored_ext is None:
- ignored_ext = ['_COPYING_']
+ ignored_ext = ["_COPYING_"]
self.filepath = filepath
self.hdfs_conn_id = hdfs_conn_id
self.file_size = file_size
@@ -70,7 +72,7 @@ def __init__(
self.hook = hook
@staticmethod
- def filter_for_filesize(result: List[Dict[Any, Any]], size: Optional[int] = None) -> List[Dict[Any, Any]]:
+ def filter_for_filesize(result: list[dict[Any, Any]], size: int | None = None) -> list[dict[Any, Any]]:
"""
Will test the filepath result and test if its size is at least self.filesize
@@ -79,16 +81,16 @@ def filter_for_filesize(result: List[Dict[Any, Any]], size: Optional[int] = None
:return: (bool) depending on the matching criteria
"""
if size:
- log.debug('Filtering for file size >= %s in files: %s', size, map(lambda x: x['path'], result))
+ log.debug("Filtering for file size >= %s in files: %s", size, map(lambda x: x["path"], result))
size *= settings.MEGABYTE
- result = [x for x in result if x['length'] >= size]
- log.debug('HdfsSensor.poke: after size filter result is %s', result)
+ result = [x for x in result if x["length"] >= size]
+ log.debug("HdfsSensor.poke: after size filter result is %s", result)
return result
@staticmethod
def filter_for_ignored_ext(
- result: List[Dict[Any, Any]], ignored_ext: List[str], ignore_copying: bool
- ) -> List[Dict[Any, Any]]:
+ result: list[dict[Any, Any]], ignored_ext: list[str], ignore_copying: bool
+ ) -> list[dict[Any, Any]]:
"""
Will filter if instructed to do so the result to remove matching criteria
@@ -96,24 +98,23 @@ def filter_for_ignored_ext(
:param ignored_ext: list of ignored extensions
:param ignore_copying: shall we ignore ?
:return: list of dicts which were not removed
- :rtype: list[dict]
"""
if ignore_copying:
- regex_builder = r"^.*\.(%s$)$" % '$|'.join(ignored_ext)
+ regex_builder = r"^.*\.(%s$)$" % "$|".join(ignored_ext)
ignored_extensions_regex = re.compile(regex_builder)
log.debug(
- 'Filtering result for ignored extensions: %s in files %s',
+ "Filtering result for ignored extensions: %s in files %s",
ignored_extensions_regex.pattern,
- map(lambda x: x['path'], result),
+ map(lambda x: x["path"], result),
)
- result = [x for x in result if not ignored_extensions_regex.match(x['path'])]
- log.debug('HdfsSensor.poke: after ext filter result is %s', result)
+ result = [x for x in result if not ignored_extensions_regex.match(x["path"])]
+ log.debug("HdfsSensor.poke: after ext filter result is %s", result)
return result
- def poke(self, context: "Context") -> bool:
+ def poke(self, context: Context) -> bool:
"""Get a snakebite client connection and check for file."""
sb_client = self.hook(self.hdfs_conn_id).get_conn()
- self.log.info('Poking for file %s', self.filepath)
+ self.log.info("Poking for file %s", self.filepath)
try:
# IMOO it's not right here, as there is no raise of any kind.
# if the filepath is let's say '/data/mydirectory',
@@ -121,7 +122,7 @@ def poke(self, context: "Context") -> bool:
# it's not correct as the directory exists and sb_client does not raise any error
# here is a quick fix
result = sb_client.ls([self.filepath], include_toplevel=False)
- self.log.debug('HdfsSensor.poke: result is %s', result)
+ self.log.debug("HdfsSensor.poke: result is %s", result)
result = self.filter_for_ignored_ext(result, self.ignored_ext, self.ignore_copying)
result = self.filter_for_filesize(result, self.file_size)
return bool(result)
@@ -144,7 +145,7 @@ def __init__(self, regex: Pattern[str], *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.regex = regex
- def poke(self, context: "Context") -> bool:
+ def poke(self, context: Context) -> bool:
"""
Poke matching files in a directory with self.regex
@@ -152,12 +153,12 @@ def poke(self, context: "Context") -> bool:
"""
sb_client = self.hook(self.hdfs_conn_id).get_conn()
self.log.info(
- 'Poking for %s to be a directory with files matching %s', self.filepath, self.regex.pattern
+ "Poking for %s to be a directory with files matching %s", self.filepath, self.regex.pattern
)
result = [
f
for f in sb_client.ls([self.filepath], include_toplevel=False)
- if f['file_type'] == 'f' and self.regex.match(f['path'].replace(f'{self.filepath}/', ''))
+ if f["file_type"] == "f" and self.regex.match(f["path"].replace(f"{self.filepath}/", ""))
]
result = self.filter_for_ignored_ext(result, self.ignored_ext, self.ignore_copying)
result = self.filter_for_filesize(result, self.file_size)
@@ -177,7 +178,7 @@ def __init__(self, be_empty: bool = False, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
self.be_empty = be_empty
- def poke(self, context: "Context") -> bool:
+ def poke(self, context: Context) -> bool:
"""
Poke for a non empty directory
@@ -188,9 +189,9 @@ def poke(self, context: "Context") -> bool:
result = self.filter_for_ignored_ext(result, self.ignored_ext, self.ignore_copying)
result = self.filter_for_filesize(result, self.file_size)
if self.be_empty:
- self.log.info('Poking for filepath %s to a empty directory', self.filepath)
- return len(result) == 1 and result[0]['path'] == self.filepath
+ self.log.info("Poking for filepath %s to a empty directory", self.filepath)
+ return len(result) == 1 and result[0]["path"] == self.filepath
else:
- self.log.info('Poking for filepath %s to a non empty directory', self.filepath)
+ self.log.info("Poking for filepath %s to a non empty directory", self.filepath)
result.pop(0)
- return bool(result) and result[0]['file_type'] == 'f'
+ return bool(result) and result[0]["file_type"] == "f"
diff --git a/airflow/providers/apache/hdfs/sensors/web_hdfs.py b/airflow/providers/apache/hdfs/sensors/web_hdfs.py
index adacdefecad07..38e1047679b03 100644
--- a/airflow/providers/apache/hdfs/sensors/web_hdfs.py
+++ b/airflow/providers/apache/hdfs/sensors/web_hdfs.py
@@ -15,6 +15,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
from typing import TYPE_CHECKING, Any, Sequence
from airflow.sensors.base import BaseSensorOperator
@@ -26,16 +28,16 @@
class WebHdfsSensor(BaseSensorOperator):
"""Waits for a file or folder to land in HDFS"""
- template_fields: Sequence[str] = ('filepath',)
+ template_fields: Sequence[str] = ("filepath",)
- def __init__(self, *, filepath: str, webhdfs_conn_id: str = 'webhdfs_default', **kwargs: Any) -> None:
+ def __init__(self, *, filepath: str, webhdfs_conn_id: str = "webhdfs_default", **kwargs: Any) -> None:
super().__init__(**kwargs)
self.filepath = filepath
self.webhdfs_conn_id = webhdfs_conn_id
- def poke(self, context: "Context") -> bool:
+ def poke(self, context: Context) -> bool:
from airflow.providers.apache.hdfs.hooks.webhdfs import WebHDFSHook
hook = WebHDFSHook(self.webhdfs_conn_id)
- self.log.info('Poking for file %s', self.filepath)
+ self.log.info("Poking for file %s", self.filepath)
return hook.check_for_path(hdfs_path=self.filepath)
diff --git a/airflow/providers/apache/hive/CHANGELOG.rst b/airflow/providers/apache/hive/CHANGELOG.rst
index bb28c620b4126..19543da737b5b 100644
--- a/airflow/providers/apache/hive/CHANGELOG.rst
+++ b/airflow/providers/apache/hive/CHANGELOG.rst
@@ -16,9 +16,104 @@
under the License.
+.. NOTE TO CONTRIBUTORS:
+ Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes
+ and you want to add an explanation to the users on how they are supposed to deal with them.
+ The changelog is updated and maintained semi-automatically by release manager.
+
Changelog
---------
+4.1.0
+.....
+
+This release of provider is only available for Airflow 2.3+ as explained in the
+`Apache Airflow providers support policy `_.
+
+Misc
+~~~~
+
+* ``Move min airflow version to 2.3.0 for all providers (#27196)``
+
+Bug Fixes
+~~~~~~~~~
+
+* ``Filter out invalid schemas in Hive hook (#27647)``
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Update old style typing (#26872)``
+ * ``Enable string normalization in python formatting - providers (#27205)``
+
+4.0.1
+.....
+
+Misc
+~~~~
+
+* ``Add common-sql lower bound for common-sql (#25789)``
+
+.. Review and move the new changes to one of the sections above:
+ * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)``
+
+4.0.0
+.....
+
+Breaking Changes
+~~~~~~~~~~~~~~~~
+
+* The ``hql`` parameter in ``get_records`` of ``HiveServer2Hook`` has been renamed to sql to match the
+ ``get_records`` DbApiHook signature. If you used it as a positional parameter, this is no change for you,
+ but if you used it as keyword one, you need to rename it.
+* ``hive_conf`` parameter has been renamed to ``parameters`` and it is now second parameter, to match ``get_records``
+ signature from the DbApiHook. You need to rename it if you used it.
+* ``schema`` parameter in ``get_records`` is an optional kwargs extra parameter that you can add, to match
+ the schema of ``get_records`` from DbApiHook.
+
+* ``Deprecate hql parameters and synchronize DBApiHook method APIs (#25299)``
+* ``Remove Smart Sensors (#25507)``
+
+
+3.1.0
+.....
+
+Features
+~~~~~~~~
+
+* ``Move all SQL classes to common-sql provider (#24836)``
+
+Bug Fixes
+~~~~~~~~~
+
+* ``fix connection extra parameter 'auth_mechanism' in 'HiveMetastoreHook' and 'HiveServer2Hook' (#24713)``
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Move provider dependencies to inside provider folders (#24672)``
+ * ``Remove 'hook-class-names' from provider.yaml (#24702)``
+
+3.0.0
+.....
+
+Breaking changes
+~~~~~~~~~~~~~~~~
+
+* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow
+ providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers
+
+Misc
+~~~~
+
+* ``chore: Refactoring and Cleaning Apache Providers (#24219)``
+* ``AIP-47 - Migrate hive DAGs to new design #22439 (#24204)``
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Add typing for airflow/configuration.py (#23716)``
+ * ``Add explanatory note for contributors about updating Changelog (#24229)``
+ * ``Prepare docs for May 2022 provider's release (#24231)``
+ * ``Update package description to remove double min-airflow specification (#24292)``
+
2.3.3
.....
diff --git a/airflow/providers/apache/hive/hooks/hive.py b/airflow/providers/apache/hive/hooks/hive.py
index bccd279f77cf0..edd8a4b372c06 100644
--- a/airflow/providers/apache/hive/hooks/hive.py
+++ b/airflow/providers/apache/hive/hooks/hive.py
@@ -15,15 +15,18 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
import contextlib
import os
import re
import socket
import subprocess
import time
+import warnings
from collections import OrderedDict
from tempfile import NamedTemporaryFile, TemporaryDirectory
-from typing import Any, Dict, List, Optional, Union
+from typing import Any, Iterable, Mapping
import pandas
import unicodecsv as csv
@@ -31,15 +34,15 @@
from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
-from airflow.hooks.dbapi import DbApiHook
+from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.security import utils
from airflow.utils.helpers import as_flattened_list
from airflow.utils.operator_helpers import AIRFLOW_VAR_NAME_FORMAT_MAPPING
-HIVE_QUEUE_PRIORITIES = ['VERY_HIGH', 'HIGH', 'NORMAL', 'LOW', 'VERY_LOW']
+HIVE_QUEUE_PRIORITIES = ["VERY_HIGH", "HIGH", "NORMAL", "LOW", "VERY_LOW"]
-def get_context_from_env_var() -> Dict[Any, Any]:
+def get_context_from_env_var() -> dict[Any, Any]:
"""
Extract context from env variable, e.g. dag_id, task_id and execution_date,
so that they can be used inside BashOperator and PythonOperator.
@@ -47,7 +50,7 @@ def get_context_from_env_var() -> Dict[Any, Any]:
:return: The context of interest.
"""
return {
- format_map['default']: os.environ.get(format_map['env_var_format'], '')
+ format_map["default"]: os.environ.get(format_map["env_var_format"], "")
for format_map in AIRFLOW_VAR_NAME_FORMAT_MAPPING.values()
}
@@ -77,24 +80,24 @@ class HiveCliHook(BaseHook):
This can make monitoring easier.
"""
- conn_name_attr = 'hive_cli_conn_id'
- default_conn_name = 'hive_cli_default'
- conn_type = 'hive_cli'
- hook_name = 'Hive Client Wrapper'
+ conn_name_attr = "hive_cli_conn_id"
+ default_conn_name = "hive_cli_default"
+ conn_type = "hive_cli"
+ hook_name = "Hive Client Wrapper"
def __init__(
self,
hive_cli_conn_id: str = default_conn_name,
- run_as: Optional[str] = None,
- mapred_queue: Optional[str] = None,
- mapred_queue_priority: Optional[str] = None,
- mapred_job_name: Optional[str] = None,
+ run_as: str | None = None,
+ mapred_queue: str | None = None,
+ mapred_queue_priority: str | None = None,
+ mapred_job_name: str | None = None,
) -> None:
super().__init__()
conn = self.get_connection(hive_cli_conn_id)
- self.hive_cli_params: str = conn.extra_dejson.get('hive_cli_params', '')
- self.use_beeline: bool = conn.extra_dejson.get('use_beeline', False)
- self.auth = conn.extra_dejson.get('auth', 'noSasl')
+ self.hive_cli_params: str = conn.extra_dejson.get("hive_cli_params", "")
+ self.use_beeline: bool = conn.extra_dejson.get("use_beeline", False)
+ self.auth = conn.extra_dejson.get("auth", "noSasl")
self.conn = conn
self.run_as = run_as
self.sub_process: Any = None
@@ -106,7 +109,7 @@ def __init__(
f"Invalid Mapred Queue Priority. Valid values are: {', '.join(HIVE_QUEUE_PRIORITIES)}"
)
- self.mapred_queue = mapred_queue or conf.get('hive', 'default_hive_mapred_queue')
+ self.mapred_queue = mapred_queue or conf.get("hive", "default_hive_mapred_queue")
self.mapred_queue_priority = mapred_queue_priority
self.mapred_job_name = mapred_job_name
@@ -114,7 +117,7 @@ def _get_proxy_user(self) -> str:
"""This function set the proper proxy_user value in case the user overwrite the default."""
conn = self.conn
- proxy_user_value: str = conn.extra_dejson.get('proxy_user', "")
+ proxy_user_value: str = conn.extra_dejson.get("proxy_user", "")
if proxy_user_value == "login" and conn.login:
return f"hive.server2.proxy.user={conn.login}"
if proxy_user_value == "owner" and self.run_as:
@@ -123,17 +126,17 @@ def _get_proxy_user(self) -> str:
return f"hive.server2.proxy.user={proxy_user_value}"
return proxy_user_value # The default proxy user (undefined)
- def _prepare_cli_cmd(self) -> List[Any]:
+ def _prepare_cli_cmd(self) -> list[Any]:
"""This function creates the command list from available information"""
conn = self.conn
- hive_bin = 'hive'
+ hive_bin = "hive"
cmd_extra = []
if self.use_beeline:
- hive_bin = 'beeline'
+ hive_bin = "beeline"
jdbc_url = f"jdbc:hive2://{conn.host}:{conn.port}/{conn.schema}"
- if conf.get('core', 'security') == 'kerberos':
- template = conn.extra_dejson.get('principal', "hive/_HOST@EXAMPLE.COM")
+ if conf.get("core", "security") == "kerberos":
+ template = conn.extra_dejson.get("principal", "hive/_HOST@EXAMPLE.COM")
if "_HOST" in template:
template = utils.replace_hostname_pattern(utils.get_components(template))
@@ -145,18 +148,18 @@ def _prepare_cli_cmd(self) -> List[Any]:
jdbc_url = f'"{jdbc_url}"'
- cmd_extra += ['-u', jdbc_url]
+ cmd_extra += ["-u", jdbc_url]
if conn.login:
- cmd_extra += ['-n', conn.login]
+ cmd_extra += ["-n", conn.login]
if conn.password:
- cmd_extra += ['-p', conn.password]
+ cmd_extra += ["-p", conn.password]
hive_params_list = self.hive_cli_params.split()
return [hive_bin] + cmd_extra + hive_params_list
@staticmethod
- def _prepare_hiveconf(d: Dict[Any, Any]) -> List[Any]:
+ def _prepare_hiveconf(d: dict[Any, Any]) -> list[Any]:
"""
This function prepares a list of hiveconf params
from a dictionary of key value pairs.
@@ -177,9 +180,9 @@ def _prepare_hiveconf(d: Dict[Any, Any]) -> List[Any]:
def run_cli(
self,
hql: str,
- schema: Optional[str] = None,
+ schema: str | None = None,
verbose: bool = True,
- hive_conf: Optional[Dict[Any, Any]] = None,
+ hive_conf: dict[Any, Any] | None = None,
) -> Any:
"""
Run an hql statement using the hive cli. If hive_conf is specified
@@ -201,13 +204,15 @@ def run_cli(
"""
conn = self.conn
schema = schema or conn.schema
+ if "!" in schema or ";" in schema:
+ raise RuntimeError(f"The schema `{schema}` contains invalid characters (!;)")
if schema:
hql = f"USE {schema};\n{hql}"
- with TemporaryDirectory(prefix='airflow_hiveop_') as tmp_dir:
+ with TemporaryDirectory(prefix="airflow_hiveop_") as tmp_dir:
with NamedTemporaryFile(dir=tmp_dir) as f:
- hql += '\n'
- f.write(hql.encode('UTF-8'))
+ hql += "\n"
+ f.write(hql.encode("UTF-8"))
f.flush()
hive_cmd = self._prepare_cli_cmd()
env_context = get_context_from_env_var()
@@ -218,25 +223,25 @@ def run_cli(
if self.mapred_queue:
hive_conf_params.extend(
[
- '-hiveconf',
- f'mapreduce.job.queuename={self.mapred_queue}',
- '-hiveconf',
- f'mapred.job.queue.name={self.mapred_queue}',
- '-hiveconf',
- f'tez.queue.name={self.mapred_queue}',
+ "-hiveconf",
+ f"mapreduce.job.queuename={self.mapred_queue}",
+ "-hiveconf",
+ f"mapred.job.queue.name={self.mapred_queue}",
+ "-hiveconf",
+ f"tez.queue.name={self.mapred_queue}",
]
)
if self.mapred_queue_priority:
hive_conf_params.extend(
- ['-hiveconf', f'mapreduce.job.priority={self.mapred_queue_priority}']
+ ["-hiveconf", f"mapreduce.job.priority={self.mapred_queue_priority}"]
)
if self.mapred_job_name:
- hive_conf_params.extend(['-hiveconf', f'mapred.job.name={self.mapred_job_name}'])
+ hive_conf_params.extend(["-hiveconf", f"mapred.job.name={self.mapred_job_name}"])
hive_cmd.extend(hive_conf_params)
- hive_cmd.extend(['-f', f.name])
+ hive_cmd.extend(["-f", f.name])
if verbose:
self.log.info("%s", " ".join(hive_cmd))
@@ -244,14 +249,14 @@ def run_cli(
hive_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd=tmp_dir, close_fds=True
)
self.sub_process = sub_process
- stdout = ''
+ stdout = ""
while True:
line = sub_process.stdout.readline()
if not line:
break
- stdout += line.decode('UTF-8')
+ stdout += line.decode("UTF-8")
if verbose:
- self.log.info(line.decode('UTF-8').strip())
+ self.log.info(line.decode("UTF-8").strip())
sub_process.wait()
if sub_process.returncode:
@@ -262,37 +267,37 @@ def run_cli(
def test_hql(self, hql: str) -> None:
"""Test an hql statement using the hive cli and EXPLAIN"""
create, insert, other = [], [], []
- for query in hql.split(';'): # naive
+ for query in hql.split(";"): # naive
query_original = query
query = query.lower().strip()
- if query.startswith('create table'):
+ if query.startswith("create table"):
create.append(query_original)
- elif query.startswith(('set ', 'add jar ', 'create temporary function')):
+ elif query.startswith(("set ", "add jar ", "create temporary function")):
other.append(query_original)
- elif query.startswith('insert'):
+ elif query.startswith("insert"):
insert.append(query_original)
- other_ = ';'.join(other)
+ other_ = ";".join(other)
for query_set in [create, insert]:
for query in query_set:
- query_preview = ' '.join(query.split())[:50]
+ query_preview = " ".join(query.split())[:50]
self.log.info("Testing HQL [%s (...)]", query_preview)
if query_set == insert:
- query = other_ + '; explain ' + query
+ query = other_ + "; explain " + query
else:
- query = 'explain ' + query
+ query = "explain " + query
try:
self.run_cli(query, verbose=False)
except AirflowException as e:
- message = e.args[0].split('\n')[-2]
+ message = e.args[0].split("\n")[-2]
self.log.info(message)
- error_loc = re.search(r'(\d+):(\d+)', message)
+ error_loc = re.search(r"(\d+):(\d+)", message)
if error_loc and error_loc.group(1).isdigit():
lst = int(error_loc.group(1))
begin = max(lst - 2, 0)
- end = min(lst + 3, len(query.split('\n')))
- context = '\n'.join(query.split('\n')[begin:end])
+ end = min(lst + 3, len(query.split("\n")))
+ context = "\n".join(query.split("\n")[begin:end])
self.log.info("Context :\n %s", context)
else:
self.log.info("SUCCESS")
@@ -301,9 +306,9 @@ def load_df(
self,
df: pandas.DataFrame,
table: str,
- field_dict: Optional[Dict[Any, Any]] = None,
- delimiter: str = ',',
- encoding: str = 'utf8',
+ field_dict: dict[Any, Any] | None = None,
+ delimiter: str = ",",
+ encoding: str = "utf8",
pandas_kwargs: Any = None,
**kwargs: Any,
) -> None:
@@ -324,18 +329,18 @@ def load_df(
:param kwargs: passed to self.load_file
"""
- def _infer_field_types_from_df(df: pandas.DataFrame) -> Dict[Any, Any]:
+ def _infer_field_types_from_df(df: pandas.DataFrame) -> dict[Any, Any]:
dtype_kind_hive_type = {
- 'b': 'BOOLEAN', # boolean
- 'i': 'BIGINT', # signed integer
- 'u': 'BIGINT', # unsigned integer
- 'f': 'DOUBLE', # floating-point
- 'c': 'STRING', # complex floating-point
- 'M': 'TIMESTAMP', # datetime
- 'O': 'STRING', # object
- 'S': 'STRING', # (byte-)string
- 'U': 'STRING', # Unicode
- 'V': 'STRING', # void
+ "b": "BOOLEAN", # boolean
+ "i": "BIGINT", # signed integer
+ "u": "BIGINT", # unsigned integer
+ "f": "DOUBLE", # floating-point
+ "c": "STRING", # complex floating-point
+ "M": "TIMESTAMP", # datetime
+ "O": "STRING", # object
+ "S": "STRING", # (byte-)string
+ "U": "STRING", # Unicode
+ "V": "STRING", # void
}
order_type = OrderedDict()
@@ -346,7 +351,7 @@ def _infer_field_types_from_df(df: pandas.DataFrame) -> Dict[Any, Any]:
if pandas_kwargs is None:
pandas_kwargs = {}
- with TemporaryDirectory(prefix='airflow_hiveop_') as tmp_dir:
+ with TemporaryDirectory(prefix="airflow_hiveop_") as tmp_dir:
with NamedTemporaryFile(dir=tmp_dir, mode="w") as f:
if field_dict is None:
field_dict = _infer_field_types_from_df(df)
@@ -371,12 +376,12 @@ def load_file(
filepath: str,
table: str,
delimiter: str = ",",
- field_dict: Optional[Dict[Any, Any]] = None,
+ field_dict: dict[Any, Any] | None = None,
create: bool = True,
overwrite: bool = True,
- partition: Optional[Dict[str, Any]] = None,
+ partition: dict[str, Any] | None = None,
recreate: bool = False,
- tblproperties: Optional[Dict[str, Any]] = None,
+ tblproperties: dict[str, Any] | None = None,
) -> None:
"""
Loads a local file into Hive
@@ -403,7 +408,7 @@ def load_file(
execution
:param tblproperties: TBLPROPERTIES of the hive table being created
"""
- hql = ''
+ hql = ""
if recreate:
hql += f"DROP TABLE IF EXISTS {table};\n"
if create or recreate:
@@ -433,14 +438,14 @@ def load_file(
# As a workaround for HIVE-10541, add a newline character
# at the end of hql (AIRFLOW-2412).
- hql += ';\n'
+ hql += ";\n"
self.log.info(hql)
self.run_cli(hql)
def kill(self) -> None:
"""Kill Hive cli command"""
- if hasattr(self, 'sub_process'):
+ if hasattr(self, "sub_process"):
if self.sub_process.poll() is None:
print("Killing the Hive job")
self.sub_process.terminate()
@@ -459,26 +464,26 @@ class HiveMetastoreHook(BaseHook):
# java short max val
MAX_PART_COUNT = 32767
- conn_name_attr = 'metastore_conn_id'
- default_conn_name = 'metastore_default'
- conn_type = 'hive_metastore'
- hook_name = 'Hive Metastore Thrift'
+ conn_name_attr = "metastore_conn_id"
+ default_conn_name = "metastore_default"
+ conn_type = "hive_metastore"
+ hook_name = "Hive Metastore Thrift"
def __init__(self, metastore_conn_id: str = default_conn_name) -> None:
super().__init__()
self.conn = self.get_connection(metastore_conn_id)
self.metastore = self.get_metastore_client()
- def __getstate__(self) -> Dict[str, Any]:
+ def __getstate__(self) -> dict[str, Any]:
# This is for pickling to work despite the thrift hive client not
# being picklable
state = dict(self.__dict__)
- del state['metastore']
+ del state["metastore"]
return state
- def __setstate__(self, d: Dict[str, Any]) -> None:
+ def __setstate__(self, d: dict[str, Any]) -> None:
self.__dict__.update(d)
- self.__dict__['metastore'] = self.get_metastore_client()
+ self.__dict__["metastore"] = self.get_metastore_client()
def get_metastore_client(self) -> Any:
"""Returns a Hive thrift client."""
@@ -492,15 +497,24 @@ def get_metastore_client(self) -> Any:
if not host:
raise AirflowException("Failed to locate the valid server.")
- auth_mechanism = conn.extra_dejson.get('authMechanism', 'NOSASL')
+ if "authMechanism" in conn.extra_dejson:
+ warnings.warn(
+ "The 'authMechanism' option is deprecated. Please use 'auth_mechanism'.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ conn.extra_dejson["auth_mechanism"] = conn.extra_dejson["authMechanism"]
+ del conn.extra_dejson["authMechanism"]
+
+ auth_mechanism = conn.extra_dejson.get("auth_mechanism", "NOSASL")
- if conf.get('core', 'security') == 'kerberos':
- auth_mechanism = conn.extra_dejson.get('authMechanism', 'GSSAPI')
- kerberos_service_name = conn.extra_dejson.get('kerberos_service_name', 'hive')
+ if conf.get("core", "security") == "kerberos":
+ auth_mechanism = conn.extra_dejson.get("auth_mechanism", "GSSAPI")
+ kerberos_service_name = conn.extra_dejson.get("kerberos_service_name", "hive")
conn_socket = TSocket.TSocket(host, conn.port)
- if conf.get('core', 'security') == 'kerberos' and auth_mechanism == 'GSSAPI':
+ if conf.get("core", "security") == "kerberos" and auth_mechanism == "GSSAPI":
try:
import saslwrapper as sasl
except ImportError:
@@ -525,7 +539,7 @@ def sasl_factory() -> sasl.Client:
def _find_valid_host(self) -> Any:
conn = self.conn
- hosts = conn.host.split(',')
+ hosts = conn.host.split(",")
for host in hosts:
host_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.log.info("Trying to connect to %s:%s", host, conn.port)
@@ -548,7 +562,6 @@ def check_for_partition(self, schema: str, table: str, partition: str) -> bool:
:param table: Name of hive table @partition belongs to
:param partition: Expression that matches the partitions to check for
(eg `a = 'b' AND c = 'd'`)
- :rtype: bool
>>> hh = HiveMetastoreHook()
>>> t = 'static_babynames_partitioned'
@@ -569,7 +582,6 @@ def check_for_named_partition(self, schema: str, table: str, partition_name: str
:param schema: Name of hive schema (database) @table belongs to
:param table: Name of hive table @partition belongs to
:param partition_name: Name of the partitions to check for (eg `a=b/c=d`)
- :rtype: bool
>>> hh = HiveMetastoreHook()
>>> t = 'static_babynames_partitioned'
@@ -581,7 +593,7 @@ def check_for_named_partition(self, schema: str, table: str, partition_name: str
with self.metastore as client:
return client.check_for_named_partition(schema, table, partition_name)
- def get_table(self, table_name: str, db: str = 'default') -> Any:
+ def get_table(self, table_name: str, db: str = "default") -> Any:
"""Get a metastore table object
>>> hh = HiveMetastoreHook()
@@ -591,25 +603,23 @@ def get_table(self, table_name: str, db: str = 'default') -> Any:
>>> [col.name for col in t.sd.cols]
['state', 'year', 'name', 'gender', 'num']
"""
- if db == 'default' and '.' in table_name:
- db, table_name = table_name.split('.')[:2]
+ if db == "default" and "." in table_name:
+ db, table_name = table_name.split(".")[:2]
with self.metastore as client:
return client.get_table(dbname=db, tbl_name=table_name)
- def get_tables(self, db: str, pattern: str = '*') -> Any:
+ def get_tables(self, db: str, pattern: str = "*") -> Any:
"""Get a metastore table object"""
with self.metastore as client:
tables = client.get_tables(db_name=db, pattern=pattern)
return client.get_table_objects_by_name(db, tables)
- def get_databases(self, pattern: str = '*') -> Any:
+ def get_databases(self, pattern: str = "*") -> Any:
"""Get a metastore table object"""
with self.metastore as client:
return client.get_databases(pattern)
- def get_partitions(
- self, schema: str, table_name: str, partition_filter: Optional[str] = None
- ) -> List[Any]:
+ def get_partitions(self, schema: str, table_name: str, partition_filter: str | None = None) -> list[Any]:
"""
Returns a list of all partitions in a table. Works only
for tables with less than 32767 (java short max val).
@@ -645,7 +655,7 @@ def get_partitions(
@staticmethod
def _get_max_partition_from_part_specs(
- part_specs: List[Any], partition_key: Optional[str], filter_map: Optional[Dict[str, Any]]
+ part_specs: list[Any], partition_key: str | None, filter_map: dict[str, Any] | None
) -> Any:
"""
Helper method to get max partition of partitions with partition_key
@@ -659,7 +669,6 @@ def _get_max_partition_from_part_specs(
Only partitions matching all partition_key:partition_value
pairs will be considered as candidates of max partition.
:return: Max partition or None if part_specs is empty.
- :rtype: basestring
"""
if not part_specs:
return None
@@ -691,8 +700,8 @@ def max_partition(
self,
schema: str,
table_name: str,
- field: Optional[str] = None,
- filter_map: Optional[Dict[Any, Any]] = None,
+ field: str | None = None,
+ filter_map: dict[Any, Any] | None = None,
) -> Any:
"""
Returns the maximum value for all partitions with given field in a table.
@@ -732,7 +741,7 @@ def max_partition(
return HiveMetastoreHook._get_max_partition_from_part_specs(part_specs, field, filter_map)
- def table_exists(self, table_name: str, db: str = 'default') -> bool:
+ def table_exists(self, table_name: str, db: str = "default") -> bool:
"""
Check if table exists
@@ -748,7 +757,7 @@ def table_exists(self, table_name: str, db: str = 'default') -> bool:
except Exception:
return False
- def drop_partitions(self, table_name, part_vals, delete_data=False, db='default'):
+ def drop_partitions(self, table_name, part_vals, delete_data=False, db="default"):
"""
Drop partitions from the given table matching the part_vals input
@@ -779,7 +788,7 @@ class HiveServer2Hook(DbApiHook):
Wrapper around the pyhive library
Notes:
- * the default authMechanism is PLAIN, to override it you
+ * the default auth_mechanism is PLAIN, to override it you
can specify it in the ``extra`` of your connection in the UI
* the default for run_set_variable_statements is true, if you
are using impala you may need to set it to false in the
@@ -790,38 +799,47 @@ class HiveServer2Hook(DbApiHook):
:param schema: Hive database name.
"""
- conn_name_attr = 'hiveserver2_conn_id'
- default_conn_name = 'hiveserver2_default'
- conn_type = 'hiveserver2'
- hook_name = 'Hive Server 2 Thrift'
+ conn_name_attr = "hiveserver2_conn_id"
+ default_conn_name = "hiveserver2_default"
+ conn_type = "hiveserver2"
+ hook_name = "Hive Server 2 Thrift"
supports_autocommit = False
- def get_conn(self, schema: Optional[str] = None) -> Any:
+ def get_conn(self, schema: str | None = None) -> Any:
"""Returns a Hive connection object."""
- username: Optional[str] = None
- password: Optional[str] = None
+ username: str | None = None
+ password: str | None = None
db = self.get_connection(self.hiveserver2_conn_id) # type: ignore
- auth_mechanism = db.extra_dejson.get('authMechanism', 'NONE')
- if auth_mechanism == 'NONE' and db.login is None:
+ if "authMechanism" in db.extra_dejson:
+ warnings.warn(
+ "The 'authMechanism' option is deprecated. Please use 'auth_mechanism'.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+ db.extra_dejson["auth_mechanism"] = db.extra_dejson["authMechanism"]
+ del db.extra_dejson["authMechanism"]
+
+ auth_mechanism = db.extra_dejson.get("auth_mechanism", "NONE")
+ if auth_mechanism == "NONE" and db.login is None:
# we need to give a username
- username = 'airflow'
+ username = "airflow"
kerberos_service_name = None
- if conf.get('core', 'security') == 'kerberos':
- auth_mechanism = db.extra_dejson.get('authMechanism', 'KERBEROS')
- kerberos_service_name = db.extra_dejson.get('kerberos_service_name', 'hive')
+ if conf.get("core", "security") == "kerberos":
+ auth_mechanism = db.extra_dejson.get("auth_mechanism", "KERBEROS")
+ kerberos_service_name = db.extra_dejson.get("kerberos_service_name", "hive")
# pyhive uses GSSAPI instead of KERBEROS as a auth_mechanism identifier
- if auth_mechanism == 'GSSAPI':
+ if auth_mechanism == "GSSAPI":
self.log.warning(
- "Detected deprecated 'GSSAPI' for authMechanism for %s. Please use 'KERBEROS' instead",
+ "Detected deprecated 'GSSAPI' for auth_mechanism for %s. Please use 'KERBEROS' instead",
self.hiveserver2_conn_id, # type: ignore
)
- auth_mechanism = 'KERBEROS'
+ auth_mechanism = "KERBEROS"
# Password should be set if and only if in LDAP or CUSTOM mode
- if auth_mechanism in ('LDAP', 'CUSTOM'):
+ if auth_mechanism in ("LDAP", "CUSTOM"):
password = db.password
from pyhive.hive import connect
@@ -833,20 +851,20 @@ def get_conn(self, schema: Optional[str] = None) -> Any:
kerberos_service_name=kerberos_service_name,
username=db.login or username,
password=password,
- database=schema or db.schema or 'default',
+ database=schema or db.schema or "default",
)
def _get_results(
self,
- hql: Union[str, List[str]],
- schema: str = 'default',
- fetch_size: Optional[int] = None,
- hive_conf: Optional[Dict[Any, Any]] = None,
+ sql: str | list[str],
+ schema: str = "default",
+ fetch_size: int | None = None,
+ hive_conf: Iterable | Mapping | None = None,
) -> Any:
from pyhive.exc import ProgrammingError
- if isinstance(hql, str):
- hql = [hql]
+ if isinstance(sql, str):
+ sql = [sql]
previous_description = None
with contextlib.closing(self.get_conn(schema)) as conn, contextlib.closing(conn.cursor()) as cur:
@@ -856,28 +874,28 @@ def _get_results(
db = self.get_connection(self.hiveserver2_conn_id) # type: ignore
- if db.extra_dejson.get('run_set_variable_statements', True):
+ if db.extra_dejson.get("run_set_variable_statements", True):
env_context = get_context_from_env_var()
if hive_conf:
env_context.update(hive_conf)
for k, v in env_context.items():
cur.execute(f"set {k}={v}")
- for statement in hql:
+ for statement in sql:
cur.execute(statement)
# we only get results of statements that returns
lowered_statement = statement.lower().strip()
if (
- lowered_statement.startswith('select')
- or lowered_statement.startswith('with')
- or lowered_statement.startswith('show')
- or (lowered_statement.startswith('set') and '=' not in lowered_statement)
+ lowered_statement.startswith("select")
+ or lowered_statement.startswith("with")
+ or lowered_statement.startswith("show")
+ or (lowered_statement.startswith("set") and "=" not in lowered_statement)
):
description = cur.description
if previous_description and previous_description != description:
- message = f'''The statements are producing different descriptions:
+ message = f"""The statements are producing different descriptions:
Current: {repr(description)}
- Previous: {repr(previous_description)}'''
+ Previous: {repr(previous_description)}"""
raise ValueError(message)
elif not previous_description:
previous_description = description
@@ -892,41 +910,40 @@ def _get_results(
def get_results(
self,
- hql: str,
- schema: str = 'default',
- fetch_size: Optional[int] = None,
- hive_conf: Optional[Dict[Any, Any]] = None,
- ) -> Dict[str, Any]:
+ sql: str | list[str],
+ schema: str = "default",
+ fetch_size: int | None = None,
+ hive_conf: Iterable | Mapping | None = None,
+ ) -> dict[str, Any]:
"""
Get results of the provided hql in target schema.
- :param hql: hql to be executed.
+ :param sql: hql to be executed.
:param schema: target schema, default to 'default'.
:param fetch_size: max size of result to fetch.
:param hive_conf: hive_conf to execute alone with the hql.
:return: results of hql execution, dict with data (list of results) and header
- :rtype: dict
"""
- results_iter = self._get_results(hql, schema, fetch_size=fetch_size, hive_conf=hive_conf)
+ results_iter = self._get_results(sql, schema, fetch_size=fetch_size, hive_conf=hive_conf)
header = next(results_iter)
- results = {'data': list(results_iter), 'header': header}
+ results = {"data": list(results_iter), "header": header}
return results
def to_csv(
self,
- hql: str,
+ sql: str,
csv_filepath: str,
- schema: str = 'default',
- delimiter: str = ',',
- lineterminator: str = '\r\n',
+ schema: str = "default",
+ delimiter: str = ",",
+ lineterminator: str = "\r\n",
output_header: bool = True,
fetch_size: int = 1000,
- hive_conf: Optional[Dict[Any, Any]] = None,
+ hive_conf: dict[Any, Any] | None = None,
) -> None:
"""
Execute hql in target schema and write results to a csv file.
- :param hql: hql to be executed.
+ :param sql: hql to be executed.
:param csv_filepath: filepath of csv to write results into.
:param schema: target schema, default to 'default'.
:param delimiter: delimiter of the csv file, default to ','.
@@ -936,16 +953,16 @@ def to_csv(
:param hive_conf: hive_conf to execute alone with the hql.
"""
- results_iter = self._get_results(hql, schema, fetch_size=fetch_size, hive_conf=hive_conf)
+ results_iter = self._get_results(sql, schema, fetch_size=fetch_size, hive_conf=hive_conf)
header = next(results_iter)
message = None
i = 0
- with open(csv_filepath, 'wb') as file:
- writer = csv.writer(file, delimiter=delimiter, lineterminator=lineterminator, encoding='utf-8')
+ with open(csv_filepath, "wb") as file:
+ writer = csv.writer(file, delimiter=delimiter, lineterminator=lineterminator, encoding="utf-8")
try:
if output_header:
- self.log.debug('Cursor description is %s', header)
+ self.log.debug("Cursor description is %s", header)
writer.writerow([c[0] for c in header])
for i, row in enumerate(results_iter, 1):
@@ -963,40 +980,39 @@ def to_csv(
self.log.info("Done. Loaded a total of %s rows.", i)
def get_records(
- self, hql: str, schema: str = 'default', hive_conf: Optional[Dict[Any, Any]] = None
+ self, sql: str | list[str], parameters: Iterable | Mapping | None = None, **kwargs
) -> Any:
"""
- Get a set of records from a Hive query.
+ Get a set of records from a Hive query. You can optionally pass 'schema' kwarg
+ which specifies target schema and default to 'default'.
- :param hql: hql to be executed.
- :param schema: target schema, default to 'default'.
- :param hive_conf: hive_conf to execute alone with the hql.
+ :param sql: hql to be executed.
+ :param parameters: optional configuration passed to get_results
:return: result of hive execution
- :rtype: list
>>> hh = HiveServer2Hook()
>>> sql = "SELECT * FROM airflow.static_babynames LIMIT 100"
>>> len(hh.get_records(sql))
100
"""
- return self.get_results(hql, schema=schema, hive_conf=hive_conf)['data']
+ schema = kwargs["schema"] if "schema" in kwargs else "default"
+ return self.get_results(sql, schema=schema, hive_conf=parameters)["data"]
def get_pandas_df( # type: ignore
self,
- hql: str,
- schema: str = 'default',
- hive_conf: Optional[Dict[Any, Any]] = None,
+ sql: str,
+ schema: str = "default",
+ hive_conf: dict[Any, Any] | None = None,
**kwargs,
) -> pandas.DataFrame:
"""
Get a pandas dataframe from a Hive query
- :param hql: hql to be executed.
+ :param sql: hql to be executed.
:param schema: target schema, default to 'default'.
:param hive_conf: hive_conf to execute alone with the hql.
:param kwargs: (optional) passed into pandas.DataFrame constructor
:return: result of hive execution
- :rtype: DataFrame
>>> hh = HiveServer2Hook()
>>> sql = "SELECT * FROM airflow.static_babynames LIMIT 100"
@@ -1006,6 +1022,6 @@ def get_pandas_df( # type: ignore
:return: pandas.DateFrame
"""
- res = self.get_results(hql, schema=schema, hive_conf=hive_conf)
- df = pandas.DataFrame(res['data'], columns=[c[0] for c in res['header']], **kwargs)
+ res = self.get_results(sql, schema=schema, hive_conf=hive_conf)
+ df = pandas.DataFrame(res["data"], columns=[c[0] for c in res["header"]], **kwargs)
return df
diff --git a/airflow/providers/apache/hive/operators/hive.py b/airflow/providers/apache/hive/operators/hive.py
index 45cae0fa4e31f..23f6c32edd3d7 100644
--- a/airflow/providers/apache/hive/operators/hive.py
+++ b/airflow/providers/apache/hive/operators/hive.py
@@ -15,9 +15,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
import os
import re
-from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence
+from typing import TYPE_CHECKING, Any, Sequence
from airflow.configuration import conf
from airflow.models import BaseOperator
@@ -57,34 +59,34 @@ class HiveOperator(BaseOperator):
"""
template_fields: Sequence[str] = (
- 'hql',
- 'schema',
- 'hive_cli_conn_id',
- 'mapred_queue',
- 'hiveconfs',
- 'mapred_job_name',
- 'mapred_queue_priority',
+ "hql",
+ "schema",
+ "hive_cli_conn_id",
+ "mapred_queue",
+ "hiveconfs",
+ "mapred_job_name",
+ "mapred_queue_priority",
)
template_ext: Sequence[str] = (
- '.hql',
- '.sql',
+ ".hql",
+ ".sql",
)
- template_fields_renderers = {'hql': 'hql'}
- ui_color = '#f0e4ec'
+ template_fields_renderers = {"hql": "hql"}
+ ui_color = "#f0e4ec"
def __init__(
self,
*,
hql: str,
- hive_cli_conn_id: str = 'hive_cli_default',
- schema: str = 'default',
- hiveconfs: Optional[Dict[Any, Any]] = None,
+ hive_cli_conn_id: str = "hive_cli_default",
+ schema: str = "default",
+ hiveconfs: dict[Any, Any] | None = None,
hiveconf_jinja_translate: bool = False,
- script_begin_tag: Optional[str] = None,
+ script_begin_tag: str | None = None,
run_as_owner: bool = False,
- mapred_queue: Optional[str] = None,
- mapred_queue_priority: Optional[str] = None,
- mapred_job_name: Optional[str] = None,
+ mapred_queue: str | None = None,
+ mapred_queue_priority: str | None = None,
+ mapred_job_name: str | None = None,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
@@ -101,20 +103,18 @@ def __init__(
self.mapred_queue_priority = mapred_queue_priority
self.mapred_job_name = mapred_job_name
- job_name_template = conf.get(
- 'hive',
- 'mapred_job_name_template',
+ job_name_template = conf.get_mandatory_value(
+ "hive",
+ "mapred_job_name_template",
fallback="Airflow HiveOperator task for {hostname}.{dag_id}.{task_id}.{execution_date}",
)
- if job_name_template is None:
- raise ValueError("Job name template should be set !")
self.mapred_job_name_template: str = job_name_template
# assigned lazily - just for consistency we can create the attribute with a
# `None` initial value, later it will be populated by the execute method.
# This also makes `on_kill` implementation consistent since it assumes `self.hook`
# is defined.
- self.hook: Optional[HiveCliHook] = None
+ self.hook: HiveCliHook | None = None
def get_hook(self) -> HiveCliHook:
"""Get Hive cli hook"""
@@ -132,18 +132,18 @@ def prepare_template(self) -> None:
if self.script_begin_tag and self.script_begin_tag in self.hql:
self.hql = "\n".join(self.hql.split(self.script_begin_tag)[1:])
- def execute(self, context: "Context") -> None:
- self.log.info('Executing: %s', self.hql)
+ def execute(self, context: Context) -> None:
+ self.log.info("Executing: %s", self.hql)
self.hook = self.get_hook()
# set the mapred_job_name if it's not set with dag, task, execution time info
if not self.mapred_job_name:
- ti = context['ti']
+ ti = context["ti"]
self.hook.mapred_job_name = self.mapred_job_name_template.format(
dag_id=ti.dag_id,
task_id=ti.task_id,
execution_date=ti.execution_date.isoformat(),
- hostname=ti.hostname.split('.')[0],
+ hostname=ti.hostname.split(".")[0],
)
if self.hiveconf_jinja_translate:
@@ -151,7 +151,7 @@ def execute(self, context: "Context") -> None:
else:
self.hiveconfs.update(context_to_airflow_vars(context))
- self.log.info('Passing HiveConf: %s', self.hiveconfs)
+ self.log.info("Passing HiveConf: %s", self.hiveconfs)
self.hook.run_cli(hql=self.hql, schema=self.schema, hive_conf=self.hiveconfs)
def dry_run(self) -> None:
@@ -169,6 +169,6 @@ def on_kill(self) -> None:
def clear_airflow_vars(self) -> None:
"""Reset airflow environment variables to prevent existing ones from impacting behavior."""
blank_env_vars = {
- value['env_var_format']: '' for value in operator_helpers.AIRFLOW_VAR_NAME_FORMAT_MAPPING.values()
+ value["env_var_format"]: "" for value in operator_helpers.AIRFLOW_VAR_NAME_FORMAT_MAPPING.values()
}
os.environ.update(blank_env_vars)
diff --git a/airflow/providers/apache/hive/operators/hive_stats.py b/airflow/providers/apache/hive/operators/hive_stats.py
index a1b5539622321..caa6770a8fdf8 100644
--- a/airflow/providers/apache/hive/operators/hive_stats.py
+++ b/airflow/providers/apache/hive/operators/hive_stats.py
@@ -15,10 +15,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
import json
import warnings
from collections import OrderedDict
-from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence
+from typing import TYPE_CHECKING, Any, Callable, Sequence
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
@@ -58,23 +60,23 @@ class HiveStatsCollectionOperator(BaseOperator):
column.
"""
- template_fields: Sequence[str] = ('table', 'partition', 'ds', 'dttm')
- ui_color = '#aff7a6'
+ template_fields: Sequence[str] = ("table", "partition", "ds", "dttm")
+ ui_color = "#aff7a6"
def __init__(
self,
*,
table: str,
partition: Any,
- extra_exprs: Optional[Dict[str, Any]] = None,
- excluded_columns: Optional[List[str]] = None,
- assignment_func: Optional[Callable[[str, str], Optional[Dict[Any, Any]]]] = None,
- metastore_conn_id: str = 'metastore_default',
- presto_conn_id: str = 'presto_default',
- mysql_conn_id: str = 'airflow_db',
+ extra_exprs: dict[str, Any] | None = None,
+ excluded_columns: list[str] | None = None,
+ assignment_func: Callable[[str, str], dict[Any, Any] | None] | None = None,
+ metastore_conn_id: str = "metastore_default",
+ presto_conn_id: str = "presto_default",
+ mysql_conn_id: str = "airflow_db",
**kwargs: Any,
) -> None:
- if 'col_blacklist' in kwargs:
+ if "col_blacklist" in kwargs:
warnings.warn(
f"col_blacklist kwarg passed to {self.__class__.__name__} "
f"(task_id: {kwargs.get('task_id')}) is deprecated, "
@@ -82,44 +84,44 @@ def __init__(
category=FutureWarning,
stacklevel=2,
)
- excluded_columns = kwargs.pop('col_blacklist')
+ excluded_columns = kwargs.pop("col_blacklist")
super().__init__(**kwargs)
self.table = table
self.partition = partition
self.extra_exprs = extra_exprs or {}
- self.excluded_columns = excluded_columns or [] # type: List[str]
+ self.excluded_columns: list[str] = excluded_columns or []
self.metastore_conn_id = metastore_conn_id
self.presto_conn_id = presto_conn_id
self.mysql_conn_id = mysql_conn_id
self.assignment_func = assignment_func
- self.ds = '{{ ds }}'
- self.dttm = '{{ execution_date.isoformat() }}'
+ self.ds = "{{ ds }}"
+ self.dttm = "{{ execution_date.isoformat() }}"
- def get_default_exprs(self, col: str, col_type: str) -> Dict[Any, Any]:
+ def get_default_exprs(self, col: str, col_type: str) -> dict[Any, Any]:
"""Get default expressions"""
if col in self.excluded_columns:
return {}
- exp = {(col, 'non_null'): f"COUNT({col})"}
- if col_type in ['double', 'int', 'bigint', 'float']:
- exp[(col, 'sum')] = f'SUM({col})'
- exp[(col, 'min')] = f'MIN({col})'
- exp[(col, 'max')] = f'MAX({col})'
- exp[(col, 'avg')] = f'AVG({col})'
- elif col_type == 'boolean':
- exp[(col, 'true')] = f'SUM(CASE WHEN {col} THEN 1 ELSE 0 END)'
- exp[(col, 'false')] = f'SUM(CASE WHEN NOT {col} THEN 1 ELSE 0 END)'
- elif col_type in ['string']:
- exp[(col, 'len')] = f'SUM(CAST(LENGTH({col}) AS BIGINT))'
- exp[(col, 'approx_distinct')] = f'APPROX_DISTINCT({col})'
+ exp = {(col, "non_null"): f"COUNT({col})"}
+ if col_type in {"double", "int", "bigint", "float"}:
+ exp[(col, "sum")] = f"SUM({col})"
+ exp[(col, "min")] = f"MIN({col})"
+ exp[(col, "max")] = f"MAX({col})"
+ exp[(col, "avg")] = f"AVG({col})"
+ elif col_type == "boolean":
+ exp[(col, "true")] = f"SUM(CASE WHEN {col} THEN 1 ELSE 0 END)"
+ exp[(col, "false")] = f"SUM(CASE WHEN NOT {col} THEN 1 ELSE 0 END)"
+ elif col_type == "string":
+ exp[(col, "len")] = f"SUM(CAST(LENGTH({col}) AS BIGINT))"
+ exp[(col, "approx_distinct")] = f"APPROX_DISTINCT({col})"
return exp
- def execute(self, context: "Context") -> None:
+ def execute(self, context: Context) -> None:
metastore = HiveMetastoreHook(metastore_conn_id=self.metastore_conn_id)
table = metastore.get_table(table_name=self.table)
field_types = {col.name: col.type for col in table.sd.cols}
- exprs: Any = {('', 'count'): 'COUNT(*)'}
+ exprs: Any = {("", "count"): "COUNT(*)"}
for col, col_type in list(field_types.items()):
if self.assignment_func:
assign_exprs = self.assignment_func(col, col_type)
@@ -130,15 +132,15 @@ def execute(self, context: "Context") -> None:
exprs.update(assign_exprs)
exprs.update(self.extra_exprs)
exprs = OrderedDict(exprs)
- exprs_str = ",\n ".join(v + " AS " + k[0] + '__' + k[1] for k, v in exprs.items())
+ exprs_str = ",\n ".join(f"{v} AS {k[0]}__{k[1]}" for k, v in exprs.items())
where_clause_ = [f"{k} = '{v}'" for k, v in self.partition.items()]
where_clause = " AND\n ".join(where_clause_)
sql = f"SELECT {exprs_str} FROM {self.table} WHERE {where_clause};"
presto = PrestoHook(presto_conn_id=self.presto_conn_id)
- self.log.info('Executing SQL check: %s', sql)
- row = presto.get_first(hql=sql)
+ self.log.info("Executing SQL check: %s", sql)
+ row = presto.get_first(sql)
self.log.info("Record: %s", row)
if not row:
raise AirflowException("The query returned None")
@@ -170,15 +172,15 @@ def execute(self, context: "Context") -> None:
(self.ds, self.dttm, self.table, part_json) + (r[0][0], r[0][1], r[1]) for r in zip(exprs, row)
]
mysql.insert_rows(
- table='hive_stats',
+ table="hive_stats",
rows=rows,
target_fields=[
- 'ds',
- 'dttm',
- 'table_name',
- 'partition_repr',
- 'col',
- 'metric',
- 'value',
+ "ds",
+ "dttm",
+ "table_name",
+ "partition_repr",
+ "col",
+ "metric",
+ "value",
],
)
diff --git a/airflow/providers/apache/hive/provider.yaml b/airflow/providers/apache/hive/provider.yaml
index cf00cf2e7cc88..d3da54dfbb57c 100644
--- a/airflow/providers/apache/hive/provider.yaml
+++ b/airflow/providers/apache/hive/provider.yaml
@@ -22,6 +22,11 @@ description: |
`Apache Hive `__
versions:
+ - 4.1.0
+ - 4.0.1
+ - 4.0.0
+ - 3.1.0
+ - 3.0.0
- 2.3.3
- 2.3.2
- 2.3.1
@@ -37,8 +42,17 @@ versions:
- 1.0.1
- 1.0.0
-additional-dependencies:
- - apache-airflow>=2.1.0
+dependencies:
+ - apache-airflow>=2.3.0
+ - apache-airflow-providers-common-sql>=1.3.1
+ - hmsclient>=0.1.0
+ - pandas>=0.17.1
+ - pyhive[hive]>=0.6.0
+ # in case of Python 3.9 sasl library needs to be installed with version higher or equal than
+ # 0.3.1 because only that version supports Python 3.9. For other Python version pyhive[hive] pulls
+ # the sasl library anyway (and there sasl library version is not relevant)
+ - sasl>=0.3.1; python_version>="3.9"
+ - thrift>=0.9.2
integrations:
- integration-name: Apache Hive
@@ -86,12 +100,6 @@ transfers:
target-integration-name: Apache Hive
python-module: airflow.providers.apache.hive.transfers.mssql_to_hive
-hook-class-names:
- # deprecated - to be removed after providers add dependency on Airflow 2.2.0+
- - airflow.providers.apache.hive.hooks.hive.HiveCliHook
- - airflow.providers.apache.hive.hooks.hive.HiveServer2Hook
- - airflow.providers.apache.hive.hooks.hive.HiveMetastoreHook
-
connection-types:
- hook-class-name: airflow.providers.apache.hive.hooks.hive.HiveCliHook
connection-type: hive_cli
diff --git a/airflow/providers/apache/hive/sensors/hive_partition.py b/airflow/providers/apache/hive/sensors/hive_partition.py
index f03dcb18f52e1..d839bb444fcbd 100644
--- a/airflow/providers/apache/hive/sensors/hive_partition.py
+++ b/airflow/providers/apache/hive/sensors/hive_partition.py
@@ -15,7 +15,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import TYPE_CHECKING, Any, Optional, Sequence
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any, Sequence
from airflow.providers.apache.hive.hooks.hive import HiveMetastoreHook
from airflow.sensors.base import BaseSensorOperator
@@ -43,19 +45,19 @@ class HivePartitionSensor(BaseSensorOperator):
"""
template_fields: Sequence[str] = (
- 'schema',
- 'table',
- 'partition',
+ "schema",
+ "table",
+ "partition",
)
- ui_color = '#C5CAE9'
+ ui_color = "#C5CAE9"
def __init__(
self,
*,
table: str,
- partition: Optional[str] = "ds='{{ ds }}'",
- metastore_conn_id: str = 'metastore_default',
- schema: str = 'default',
+ partition: str | None = "ds='{{ ds }}'",
+ metastore_conn_id: str = "metastore_default",
+ schema: str = "default",
poke_interval: int = 60 * 3,
**kwargs: Any,
):
@@ -67,10 +69,10 @@ def __init__(
self.partition = partition
self.schema = schema
- def poke(self, context: "Context") -> bool:
- if '.' in self.table:
- self.schema, self.table = self.table.split('.')
- self.log.info('Poking for table %s.%s, partition %s', self.schema, self.table, self.partition)
- if not hasattr(self, 'hook'):
+ def poke(self, context: Context) -> bool:
+ if "." in self.table:
+ self.schema, self.table = self.table.split(".")
+ self.log.info("Poking for table %s.%s, partition %s", self.schema, self.table, self.partition)
+ if not hasattr(self, "hook"):
hook = HiveMetastoreHook(metastore_conn_id=self.metastore_conn_id)
return hook.check_for_partition(self.schema, self.table, self.partition)
diff --git a/airflow/providers/apache/hive/sensors/metastore_partition.py b/airflow/providers/apache/hive/sensors/metastore_partition.py
index ea6c1525a1d57..57e793849efce 100644
--- a/airflow/providers/apache/hive/sensors/metastore_partition.py
+++ b/airflow/providers/apache/hive/sensors/metastore_partition.py
@@ -15,9 +15,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
from typing import TYPE_CHECKING, Any, Sequence
-from airflow.sensors.sql import SqlSensor
+from airflow.providers.common.sql.sensors.sql import SqlSensor
if TYPE_CHECKING:
from airflow.utils.context import Context
@@ -40,9 +42,8 @@ class MetastorePartitionSensor(SqlSensor):
:param mysql_conn_id: a reference to the MySQL conn_id for the metastore
"""
- template_fields: Sequence[str] = ('partition_name', 'table', 'schema')
- ui_color = '#8da7be'
- poke_context_fields = ('partition_name', 'table', 'schema', 'mysql_conn_id')
+ template_fields: Sequence[str] = ("partition_name", "table", "schema")
+ ui_color = "#8da7be"
def __init__(
self,
@@ -66,11 +67,11 @@ def __init__(
# constructor below and apply_defaults will no longer throw an exception.
super().__init__(**kwargs)
- def poke(self, context: "Context") -> Any:
+ def poke(self, context: Context) -> Any:
if self.first_poke:
self.first_poke = False
- if '.' in self.table:
- self.schema, self.table = self.table.split('.')
+ if "." in self.table:
+ self.schema, self.table = self.table.split(".")
self.sql = """
SELECT 'X'
FROM PARTITIONS A0
diff --git a/airflow/providers/apache/hive/sensors/named_hive_partition.py b/airflow/providers/apache/hive/sensors/named_hive_partition.py
index 9535bcdab0219..1a7e5ee36ebf8 100644
--- a/airflow/providers/apache/hive/sensors/named_hive_partition.py
+++ b/airflow/providers/apache/hive/sensors/named_hive_partition.py
@@ -15,7 +15,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import TYPE_CHECKING, Any, List, Sequence, Tuple
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any, Sequence
from airflow.sensors.base import BaseSensorOperator
@@ -38,15 +40,14 @@ class NamedHivePartitionSensor(BaseSensorOperator):
:ref:`metastore thrift service connection id `.
"""
- template_fields: Sequence[str] = ('partition_names',)
- ui_color = '#8d99ae'
- poke_context_fields = ('partition_names', 'metastore_conn_id')
+ template_fields: Sequence[str] = ("partition_names",)
+ ui_color = "#8d99ae"
def __init__(
self,
*,
- partition_names: List[str],
- metastore_conn_id: str = 'metastore_default',
+ partition_names: list[str],
+ metastore_conn_id: str = "metastore_default",
poke_interval: int = 60 * 3,
hook: Any = None,
**kwargs: Any,
@@ -55,28 +56,28 @@ def __init__(
self.next_index_to_poke = 0
if isinstance(partition_names, str):
- raise TypeError('partition_names must be an array of strings')
+ raise TypeError("partition_names must be an array of strings")
self.metastore_conn_id = metastore_conn_id
self.partition_names = partition_names
self.hook = hook
- if self.hook and metastore_conn_id != 'metastore_default':
+ if self.hook and metastore_conn_id != "metastore_default":
self.log.warning(
- 'A hook was passed but a non default metastore_conn_id=%s was used', metastore_conn_id
+ "A hook was passed but a non default metastore_conn_id=%s was used", metastore_conn_id
)
@staticmethod
- def parse_partition_name(partition: str) -> Tuple[Any, ...]:
+ def parse_partition_name(partition: str) -> tuple[Any, ...]:
"""Get schema, table, and partition info."""
- first_split = partition.split('.', 1)
+ first_split = partition.split(".", 1)
if len(first_split) == 1:
- schema = 'default'
+ schema = "default"
table_partition = max(first_split) # poor man first
else:
schema, table_partition = first_split
- second_split = table_partition.split('/', 1)
+ second_split = table_partition.split("/", 1)
if len(second_split) == 1:
- raise ValueError('Could not parse ' + partition + 'into table, partition')
+ raise ValueError(f"Could not parse {partition}into table, partition")
else:
table, partition = second_split
return schema, table, partition
@@ -90,10 +91,10 @@ def poke_partition(self, partition: str) -> Any:
schema, table, partition = self.parse_partition_name(partition)
- self.log.info('Poking for %s.%s/%s', schema, table, partition)
+ self.log.info("Poking for %s.%s/%s", schema, table, partition)
return self.hook.check_for_named_partition(schema, table, partition)
- def poke(self, context: "Context") -> bool:
+ def poke(self, context: Context) -> bool:
number_of_partitions = len(self.partition_names)
poke_index_start = self.next_index_to_poke
@@ -104,12 +105,3 @@ def poke(self, context: "Context") -> bool:
self.next_index_to_poke = 0
return True
-
- def is_smart_sensor_compatible(self):
- result = (
- not self.soft_fail
- and not self.hook
- and len(self.partition_names) <= 30
- and super().is_smart_sensor_compatible()
- )
- return result
diff --git a/airflow/providers/apache/hive/transfers/hive_to_mysql.py b/airflow/providers/apache/hive/transfers/hive_to_mysql.py
index 65de4cc159ae2..b1a3669d7171c 100644
--- a/airflow/providers/apache/hive/transfers/hive_to_mysql.py
+++ b/airflow/providers/apache/hive/transfers/hive_to_mysql.py
@@ -15,23 +15,20 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""This module contains an operator to move data from Hive to MySQL."""
+from __future__ import annotations
+
from tempfile import NamedTemporaryFile
-from typing import TYPE_CHECKING, Dict, Optional, Sequence
+from typing import TYPE_CHECKING, Sequence
from airflow.models import BaseOperator
from airflow.providers.apache.hive.hooks.hive import HiveServer2Hook
from airflow.providers.mysql.hooks.mysql import MySqlHook
from airflow.utils.operator_helpers import context_to_airflow_vars
-from airflow.www import utils as wwwutils
if TYPE_CHECKING:
from airflow.utils.context import Context
-# TODO: Remove renderer check when the provider has an Airflow 2.3+ requirement.
-MYSQL_RENDERER = 'mysql' if 'mysql' in wwwutils.get_attr_renderer() else 'sql'
-
class HiveToMySqlOperator(BaseOperator):
"""
@@ -59,26 +56,26 @@ class HiveToMySqlOperator(BaseOperator):
:param hive_conf:
"""
- template_fields: Sequence[str] = ('sql', 'mysql_table', 'mysql_preoperator', 'mysql_postoperator')
- template_ext: Sequence[str] = ('.sql',)
+ template_fields: Sequence[str] = ("sql", "mysql_table", "mysql_preoperator", "mysql_postoperator")
+ template_ext: Sequence[str] = (".sql",)
template_fields_renderers = {
- 'sql': 'hql',
- 'mysql_preoperator': MYSQL_RENDERER,
- 'mysql_postoperator': MYSQL_RENDERER,
+ "sql": "hql",
+ "mysql_preoperator": "mysql",
+ "mysql_postoperator": "mysql",
}
- ui_color = '#a0e08c'
+ ui_color = "#a0e08c"
def __init__(
self,
*,
sql: str,
mysql_table: str,
- hiveserver2_conn_id: str = 'hiveserver2_default',
- mysql_conn_id: str = 'mysql_default',
- mysql_preoperator: Optional[str] = None,
- mysql_postoperator: Optional[str] = None,
+ hiveserver2_conn_id: str = "hiveserver2_default",
+ mysql_conn_id: str = "mysql_default",
+ mysql_preoperator: str | None = None,
+ mysql_postoperator: str | None = None,
bulk_load: bool = False,
- hive_conf: Optional[Dict] = None,
+ hive_conf: dict | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -91,7 +88,7 @@ def __init__(
self.bulk_load = bulk_load
self.hive_conf = hive_conf
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
hive = HiveServer2Hook(hiveserver2_conn_id=self.hiveserver2_conn_id)
self.log.info("Extracting data from Hive: %s", self.sql)
@@ -103,15 +100,15 @@ def execute(self, context: 'Context'):
hive.to_csv(
self.sql,
tmp_file.name,
- delimiter='\t',
- lineterminator='\n',
+ delimiter="\t",
+ lineterminator="\n",
output_header=False,
hive_conf=hive_conf,
)
mysql = self._call_preoperator()
mysql.bulk_load(table=self.mysql_table, tmp_file=tmp_file.name)
else:
- hive_results = hive.get_records(self.sql, hive_conf=hive_conf)
+ hive_results = hive.get_records(self.sql, parameters=hive_conf)
mysql = self._call_preoperator()
mysql.insert_rows(table=self.mysql_table, rows=hive_results)
diff --git a/airflow/providers/apache/hive/transfers/hive_to_samba.py b/airflow/providers/apache/hive/transfers/hive_to_samba.py
index c5ab66efa42e0..0ad815ea9cf73 100644
--- a/airflow/providers/apache/hive/transfers/hive_to_samba.py
+++ b/airflow/providers/apache/hive/transfers/hive_to_samba.py
@@ -15,8 +15,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""This module contains an operator to move data from Hive to Samba."""
+from __future__ import annotations
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Sequence
@@ -42,33 +42,33 @@ class HiveToSambaOperator(BaseOperator):
:ref: `Hive Server2 thrift service connection id `.
"""
- template_fields: Sequence[str] = ('hql', 'destination_filepath')
+ template_fields: Sequence[str] = ("hql", "destination_filepath")
template_ext: Sequence[str] = (
- '.hql',
- '.sql',
+ ".hql",
+ ".sql",
)
- template_fields_renderers = {'hql': 'hql'}
+ template_fields_renderers = {"hql": "hql"}
def __init__(
self,
*,
hql: str,
destination_filepath: str,
- samba_conn_id: str = 'samba_default',
- hiveserver2_conn_id: str = 'hiveserver2_default',
+ samba_conn_id: str = "samba_default",
+ hiveserver2_conn_id: str = "hiveserver2_default",
**kwargs,
) -> None:
super().__init__(**kwargs)
self.hiveserver2_conn_id = hiveserver2_conn_id
self.samba_conn_id = samba_conn_id
self.destination_filepath = destination_filepath
- self.hql = hql.strip().rstrip(';')
+ self.hql = hql.strip().rstrip(";")
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
with NamedTemporaryFile() as tmp_file:
self.log.info("Fetching file from Hive")
hive = HiveServer2Hook(hiveserver2_conn_id=self.hiveserver2_conn_id)
- hive.to_csv(hql=self.hql, csv_filepath=tmp_file.name, hive_conf=context_to_airflow_vars(context))
+ hive.to_csv(self.hql, csv_filepath=tmp_file.name, hive_conf=context_to_airflow_vars(context))
self.log.info("Pushing to samba")
samba = SambaHook(samba_conn_id=self.samba_conn_id)
samba.push_from_local(self.destination_filepath, tmp_file.name)
diff --git a/airflow/providers/apache/hive/transfers/mssql_to_hive.py b/airflow/providers/apache/hive/transfers/mssql_to_hive.py
index 912c2a58a36bd..9cdd581911238 100644
--- a/airflow/providers/apache/hive/transfers/mssql_to_hive.py
+++ b/airflow/providers/apache/hive/transfers/mssql_to_hive.py
@@ -15,12 +15,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""This module contains an operator to move data from MSSQL to Hive."""
+from __future__ import annotations
from collections import OrderedDict
from tempfile import NamedTemporaryFile
-from typing import TYPE_CHECKING, Dict, Optional, Sequence
+from typing import TYPE_CHECKING, Sequence
import pymssql
import unicodecsv as csv
@@ -28,7 +28,6 @@
from airflow.models import BaseOperator
from airflow.providers.apache.hive.hooks.hive import HiveCliHook
from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook
-from airflow.www import utils as wwwutils
if TYPE_CHECKING:
from airflow.utils.context import Context
@@ -64,11 +63,10 @@ class MsSqlToHiveOperator(BaseOperator):
:param tblproperties: TBLPROPERTIES of the hive table being created
"""
- template_fields: Sequence[str] = ('sql', 'partition', 'hive_table')
- template_ext: Sequence[str] = ('.sql',)
- # TODO: Remove renderer check when the provider has an Airflow 2.3+ requirement.
- template_fields_renderers = {'sql': 'tsql' if 'tsql' in wwwutils.get_attr_renderer() else 'sql'}
- ui_color = '#a0e08c'
+ template_fields: Sequence[str] = ("sql", "partition", "hive_table")
+ template_ext: Sequence[str] = (".sql",)
+ template_fields_renderers = {"sql": "tsql"}
+ ui_color = "#a0e08c"
def __init__(
self,
@@ -77,11 +75,11 @@ def __init__(
hive_table: str,
create: bool = True,
recreate: bool = False,
- partition: Optional[Dict] = None,
+ partition: dict | None = None,
delimiter: str = chr(1),
- mssql_conn_id: str = 'mssql_default',
- hive_cli_conn_id: str = 'hive_cli_default',
- tblproperties: Optional[Dict] = None,
+ mssql_conn_id: str = "mssql_default",
+ hive_cli_conn_id: str = "hive_cli_default",
+ tblproperties: dict | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -100,26 +98,24 @@ def __init__(
def type_map(cls, mssql_type: int) -> str:
"""Maps MsSQL type to Hive type."""
map_dict = {
- pymssql.BINARY.value: 'INT',
- pymssql.DECIMAL.value: 'FLOAT',
- pymssql.NUMBER.value: 'INT',
+ pymssql.BINARY.value: "INT",
+ pymssql.DECIMAL.value: "FLOAT",
+ pymssql.NUMBER.value: "INT",
}
- return map_dict.get(mssql_type, 'STRING')
+ return map_dict.get(mssql_type, "STRING")
- def execute(self, context: "Context"):
+ def execute(self, context: Context):
mssql = MsSqlHook(mssql_conn_id=self.mssql_conn_id)
self.log.info("Dumping Microsoft SQL Server query results to local file")
with mssql.get_conn() as conn:
with conn.cursor() as cursor:
cursor.execute(self.sql)
with NamedTemporaryFile("w") as tmp_file:
- csv_writer = csv.writer(tmp_file, delimiter=self.delimiter, encoding='utf-8')
+ csv_writer = csv.writer(tmp_file, delimiter=self.delimiter, encoding="utf-8")
field_dict = OrderedDict()
- col_count = 0
- for field in cursor.description:
- col_count += 1
+ for col_count, field in enumerate(cursor.description, start=1):
col_position = f"Column{col_count}"
- field_dict[col_position if field[0] == '' else field[0]] = self.type_map(field[1])
+ field_dict[col_position if field[0] == "" else field[0]] = self.type_map(field[1])
csv_writer.writerows(cursor)
tmp_file.flush()
diff --git a/airflow/providers/apache/hive/transfers/mysql_to_hive.py b/airflow/providers/apache/hive/transfers/mysql_to_hive.py
index ccdaf7b682fc1..1693bf48dad9a 100644
--- a/airflow/providers/apache/hive/transfers/mysql_to_hive.py
+++ b/airflow/providers/apache/hive/transfers/mysql_to_hive.py
@@ -15,12 +15,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""This module contains an operator to move data from MySQL to Hive."""
+from __future__ import annotations
from collections import OrderedDict
from tempfile import NamedTemporaryFile
-from typing import TYPE_CHECKING, Dict, Optional, Sequence
+from typing import TYPE_CHECKING, Sequence
import MySQLdb
import unicodecsv as csv
@@ -68,10 +68,10 @@ class MySqlToHiveOperator(BaseOperator):
:param tblproperties: TBLPROPERTIES of the hive table being created
"""
- template_fields: Sequence[str] = ('sql', 'partition', 'hive_table')
- template_ext: Sequence[str] = ('.sql',)
- template_fields_renderers = {'sql': 'mysql'}
- ui_color = '#a0e08c'
+ template_fields: Sequence[str] = ("sql", "partition", "hive_table")
+ template_ext: Sequence[str] = (".sql",)
+ template_fields_renderers = {"sql": "mysql"}
+ ui_color = "#a0e08c"
def __init__(
self,
@@ -80,14 +80,14 @@ def __init__(
hive_table: str,
create: bool = True,
recreate: bool = False,
- partition: Optional[Dict] = None,
+ partition: dict | None = None,
delimiter: str = chr(1),
- quoting: Optional[str] = None,
+ quoting: str | None = None,
quotechar: str = '"',
- escapechar: Optional[str] = None,
- mysql_conn_id: str = 'mysql_default',
- hive_cli_conn_id: str = 'hive_cli_default',
- tblproperties: Optional[Dict] = None,
+ escapechar: str | None = None,
+ mysql_conn_id: str = "mysql_default",
+ hive_cli_conn_id: str = "hive_cli_default",
+ tblproperties: dict | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -110,22 +110,22 @@ def type_map(cls, mysql_type: int) -> str:
"""Maps MySQL type to Hive type."""
types = MySQLdb.constants.FIELD_TYPE
type_map = {
- types.BIT: 'INT',
- types.DECIMAL: 'DOUBLE',
- types.NEWDECIMAL: 'DOUBLE',
- types.DOUBLE: 'DOUBLE',
- types.FLOAT: 'DOUBLE',
- types.INT24: 'INT',
- types.LONG: 'BIGINT',
- types.LONGLONG: 'DECIMAL(38,0)',
- types.SHORT: 'INT',
- types.TINY: 'SMALLINT',
- types.YEAR: 'INT',
- types.TIMESTAMP: 'TIMESTAMP',
+ types.BIT: "INT",
+ types.DECIMAL: "DOUBLE",
+ types.NEWDECIMAL: "DOUBLE",
+ types.DOUBLE: "DOUBLE",
+ types.FLOAT: "DOUBLE",
+ types.INT24: "INT",
+ types.LONG: "BIGINT",
+ types.LONGLONG: "DECIMAL(38,0)",
+ types.SHORT: "INT",
+ types.TINY: "SMALLINT",
+ types.YEAR: "INT",
+ types.TIMESTAMP: "TIMESTAMP",
}
- return type_map.get(mysql_type, 'STRING')
+ return type_map.get(mysql_type, "STRING")
- def execute(self, context: "Context"):
+ def execute(self, context: Context):
hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id)
mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id)
diff --git a/airflow/providers/apache/hive/transfers/s3_to_hive.py b/airflow/providers/apache/hive/transfers/s3_to_hive.py
index cc189303e0584..470d1d23c7306 100644
--- a/airflow/providers/apache/hive/transfers/s3_to_hive.py
+++ b/airflow/providers/apache/hive/transfers/s3_to_hive.py
@@ -15,15 +15,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""This module contains an operator to move data from an S3 bucket to Hive."""
+from __future__ import annotations
import bz2
import gzip
import os
import tempfile
from tempfile import NamedTemporaryFile, TemporaryDirectory
-from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Union
+from typing import TYPE_CHECKING, Any, Sequence
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
@@ -86,29 +86,29 @@ class S3ToHiveOperator(BaseOperator):
:param select_expression: S3 Select expression
"""
- template_fields: Sequence[str] = ('s3_key', 'partition', 'hive_table')
+ template_fields: Sequence[str] = ("s3_key", "partition", "hive_table")
template_ext: Sequence[str] = ()
- ui_color = '#a0e08c'
+ ui_color = "#a0e08c"
def __init__(
self,
*,
s3_key: str,
- field_dict: Dict,
+ field_dict: dict,
hive_table: str,
- delimiter: str = ',',
+ delimiter: str = ",",
create: bool = True,
recreate: bool = False,
- partition: Optional[Dict] = None,
+ partition: dict | None = None,
headers: bool = False,
check_headers: bool = False,
wildcard_match: bool = False,
- aws_conn_id: str = 'aws_default',
- verify: Optional[Union[bool, str]] = None,
- hive_cli_conn_id: str = 'hive_cli_default',
+ aws_conn_id: str = "aws_default",
+ verify: bool | str | None = None,
+ hive_cli_conn_id: str = "hive_cli_default",
input_compressed: bool = False,
- tblproperties: Optional[Dict] = None,
- select_expression: Optional[str] = None,
+ tblproperties: dict | None = None,
+ select_expression: str | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -132,7 +132,7 @@ def __init__(
if self.check_headers and not (self.field_dict is not None and self.headers):
raise AirflowException("To check_headers provide field_dict and headers")
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
# Downloading file from S3
s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
hive_hook = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id)
@@ -142,29 +142,29 @@ def execute(self, context: 'Context'):
if not s3_hook.check_for_wildcard_key(self.s3_key):
raise AirflowException(f"No key matches {self.s3_key}")
s3_key_object = s3_hook.get_wildcard_key(self.s3_key)
- else:
- if not s3_hook.check_for_key(self.s3_key):
- raise AirflowException(f"The key {self.s3_key} does not exists")
+ elif s3_hook.check_for_key(self.s3_key):
s3_key_object = s3_hook.get_key(self.s3_key)
+ else:
+ raise AirflowException(f"The key {self.s3_key} does not exists")
_, file_ext = os.path.splitext(s3_key_object.key)
- if self.select_expression and self.input_compressed and file_ext.lower() != '.gz':
+ if self.select_expression and self.input_compressed and file_ext.lower() != ".gz":
raise AirflowException("GZIP is the only compression format Amazon S3 Select supports")
- with TemporaryDirectory(prefix='tmps32hive_') as tmp_dir, NamedTemporaryFile(
+ with TemporaryDirectory(prefix="tmps32hive_") as tmp_dir, NamedTemporaryFile(
mode="wb", dir=tmp_dir, suffix=file_ext
) as f:
self.log.info("Dumping S3 key %s contents to local file %s", s3_key_object.key, f.name)
if self.select_expression:
option = {}
if self.headers:
- option['FileHeaderInfo'] = 'USE'
+ option["FileHeaderInfo"] = "USE"
if self.delimiter:
- option['FieldDelimiter'] = self.delimiter
+ option["FieldDelimiter"] = self.delimiter
- input_serialization: Dict[str, Any] = {'CSV': option}
+ input_serialization: dict[str, Any] = {"CSV": option}
if self.input_compressed:
- input_serialization['CompressionType'] = 'GZIP'
+ input_serialization["CompressionType"] = "GZIP"
content = s3_hook.select_key(
bucket_name=s3_key_object.bucket_name,
@@ -227,8 +227,7 @@ def execute(self, context: 'Context'):
def _get_top_row_as_list(self, file_name):
with open(file_name) as file:
header_line = file.readline().strip()
- header_list = header_line.split(self.delimiter)
- return header_list
+ return header_line.split(self.delimiter)
def _match_headers(self, header_list):
if not header_list:
@@ -254,13 +253,13 @@ def _match_headers(self, header_list):
def _delete_top_row_and_compress(input_file_name, output_file_ext, dest_dir):
# When output_file_ext is not defined, file is not compressed
open_fn = open
- if output_file_ext.lower() == '.gz':
+ if output_file_ext.lower() == ".gz":
open_fn = gzip.GzipFile
- elif output_file_ext.lower() == '.bz2':
+ elif output_file_ext.lower() == ".bz2":
open_fn = bz2.BZ2File
_, fn_output = tempfile.mkstemp(suffix=output_file_ext, dir=dest_dir)
- with open(input_file_name, 'rb') as f_in, open_fn(fn_output, 'wb') as f_out:
+ with open(input_file_name, "rb") as f_in, open_fn(fn_output, "wb") as f_out:
f_in.seek(0)
next(f_in)
for line in f_in:
diff --git a/airflow/providers/apache/hive/transfers/vertica_to_hive.py b/airflow/providers/apache/hive/transfers/vertica_to_hive.py
index 7e53638b0c70a..92a677af92cdd 100644
--- a/airflow/providers/apache/hive/transfers/vertica_to_hive.py
+++ b/airflow/providers/apache/hive/transfers/vertica_to_hive.py
@@ -15,12 +15,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""This module contains an operator to move data from Vertica to Hive."""
+from __future__ import annotations
from collections import OrderedDict
from tempfile import NamedTemporaryFile
-from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence
+from typing import TYPE_CHECKING, Any, Sequence
import unicodecsv as csv
@@ -60,10 +60,10 @@ class VerticaToHiveOperator(BaseOperator):
:ref:`Hive CLI connection id `.
"""
- template_fields: Sequence[str] = ('sql', 'partition', 'hive_table')
- template_ext: Sequence[str] = ('.sql',)
- template_fields_renderers = {'sql': 'sql'}
- ui_color = '#b4e0ff'
+ template_fields: Sequence[str] = ("sql", "partition", "hive_table")
+ template_ext: Sequence[str] = (".sql",)
+ template_fields_renderers = {"sql": "sql"}
+ ui_color = "#b4e0ff"
def __init__(
self,
@@ -72,10 +72,10 @@ def __init__(
hive_table: str,
create: bool = True,
recreate: bool = False,
- partition: Optional[Dict] = None,
+ partition: dict | None = None,
delimiter: str = chr(1),
- vertica_conn_id: str = 'vertica_default',
- hive_cli_conn_id: str = 'hive_cli_default',
+ vertica_conn_id: str = "vertica_default",
+ hive_cli_conn_id: str = "hive_cli_default",
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
@@ -97,16 +97,16 @@ def type_map(cls, vertica_type):
https://github.com/uber/vertica-python/blob/master/vertica_python/vertica/column.py
"""
type_map = {
- 5: 'BOOLEAN',
- 6: 'INT',
- 7: 'FLOAT',
- 8: 'STRING',
- 9: 'STRING',
- 16: 'FLOAT',
+ 5: "BOOLEAN",
+ 6: "INT",
+ 7: "FLOAT",
+ 8: "STRING",
+ 9: "STRING",
+ 16: "FLOAT",
}
- return type_map.get(vertica_type, 'STRING')
+ return type_map.get(vertica_type, "STRING")
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
hive = HiveCliHook(hive_cli_conn_id=self.hive_cli_conn_id)
vertica = VerticaHook(vertica_conn_id=self.vertica_conn_id)
@@ -115,13 +115,11 @@ def execute(self, context: 'Context'):
cursor = conn.cursor()
cursor.execute(self.sql)
with NamedTemporaryFile("w") as f:
- csv_writer = csv.writer(f, delimiter=self.delimiter, encoding='utf-8')
+ csv_writer = csv.writer(f, delimiter=self.delimiter, encoding="utf-8")
field_dict = OrderedDict()
- col_count = 0
- for field in cursor.description:
- col_count += 1
+ for col_count, field in enumerate(cursor.description, start=1):
col_position = f"Column{col_count}"
- field_dict[col_position if field[0] == '' else field[0]] = self.type_map(field[1])
+ field_dict[col_position if field[0] == "" else field[0]] = self.type_map(field[1])
csv_writer.writerows(cursor.iterate())
f.flush()
cursor.close()
diff --git a/airflow/providers/apache/kylin/.latest-doc-only-change.txt b/airflow/providers/apache/kylin/.latest-doc-only-change.txt
index 28124098645cf..ff7136e07d744 100644
--- a/airflow/providers/apache/kylin/.latest-doc-only-change.txt
+++ b/airflow/providers/apache/kylin/.latest-doc-only-change.txt
@@ -1 +1 @@
-6c3a67d4fccafe4ab6cd9ec8c7bacf2677f17038
+06acf40a4337759797f666d5bb27a5a393b74fed
diff --git a/airflow/providers/apache/kylin/CHANGELOG.rst b/airflow/providers/apache/kylin/CHANGELOG.rst
index 5b25eda615175..f15f8611cf060 100644
--- a/airflow/providers/apache/kylin/CHANGELOG.rst
+++ b/airflow/providers/apache/kylin/CHANGELOG.rst
@@ -16,9 +16,60 @@
under the License.
+.. NOTE TO CONTRIBUTORS:
+ Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes
+ and you want to add an explanation to the users on how they are supposed to deal with them.
+ The changelog is updated and maintained semi-automatically by release manager.
+
Changelog
---------
+3.1.0
+.....
+
+This release of provider is only available for Airflow 2.3+ as explained in the
+`Apache Airflow providers support policy `_.
+
+Misc
+~~~~
+
+* ``Move min airflow version to 2.3.0 for all providers (#27196)``
+
+Bug Fixes
+~~~~~~~~~
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Add documentation for July 2022 Provider's release (#25030)``
+ * ``Enable string normalization in python formatting - providers (#27205)``
+ * ``Update docs for September Provider's release (#26731)``
+ * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)``
+ * ``Prepare docs for new providers release (August 2022) (#25618)``
+ * ``Remove "bad characters" from our codebase (#24841)``
+ * ``Move provider dependencies to inside provider folders (#24672)``
+
+3.0.0
+.....
+
+Breaking changes
+~~~~~~~~~~~~~~~~
+
+* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow
+ providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers
+
+Misc
+~~~~
+
+* ``AIP-47 - Migrate kylin DAGs to new design #22439 (#24205)``
+* ``chore: Refactoring and Cleaning Apache Providers (#24219)``
+
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Add explanatory note for contributors about updating Changelog (#24229)``
+ * ``Prepare docs for May 2022 provider's release (#24231)``
+ * ``Update package description to remove double min-airflow specification (#24292)``
+
2.0.4
.....
diff --git a/airflow/providers/apache/kylin/example_dags/example_kylin_dag.py b/airflow/providers/apache/kylin/example_dags/example_kylin_dag.py
deleted file mode 100644
index 0d68b36d6534e..0000000000000
--- a/airflow/providers/apache/kylin/example_dags/example_kylin_dag.py
+++ /dev/null
@@ -1,114 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""
-This is an example DAG which uses the KylinCubeOperator.
-The tasks below include kylin build, refresh, merge operation.
-"""
-from datetime import datetime
-
-from airflow import DAG
-from airflow.providers.apache.kylin.operators.kylin_cube import KylinCubeOperator
-
-dag = DAG(
- dag_id='example_kylin_operator',
- schedule_interval=None,
- start_date=datetime(2021, 1, 1),
- catchup=False,
- default_args={'project': 'learn_kylin', 'cube': 'kylin_sales_cube'},
- tags=['example'],
-)
-
-
-@dag.task
-def gen_build_time():
- """
- Gen build time and push to XCom (with key of "return_value")
- :return: A dict with build time values.
- """
- return {'date_start': '1325347200000', 'date_end': '1325433600000'}
-
-
-gen_build_time_task = gen_build_time()
-gen_build_time_output_date_start = gen_build_time_task['date_start']
-gen_build_time_output_date_end = gen_build_time_task['date_end']
-
-build_task1 = KylinCubeOperator(
- task_id="kylin_build_1",
- command='build',
- start_time=gen_build_time_output_date_start,
- end_time=gen_build_time_output_date_end,
- is_track_job=True,
- dag=dag,
-)
-
-build_task2 = KylinCubeOperator(
- task_id="kylin_build_2",
- command='build',
- start_time=gen_build_time_output_date_end,
- end_time='1325520000000',
- is_track_job=True,
- dag=dag,
-)
-
-refresh_task1 = KylinCubeOperator(
- task_id="kylin_refresh_1",
- command='refresh',
- start_time=gen_build_time_output_date_start,
- end_time=gen_build_time_output_date_end,
- is_track_job=True,
- dag=dag,
-)
-
-merge_task = KylinCubeOperator(
- task_id="kylin_merge",
- command='merge',
- start_time=gen_build_time_output_date_start,
- end_time='1325520000000',
- is_track_job=True,
- dag=dag,
-)
-
-disable_task = KylinCubeOperator(
- task_id="kylin_disable",
- command='disable',
- dag=dag,
-)
-
-purge_task = KylinCubeOperator(
- task_id="kylin_purge",
- command='purge',
- dag=dag,
-)
-
-build_task3 = KylinCubeOperator(
- task_id="kylin_build_3",
- command='build',
- start_time=gen_build_time_output_date_end,
- end_time='1328730000000',
- dag=dag,
-)
-
-build_task1 >> build_task2 >> refresh_task1 >> merge_task >> disable_task >> purge_task >> build_task3
-
-# Task dependency created via `XComArgs`:
-# gen_build_time >> build_task1
-# gen_build_time >> build_task2
-# gen_build_time >> refresh_task1
-# gen_build_time >> merge_task
-# gen_build_time >> build_task3
diff --git a/airflow/providers/apache/kylin/hooks/kylin.py b/airflow/providers/apache/kylin/hooks/kylin.py
index 032b15c7e5bbf..709d4e7a281c7 100644
--- a/airflow/providers/apache/kylin/hooks/kylin.py
+++ b/airflow/providers/apache/kylin/hooks/kylin.py
@@ -15,8 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
-from typing import Optional
+from __future__ import annotations
from kylinpy import exceptions, kylinpy
@@ -35,9 +34,9 @@ class KylinHook(BaseHook):
def __init__(
self,
- kylin_conn_id: str = 'kylin_default',
- project: Optional[str] = None,
- dsn: Optional[str] = None,
+ kylin_conn_id: str = "kylin_default",
+ project: str | None = None,
+ dsn: str | None = None,
):
super().__init__()
self.kylin_conn_id = kylin_conn_id
@@ -48,16 +47,15 @@ def get_conn(self):
conn = self.get_connection(self.kylin_conn_id)
if self.dsn:
return kylinpy.create_kylin(self.dsn)
- else:
- self.project = self.project if self.project else conn.schema
- return kylinpy.Kylin(
- conn.host,
- username=conn.login,
- password=conn.password,
- port=conn.port,
- project=self.project,
- **conn.extra_dejson,
- )
+ self.project = self.project or conn.schema
+ return kylinpy.Kylin(
+ conn.host,
+ username=conn.login,
+ password=conn.password,
+ port=conn.port,
+ project=self.project,
+ **conn.extra_dejson,
+ )
def cube_run(self, datasource_name, op, **op_args):
"""
@@ -70,8 +68,7 @@ def cube_run(self, datasource_name, op, **op_args):
"""
cube_source = self.get_conn().get_datasource(datasource_name)
try:
- response = cube_source.invoke_command(op, **op_args)
- return response
+ return cube_source.invoke_command(op, **op_args)
except exceptions.KylinError as err:
raise AirflowException(f"Cube operation {op} error , Message: {err}")
diff --git a/airflow/providers/apache/kylin/operators/kylin_cube.py b/airflow/providers/apache/kylin/operators/kylin_cube.py
index 5fe91ee831934..3f873cbaf54a2 100644
--- a/airflow/providers/apache/kylin/operators/kylin_cube.py
+++ b/airflow/providers/apache/kylin/operators/kylin_cube.py
@@ -15,10 +15,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
import time
from datetime import datetime
-from typing import TYPE_CHECKING, Optional, Sequence
+from typing import TYPE_CHECKING, Sequence
from kylinpy import kylinpy
@@ -46,14 +47,14 @@ class KylinCubeOperator(BaseOperator):
:param command: (kylin command include 'build', 'merge', 'refresh', 'delete',
'build_streaming', 'merge_streaming', 'refresh_streaming', 'disable', 'enable',
'purge', 'clone', 'drop'.
- build - use /kylin/api/cubes/{cubeName}/build rest api,and buildType is ‘BUILD’,
+ build - use /kylin/api/cubes/{cubeName}/build rest api,and buildType is 'BUILD',
and you should give start_time and end_time
- refresh - use build rest api,and buildType is ‘REFRESH’
- merge - use build rest api,and buildType is ‘MERGE’
- build_streaming - use /kylin/api/cubes/{cubeName}/build2 rest api,and buildType is ‘BUILD’
+ refresh - use build rest api,and buildType is 'REFRESH'
+ merge - use build rest api,and buildType is 'MERGE'
+ build_streaming - use /kylin/api/cubes/{cubeName}/build2 rest api,and buildType is 'BUILD'
and you should give offset_start and offset_end
- refresh_streaming - use build2 rest api,and buildType is ‘REFRESH’
- merge_streaming - use build2 rest api,and buildType is ‘MERGE’
+ refresh_streaming - use build2 rest api,and buildType is 'REFRESH'
+ merge_streaming - use build2 rest api,and buildType is 'MERGE'
delete - delete segment, and you should give segment_name value
disable - disable cube
enable - enable cube
@@ -75,41 +76,41 @@ class KylinCubeOperator(BaseOperator):
"""
template_fields: Sequence[str] = (
- 'project',
- 'cube',
- 'dsn',
- 'command',
- 'start_time',
- 'end_time',
- 'segment_name',
- 'offset_start',
- 'offset_end',
+ "project",
+ "cube",
+ "dsn",
+ "command",
+ "start_time",
+ "end_time",
+ "segment_name",
+ "offset_start",
+ "offset_end",
)
- ui_color = '#E79C46'
+ ui_color = "#E79C46"
build_command = {
- 'fullbuild',
- 'build',
- 'merge',
- 'refresh',
- 'build_streaming',
- 'merge_streaming',
- 'refresh_streaming',
+ "fullbuild",
+ "build",
+ "merge",
+ "refresh",
+ "build_streaming",
+ "merge_streaming",
+ "refresh_streaming",
}
jobs_end_status = {"FINISHED", "ERROR", "DISCARDED", "KILLED", "SUICIDAL", "STOPPED"}
def __init__(
self,
*,
- kylin_conn_id: str = 'kylin_default',
- project: Optional[str] = None,
- cube: Optional[str] = None,
- dsn: Optional[str] = None,
- command: Optional[str] = None,
- start_time: Optional[str] = None,
- end_time: Optional[str] = None,
- offset_start: Optional[str] = None,
- offset_end: Optional[str] = None,
- segment_name: Optional[str] = None,
+ kylin_conn_id: str = "kylin_default",
+ project: str | None = None,
+ cube: str | None = None,
+ dsn: str | None = None,
+ command: str | None = None,
+ start_time: str | None = None,
+ end_time: str | None = None,
+ offset_start: str | None = None,
+ offset_end: str | None = None,
+ segment_name: str | None = None,
is_track_job: bool = False,
interval: int = 60,
timeout: int = 60 * 60 * 24,
@@ -133,24 +134,24 @@ def __init__(
self.eager_error_status = eager_error_status
self.jobs_error_status = [stat.upper() for stat in eager_error_status]
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
_hook = KylinHook(kylin_conn_id=self.kylin_conn_id, project=self.project, dsn=self.dsn)
_support_invoke_command = kylinpy.CubeSource.support_invoke_command
if not self.command:
- raise AirflowException(f'Kylin:Command {self.command} can not be empty')
+ raise AirflowException(f"Kylin:Command {self.command} can not be empty")
if self.command.lower() not in _support_invoke_command:
raise AirflowException(
- f'Kylin:Command {self.command} can not match kylin command list {_support_invoke_command}'
+ f"Kylin:Command {self.command} can not match kylin command list {_support_invoke_command}"
)
kylinpy_params = {
- 'start': datetime.fromtimestamp(int(self.start_time) / 1000) if self.start_time else None,
- 'end': datetime.fromtimestamp(int(self.end_time) / 1000) if self.end_time else None,
- 'name': self.segment_name,
- 'offset_start': int(self.offset_start) if self.offset_start else None,
- 'offset_end': int(self.offset_end) if self.offset_end else None,
+ "start": datetime.fromtimestamp(int(self.start_time) / 1000) if self.start_time else None,
+ "end": datetime.fromtimestamp(int(self.end_time) / 1000) if self.end_time else None,
+ "name": self.segment_name,
+ "offset_start": int(self.offset_start) if self.offset_start else None,
+ "offset_end": int(self.offset_end) if self.offset_end else None,
}
rsp_data = _hook.cube_run(self.cube, self.command.lower(), **kylinpy_params)
if self.is_track_job and self.command.lower() in self.build_command:
@@ -163,13 +164,13 @@ def execute(self, context: 'Context'):
job_status = None
while job_status not in self.jobs_end_status:
if time.monotonic() - started_at > self.timeout:
- raise AirflowException(f'kylin job {job_id} timeout')
+ raise AirflowException(f"kylin job {job_id} timeout")
time.sleep(self.interval)
job_status = _hook.get_job_status(job_id)
- self.log.info('Kylin job status is %s ', job_status)
+ self.log.info("Kylin job status is %s ", job_status)
if job_status in self.jobs_error_status:
- raise AirflowException(f'Kylin job {job_id} status {job_status} is error ')
+ raise AirflowException(f"Kylin job {job_id} status {job_status} is error ")
if self.do_xcom_push:
return rsp_data
diff --git a/airflow/providers/apache/kylin/provider.yaml b/airflow/providers/apache/kylin/provider.yaml
index 613042f784a92..73f69523e11e8 100644
--- a/airflow/providers/apache/kylin/provider.yaml
+++ b/airflow/providers/apache/kylin/provider.yaml
@@ -22,6 +22,8 @@ description: |
`Apache Kylin `__
versions:
+ - 3.1.0
+ - 3.0.0
- 2.0.4
- 2.0.3
- 2.0.2
@@ -30,8 +32,9 @@ versions:
- 1.0.1
- 1.0.0
-additional-dependencies:
- - apache-airflow>=2.1.0
+dependencies:
+ - apache-airflow>=2.3.0
+ - kylinpy>=2.6
integrations:
- integration-name: Apache Kylin
diff --git a/airflow/providers/apache/livy/.latest-doc-only-change.txt b/airflow/providers/apache/livy/.latest-doc-only-change.txt
index 28124098645cf..ff7136e07d744 100644
--- a/airflow/providers/apache/livy/.latest-doc-only-change.txt
+++ b/airflow/providers/apache/livy/.latest-doc-only-change.txt
@@ -1 +1 @@
-6c3a67d4fccafe4ab6cd9ec8c7bacf2677f17038
+06acf40a4337759797f666d5bb27a5a393b74fed
diff --git a/airflow/providers/apache/livy/CHANGELOG.rst b/airflow/providers/apache/livy/CHANGELOG.rst
index 43556648aed02..7a4e720bdd9f4 100644
--- a/airflow/providers/apache/livy/CHANGELOG.rst
+++ b/airflow/providers/apache/livy/CHANGELOG.rst
@@ -16,9 +16,70 @@
under the License.
+.. NOTE TO CONTRIBUTORS:
+ Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes
+ and you want to add an explanation to the users on how they are supposed to deal with them.
+ The changelog is updated and maintained semi-automatically by release manager.
+
Changelog
---------
+3.2.0
+.....
+
+This release of provider is only available for Airflow 2.3+ as explained in the
+`Apache Airflow providers support policy `_.
+
+Misc
+~~~~
+
+* ``Move min airflow version to 2.3.0 for all providers (#27196)``
+
+Features
+~~~~~~~~
+
+* ``Add template to livy operator documentation (#27404)``
+* ``Add Spark's 'appId' to xcom output (#27376)``
+* ``add template field renderer to livy operator (#27321)``
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Update old style typing (#26872)``
+ * ``Enable string normalization in python formatting - providers (#27205)``
+ * ``Update docs for September Provider's release (#26731)``
+ * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)``
+
+3.1.0
+.....
+
+Features
+~~~~~~~~
+
+* ``Add auth_type to LivyHook (#25183)``
+
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Add documentation for July 2022 Provider's release (#25030)``
+ * ``Move provider dependencies to inside provider folders (#24672)``
+ * ``Remove 'hook-class-names' from provider.yaml (#24702)``
+
+3.0.0
+.....
+
+Breaking changes
+~~~~~~~~~~~~~~~~
+
+* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow
+ providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Add explanatory note for contributors about updating Changelog (#24229)``
+ * ``AIP-47 - Migrate livy DAGs to new design #22439 (#24208)``
+ * ``Prepare docs for May 2022 provider's release (#24231)``
+ * ``Update package description to remove double min-airflow specification (#24292)``
+
2.2.3
.....
diff --git a/airflow/providers/apache/livy/example_dags/example_livy.py b/airflow/providers/apache/livy/example_dags/example_livy.py
deleted file mode 100644
index cf7bbbfb1b18d..0000000000000
--- a/airflow/providers/apache/livy/example_dags/example_livy.py
+++ /dev/null
@@ -1,50 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""
-This is an example DAG which uses the LivyOperator.
-The tasks below trigger the computation of pi on the Spark instance
-using the Java and Python executables provided in the example library.
-"""
-from datetime import datetime
-
-from airflow import DAG
-from airflow.providers.apache.livy.operators.livy import LivyOperator
-
-with DAG(
- dag_id='example_livy_operator',
- default_args={'args': [10]},
- schedule_interval='@daily',
- start_date=datetime(2021, 1, 1),
- catchup=False,
-) as dag:
-
- # [START create_livy]
- livy_java_task = LivyOperator(
- task_id='pi_java_task',
- file='/spark-examples.jar',
- num_executors=1,
- conf={
- 'spark.shuffle.compress': 'false',
- },
- class_name='org.apache.spark.examples.SparkPi',
- )
-
- livy_python_task = LivyOperator(task_id='pi_python_task', file='/pi.py', polling_interval=60)
-
- livy_java_task >> livy_python_task
- # [END create_livy]
diff --git a/airflow/providers/apache/livy/hooks/livy.py b/airflow/providers/apache/livy/hooks/livy.py
index e809dc02c7fb4..8c3e49a545431 100644
--- a/airflow/providers/apache/livy/hooks/livy.py
+++ b/airflow/providers/apache/livy/hooks/livy.py
@@ -14,12 +14,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""This module contains the Apache Livy hook."""
+from __future__ import annotations
+
import json
import re
from enum import Enum
-from typing import Any, Dict, List, Optional, Sequence, Union
+from typing import Any, Sequence
import requests
@@ -31,16 +32,16 @@
class BatchState(Enum):
"""Batch session states"""
- NOT_STARTED = 'not_started'
- STARTING = 'starting'
- RUNNING = 'running'
- IDLE = 'idle'
- BUSY = 'busy'
- SHUTTING_DOWN = 'shutting_down'
- ERROR = 'error'
- DEAD = 'dead'
- KILLED = 'killed'
- SUCCESS = 'success'
+ NOT_STARTED = "not_started"
+ STARTING = "starting"
+ RUNNING = "running"
+ IDLE = "idle"
+ BUSY = "busy"
+ SHUTTING_DOWN = "shutting_down"
+ ERROR = "error"
+ DEAD = "dead"
+ KILLED = "killed"
+ SUCCESS = "success"
class LivyHook(HttpHook, LoggingMixin):
@@ -50,6 +51,7 @@ class LivyHook(HttpHook, LoggingMixin):
:param livy_conn_id: reference to a pre-defined Livy Connection.
:param extra_options: A dictionary of options passed to Livy.
:param extra_headers: A dictionary of headers passed to the HTTP request to livy.
+ :param auth_type: The auth type for the service.
.. seealso::
For more details refer to the Apache Livy API reference:
@@ -63,30 +65,31 @@ class LivyHook(HttpHook, LoggingMixin):
BatchState.ERROR,
}
- _def_headers = {'Content-Type': 'application/json', 'Accept': 'application/json'}
+ _def_headers = {"Content-Type": "application/json", "Accept": "application/json"}
- conn_name_attr = 'livy_conn_id'
- default_conn_name = 'livy_default'
- conn_type = 'livy'
- hook_name = 'Apache Livy'
+ conn_name_attr = "livy_conn_id"
+ default_conn_name = "livy_default"
+ conn_type = "livy"
+ hook_name = "Apache Livy"
def __init__(
self,
livy_conn_id: str = default_conn_name,
- extra_options: Optional[Dict[str, Any]] = None,
- extra_headers: Optional[Dict[str, Any]] = None,
+ extra_options: dict[str, Any] | None = None,
+ extra_headers: dict[str, Any] | None = None,
+ auth_type: Any | None = None,
) -> None:
super().__init__(http_conn_id=livy_conn_id)
self.extra_headers = extra_headers or {}
self.extra_options = extra_options or {}
+ self.auth_type = auth_type or self.auth_type
- def get_conn(self, headers: Optional[Dict[str, Any]] = None) -> Any:
+ def get_conn(self, headers: dict[str, Any] | None = None) -> Any:
"""
Returns http session for use with requests
:param headers: additional headers to be passed through as a dictionary
:return: requests session
- :rtype: requests.Session
"""
tmp_headers = self._def_headers.copy() # setting default headers
if headers:
@@ -96,10 +99,10 @@ def get_conn(self, headers: Optional[Dict[str, Any]] = None) -> Any:
def run_method(
self,
endpoint: str,
- method: str = 'GET',
- data: Optional[Any] = None,
- headers: Optional[Dict[str, Any]] = None,
- retry_args: Optional[Dict[str, Any]] = None,
+ method: str = "GET",
+ data: Any | None = None,
+ headers: dict[str, Any] | None = None,
+ retry_args: dict[str, Any] | None = None,
) -> Any:
"""
Wrapper for HttpHook, allows to change method on the same HttpHook
@@ -111,12 +114,11 @@ def run_method(
:param retry_args: Arguments which define the retry behaviour.
See Tenacity documentation at https://github.com/jd/tenacity
:return: http response
- :rtype: requests.Response
"""
- if method not in ('GET', 'POST', 'PUT', 'DELETE', 'HEAD'):
+ if method not in ("GET", "POST", "PUT", "DELETE", "HEAD"):
raise ValueError(f"Invalid http method '{method}'")
if not self.extra_options:
- self.extra_options = {'check_response': False}
+ self.extra_options = {"check_response": False}
back_method = self.method
self.method = method
@@ -136,12 +138,11 @@ def run_method(
self.method = back_method
return result
- def post_batch(self, *args: Any, **kwargs: Any) -> Any:
+ def post_batch(self, *args: Any, **kwargs: Any) -> int:
"""
Perform request to submit batch
:return: batch session id
- :rtype: int
"""
batch_submit_body = json.dumps(self.build_post_batch_body(*args, **kwargs))
@@ -151,7 +152,7 @@ def post_batch(self, *args: Any, **kwargs: Any) -> Any:
self.log.info("Submitting job %s to %s", batch_submit_body, self.base_url)
response = self.run_method(
- method='POST', endpoint='/batches', data=batch_submit_body, headers=self.extra_headers
+ method="POST", endpoint="/batches", data=batch_submit_body, headers=self.extra_headers
)
self.log.debug("Got response: %s", response.text)
@@ -170,18 +171,17 @@ def post_batch(self, *args: Any, **kwargs: Any) -> Any:
return batch_id
- def get_batch(self, session_id: Union[int, str]) -> Any:
+ def get_batch(self, session_id: int | str) -> dict:
"""
Fetch info about the specified batch
:param session_id: identifier of the batch sessions
:return: response body
- :rtype: dict
"""
self._validate_session_id(session_id)
self.log.debug("Fetching info for batch session %d", session_id)
- response = self.run_method(endpoint=f'/batches/{session_id}', headers=self.extra_headers)
+ response = self.run_method(endpoint=f"/batches/{session_id}", headers=self.extra_headers)
try:
response.raise_for_status()
@@ -193,9 +193,7 @@ def get_batch(self, session_id: Union[int, str]) -> Any:
return response.json()
- def get_batch_state(
- self, session_id: Union[int, str], retry_args: Optional[Dict[str, Any]] = None
- ) -> BatchState:
+ def get_batch_state(self, session_id: int | str, retry_args: dict[str, Any] | None = None) -> BatchState:
"""
Fetch the state of the specified batch
@@ -203,13 +201,12 @@ def get_batch_state(
:param retry_args: Arguments which define the retry behaviour.
See Tenacity documentation at https://github.com/jd/tenacity
:return: batch state
- :rtype: BatchState
"""
self._validate_session_id(session_id)
self.log.debug("Fetching info for batch session %d", session_id)
response = self.run_method(
- endpoint=f'/batches/{session_id}/state', retry_args=retry_args, headers=self.extra_headers
+ endpoint=f"/batches/{session_id}/state", retry_args=retry_args, headers=self.extra_headers
)
try:
@@ -221,23 +218,22 @@ def get_batch_state(
)
jresp = response.json()
- if 'state' not in jresp:
+ if "state" not in jresp:
raise AirflowException(f"Unable to get state for batch with id: {session_id}")
- return BatchState(jresp['state'])
+ return BatchState(jresp["state"])
- def delete_batch(self, session_id: Union[int, str]) -> Any:
+ def delete_batch(self, session_id: int | str) -> dict:
"""
Delete the specified batch
:param session_id: identifier of the batch sessions
:return: response body
- :rtype: dict
"""
self._validate_session_id(session_id)
self.log.info("Deleting batch session %d", session_id)
response = self.run_method(
- method='DELETE', endpoint=f'/batches/{session_id}', headers=self.extra_headers
+ method="DELETE", endpoint=f"/batches/{session_id}", headers=self.extra_headers
)
try:
@@ -250,7 +246,7 @@ def delete_batch(self, session_id: Union[int, str]) -> Any:
return response.json()
- def get_batch_logs(self, session_id: Union[int, str], log_start_position, log_batch_size) -> Any:
+ def get_batch_logs(self, session_id: int | str, log_start_position, log_batch_size) -> dict:
"""
Gets the session logs for a specified batch.
:param session_id: identifier of the batch sessions
@@ -258,12 +254,11 @@ def get_batch_logs(self, session_id: Union[int, str], log_start_position, log_ba
:param log_batch_size: Number of lines to pull in one batch
:return: response body
- :rtype: dict
"""
self._validate_session_id(session_id)
- log_params = {'from': log_start_position, 'size': log_batch_size}
+ log_params = {"from": log_start_position, "size": log_batch_size}
response = self.run_method(
- endpoint=f'/batches/{session_id}/log', data=log_params, headers=self.extra_headers
+ endpoint=f"/batches/{session_id}/log", data=log_params, headers=self.extra_headers
)
try:
response.raise_for_status()
@@ -275,13 +270,12 @@ def get_batch_logs(self, session_id: Union[int, str], log_start_position, log_ba
)
return response.json()
- def dump_batch_logs(self, session_id: Union[int, str]) -> Any:
+ def dump_batch_logs(self, session_id: int | str) -> None:
"""
Dumps the session logs for a specified batch
:param session_id: identifier of the batch sessions
:return: response body
- :rtype: dict
"""
self.log.info("Fetching the logs for batch session with id: %d", session_id)
log_start_line = 0
@@ -291,14 +285,14 @@ def dump_batch_logs(self, session_id: Union[int, str]) -> Any:
while log_start_line <= log_total_lines:
# Livy log endpoint is paginated.
response = self.get_batch_logs(session_id, log_start_line, log_batch_size)
- log_total_lines = self._parse_request_response(response, 'total')
+ log_total_lines = self._parse_request_response(response, "total")
log_start_line += log_batch_size
- log_lines = self._parse_request_response(response, 'log')
+ log_lines = self._parse_request_response(response, "log")
for log_line in log_lines:
self.log.info(log_line)
@staticmethod
- def _validate_session_id(session_id: Union[int, str]) -> None:
+ def _validate_session_id(session_id: int | str) -> None:
"""
Validate session id is a int
@@ -310,46 +304,44 @@ def _validate_session_id(session_id: Union[int, str]) -> None:
raise TypeError("'session_id' must be an integer")
@staticmethod
- def _parse_post_response(response: Dict[Any, Any]) -> Any:
+ def _parse_post_response(response: dict[Any, Any]) -> int | None:
"""
Parse batch response for batch id
:param response: response body
:return: session id
- :rtype: int
"""
- return response.get('id')
+ return response.get("id")
@staticmethod
- def _parse_request_response(response: Dict[Any, Any], parameter) -> Any:
+ def _parse_request_response(response: dict[Any, Any], parameter):
"""
Parse batch response for batch id
:param response: response body
:return: value of parameter
- :rtype: Union[int, list]
"""
- return response.get(parameter)
+ return response.get(parameter, [])
@staticmethod
def build_post_batch_body(
file: str,
- args: Optional[Sequence[Union[str, int, float]]] = None,
- class_name: Optional[str] = None,
- jars: Optional[List[str]] = None,
- py_files: Optional[List[str]] = None,
- files: Optional[List[str]] = None,
- archives: Optional[List[str]] = None,
- name: Optional[str] = None,
- driver_memory: Optional[str] = None,
- driver_cores: Optional[Union[int, str]] = None,
- executor_memory: Optional[str] = None,
- executor_cores: Optional[int] = None,
- num_executors: Optional[Union[int, str]] = None,
- queue: Optional[str] = None,
- proxy_user: Optional[str] = None,
- conf: Optional[Dict[Any, Any]] = None,
- ) -> Any:
+ args: Sequence[str | int | float] | None = None,
+ class_name: str | None = None,
+ jars: list[str] | None = None,
+ py_files: list[str] | None = None,
+ files: list[str] | None = None,
+ archives: list[str] | None = None,
+ name: str | None = None,
+ driver_memory: str | None = None,
+ driver_cores: int | str | None = None,
+ executor_memory: str | None = None,
+ executor_cores: int | None = None,
+ num_executors: int | str | None = None,
+ queue: str | None = None,
+ proxy_user: str | None = None,
+ conf: dict[Any, Any] | None = None,
+ ) -> dict:
"""
Build the post batch request body.
For more information about the format refer to
@@ -371,40 +363,39 @@ def build_post_batch_body(
:param name: The name of this session string.
:param conf: Spark configuration properties.
:return: request body
- :rtype: dict
"""
- body: Dict[str, Any] = {'file': file}
+ body: dict[str, Any] = {"file": file}
if proxy_user:
- body['proxyUser'] = proxy_user
+ body["proxyUser"] = proxy_user
if class_name:
- body['className'] = class_name
+ body["className"] = class_name
if args and LivyHook._validate_list_of_stringables(args):
- body['args'] = [str(val) for val in args]
+ body["args"] = [str(val) for val in args]
if jars and LivyHook._validate_list_of_stringables(jars):
- body['jars'] = jars
+ body["jars"] = jars
if py_files and LivyHook._validate_list_of_stringables(py_files):
- body['pyFiles'] = py_files
+ body["pyFiles"] = py_files
if files and LivyHook._validate_list_of_stringables(files):
- body['files'] = files
+ body["files"] = files
if driver_memory and LivyHook._validate_size_format(driver_memory):
- body['driverMemory'] = driver_memory
+ body["driverMemory"] = driver_memory
if driver_cores:
- body['driverCores'] = driver_cores
+ body["driverCores"] = driver_cores
if executor_memory and LivyHook._validate_size_format(executor_memory):
- body['executorMemory'] = executor_memory
+ body["executorMemory"] = executor_memory
if executor_cores:
- body['executorCores'] = executor_cores
+ body["executorCores"] = executor_cores
if num_executors:
- body['numExecutors'] = num_executors
+ body["numExecutors"] = num_executors
if archives and LivyHook._validate_list_of_stringables(archives):
- body['archives'] = archives
+ body["archives"] = archives
if queue:
- body['queue'] = queue
+ body["queue"] = queue
if name:
- body['name'] = name
+ body["name"] = name
if conf and LivyHook._validate_extra_conf(conf):
- body['conf'] = conf
+ body["conf"] = conf
return body
@@ -415,20 +406,18 @@ def _validate_size_format(size: str) -> bool:
:param size: size value
:return: true if valid format
- :rtype: bool
"""
- if size and not (isinstance(size, str) and re.match(r'^\d+[kmgt]b?$', size, re.IGNORECASE)):
+ if size and not (isinstance(size, str) and re.match(r"^\d+[kmgt]b?$", size, re.IGNORECASE)):
raise ValueError(f"Invalid java size format for string'{size}'")
return True
@staticmethod
- def _validate_list_of_stringables(vals: Sequence[Union[str, int, float]]) -> bool:
+ def _validate_list_of_stringables(vals: Sequence[str | int | float]) -> bool:
"""
Check the values in the provided list can be converted to strings.
:param vals: list to validate
:return: true if valid
- :rtype: bool
"""
if (
vals is None
@@ -439,13 +428,12 @@ def _validate_list_of_stringables(vals: Sequence[Union[str, int, float]]) -> boo
return True
@staticmethod
- def _validate_extra_conf(conf: Dict[Any, Any]) -> bool:
+ def _validate_extra_conf(conf: dict[Any, Any]) -> bool:
"""
Check configuration values are either strings or ints.
:param conf: configuration variable
:return: true if valid
- :rtype: bool
"""
if conf:
if not isinstance(conf, dict):
diff --git a/airflow/providers/apache/livy/operators/livy.py b/airflow/providers/apache/livy/operators/livy.py
index f0dbc9e3165fe..313f64c9f9afd 100644
--- a/airflow/providers/apache/livy/operators/livy.py
+++ b/airflow/providers/apache/livy/operators/livy.py
@@ -14,10 +14,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""This module contains the Apache Livy operator."""
+from __future__ import annotations
+
from time import sleep
-from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Union
+from typing import TYPE_CHECKING, Any, Sequence
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
@@ -32,23 +33,24 @@ class LivyOperator(BaseOperator):
This operator wraps the Apache Livy batch REST API, allowing to submit a Spark
application to the underlying cluster.
- :param file: path of the file containing the application to execute (required).
- :param class_name: name of the application Java/Spark main class.
- :param args: application command line arguments.
- :param jars: jars to be used in this sessions.
- :param py_files: python files to be used in this session.
- :param files: files to be used in this session.
- :param driver_memory: amount of memory to use for the driver process.
- :param driver_cores: number of cores to use for the driver process.
- :param executor_memory: amount of memory to use per executor process.
- :param executor_cores: number of cores to use for each executor.
- :param num_executors: number of executors to launch for this session.
- :param archives: archives to be used in this session.
- :param queue: name of the YARN queue to which the application is submitted.
- :param name: name of this session.
- :param conf: Spark configuration properties.
- :param proxy_user: user to impersonate when running the job.
+ :param file: path of the file containing the application to execute (required). (templated)
+ :param class_name: name of the application Java/Spark main class. (templated)
+ :param args: application command line arguments. (templated)
+ :param jars: jars to be used in this sessions. (templated)
+ :param py_files: python files to be used in this session. (templated)
+ :param files: files to be used in this session. (templated)
+ :param driver_memory: amount of memory to use for the driver process. (templated)
+ :param driver_cores: number of cores to use for the driver process. (templated)
+ :param executor_memory: amount of memory to use per executor process. (templated)
+ :param executor_cores: number of cores to use for each executor. (templated)
+ :param num_executors: number of executors to launch for this session. (templated)
+ :param archives: archives to be used in this session. (templated)
+ :param queue: name of the YARN queue to which the application is submitted. (templated)
+ :param name: name of this session. (templated)
+ :param conf: Spark configuration properties. (templated)
+ :param proxy_user: user to impersonate when running the job. (templated)
:param livy_conn_id: reference to a pre-defined Livy Connection.
+ :param livy_conn_auth_type: The auth type for the Livy Connection.
:param polling_interval: time in seconds between polling for job completion. Don't poll for values >=0
:param extra_options: A dictionary of options, where key is string and value
depends on the option that's being modified.
@@ -57,63 +59,66 @@ class LivyOperator(BaseOperator):
See Tenacity documentation at https://github.com/jd/tenacity
"""
- template_fields: Sequence[str] = ('spark_params',)
+ template_fields: Sequence[str] = ("spark_params",)
+ template_fields_renderers = {"spark_params": "json"}
def __init__(
self,
*,
file: str,
- class_name: Optional[str] = None,
- args: Optional[Sequence[Union[str, int, float]]] = None,
- conf: Optional[Dict[Any, Any]] = None,
- jars: Optional[Sequence[str]] = None,
- py_files: Optional[Sequence[str]] = None,
- files: Optional[Sequence[str]] = None,
- driver_memory: Optional[str] = None,
- driver_cores: Optional[Union[int, str]] = None,
- executor_memory: Optional[str] = None,
- executor_cores: Optional[Union[int, str]] = None,
- num_executors: Optional[Union[int, str]] = None,
- archives: Optional[Sequence[str]] = None,
- queue: Optional[str] = None,
- name: Optional[str] = None,
- proxy_user: Optional[str] = None,
- livy_conn_id: str = 'livy_default',
+ class_name: str | None = None,
+ args: Sequence[str | int | float] | None = None,
+ conf: dict[Any, Any] | None = None,
+ jars: Sequence[str] | None = None,
+ py_files: Sequence[str] | None = None,
+ files: Sequence[str] | None = None,
+ driver_memory: str | None = None,
+ driver_cores: int | str | None = None,
+ executor_memory: str | None = None,
+ executor_cores: int | str | None = None,
+ num_executors: int | str | None = None,
+ archives: Sequence[str] | None = None,
+ queue: str | None = None,
+ name: str | None = None,
+ proxy_user: str | None = None,
+ livy_conn_id: str = "livy_default",
+ livy_conn_auth_type: Any | None = None,
polling_interval: int = 0,
- extra_options: Optional[Dict[str, Any]] = None,
- extra_headers: Optional[Dict[str, Any]] = None,
- retry_args: Optional[Dict[str, Any]] = None,
+ extra_options: dict[str, Any] | None = None,
+ extra_headers: dict[str, Any] | None = None,
+ retry_args: dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.spark_params = {
- 'file': file,
- 'class_name': class_name,
- 'args': args,
- 'jars': jars,
- 'py_files': py_files,
- 'files': files,
- 'driver_memory': driver_memory,
- 'driver_cores': driver_cores,
- 'executor_memory': executor_memory,
- 'executor_cores': executor_cores,
- 'num_executors': num_executors,
- 'archives': archives,
- 'queue': queue,
- 'name': name,
- 'conf': conf,
- 'proxy_user': proxy_user,
+ "file": file,
+ "class_name": class_name,
+ "args": args,
+ "jars": jars,
+ "py_files": py_files,
+ "files": files,
+ "driver_memory": driver_memory,
+ "driver_cores": driver_cores,
+ "executor_memory": executor_memory,
+ "executor_cores": executor_cores,
+ "num_executors": num_executors,
+ "archives": archives,
+ "queue": queue,
+ "name": name,
+ "conf": conf,
+ "proxy_user": proxy_user,
}
self._livy_conn_id = livy_conn_id
+ self._livy_conn_auth_type = livy_conn_auth_type
self._polling_interval = polling_interval
self._extra_options = extra_options or {}
self._extra_headers = extra_headers or {}
- self._livy_hook: Optional[LivyHook] = None
- self._batch_id: Union[int, str]
+ self._livy_hook: LivyHook | None = None
+ self._batch_id: int | str
self.retry_args = retry_args
def get_hook(self) -> LivyHook:
@@ -121,25 +126,27 @@ def get_hook(self) -> LivyHook:
Get valid hook.
:return: hook
- :rtype: LivyHook
"""
if self._livy_hook is None or not isinstance(self._livy_hook, LivyHook):
self._livy_hook = LivyHook(
livy_conn_id=self._livy_conn_id,
extra_headers=self._extra_headers,
extra_options=self._extra_options,
+ auth_type=self._livy_conn_auth_type,
)
return self._livy_hook
- def execute(self, context: "Context") -> Any:
+ def execute(self, context: Context) -> Any:
self._batch_id = self.get_hook().post_batch(**self.spark_params)
if self._polling_interval > 0:
self.poll_for_termination(self._batch_id)
+ context["ti"].xcom_push(key="app_id", value=self.get_hook().get_batch(self._batch_id)["appId"])
+
return self._batch_id
- def poll_for_termination(self, batch_id: Union[int, str]) -> None:
+ def poll_for_termination(self, batch_id: int | str) -> None:
"""
Pool Livy for batch termination.
@@ -148,7 +155,7 @@ def poll_for_termination(self, batch_id: Union[int, str]) -> None:
hook = self.get_hook()
state = hook.get_batch_state(batch_id, retry_args=self.retry_args)
while state not in hook.TERMINAL_STATES:
- self.log.debug('Batch with id %s is in state: %s', batch_id, state.value)
+ self.log.debug("Batch with id %s is in state: %s", batch_id, state.value)
sleep(self._polling_interval)
state = hook.get_batch_state(batch_id, retry_args=self.retry_args)
self.log.info("Batch with id %s terminated with state: %s", batch_id, state.value)
diff --git a/airflow/providers/apache/livy/provider.yaml b/airflow/providers/apache/livy/provider.yaml
index c624e0b9d360a..8e69ef935d03a 100644
--- a/airflow/providers/apache/livy/provider.yaml
+++ b/airflow/providers/apache/livy/provider.yaml
@@ -22,6 +22,9 @@ description: |
`Apache Livy `__
versions:
+ - 3.2.0
+ - 3.1.0
+ - 3.0.0
- 2.2.3
- 2.2.2
- 2.2.1
@@ -32,8 +35,9 @@ versions:
- 1.0.1
- 1.0.0
-additional-dependencies:
- - apache-airflow>=2.1.0
+dependencies:
+ - apache-airflow>=2.3.0
+ - apache-airflow-providers-http
integrations:
- integration-name: Apache Livy
@@ -58,9 +62,6 @@ hooks:
python-modules:
- airflow.providers.apache.livy.hooks.livy
-hook-class-names:
- # deprecated - to be removed after providers add dependency on Airflow 2.2.0+
- - airflow.providers.apache.livy.hooks.livy.LivyHook
connection-types:
- hook-class-name: airflow.providers.apache.livy.hooks.livy.LivyHook
diff --git a/airflow/providers/apache/livy/sensors/livy.py b/airflow/providers/apache/livy/sensors/livy.py
index 4c3419f2af4b2..d0838b187a965 100644
--- a/airflow/providers/apache/livy/sensors/livy.py
+++ b/airflow/providers/apache/livy/sensors/livy.py
@@ -14,9 +14,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""This module contains the Apache Livy sensor."""
-from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Union
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any, Sequence
from airflow.providers.apache.livy.hooks.livy import LivyHook
from airflow.sensors.base import BaseSensorOperator
@@ -34,20 +35,22 @@ class LivySensor(BaseSensorOperator):
depends on the option that's being modified.
"""
- template_fields: Sequence[str] = ('batch_id',)
+ template_fields: Sequence[str] = ("batch_id",)
def __init__(
self,
*,
- batch_id: Union[int, str],
- livy_conn_id: str = 'livy_default',
- extra_options: Optional[Dict[str, Any]] = None,
+ batch_id: int | str,
+ livy_conn_id: str = "livy_default",
+ livy_conn_auth_type: Any | None = None,
+ extra_options: dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.batch_id = batch_id
self._livy_conn_id = livy_conn_id
- self._livy_hook: Optional[LivyHook] = None
+ self._livy_conn_auth_type = livy_conn_auth_type
+ self._livy_hook: LivyHook | None = None
self._extra_options = extra_options or {}
def get_hook(self) -> LivyHook:
@@ -55,13 +58,16 @@ def get_hook(self) -> LivyHook:
Get valid hook.
:return: hook
- :rtype: LivyHook
"""
if self._livy_hook is None or not isinstance(self._livy_hook, LivyHook):
- self._livy_hook = LivyHook(livy_conn_id=self._livy_conn_id, extra_options=self._extra_options)
+ self._livy_hook = LivyHook(
+ livy_conn_id=self._livy_conn_id,
+ extra_options=self._extra_options,
+ auth_type=self._livy_conn_auth_type,
+ )
return self._livy_hook
- def poke(self, context: "Context") -> bool:
+ def poke(self, context: Context) -> bool:
batch_id = self.batch_id
status = self.get_hook().get_batch_state(batch_id)
diff --git a/airflow/providers/apache/pig/.latest-doc-only-change.txt b/airflow/providers/apache/pig/.latest-doc-only-change.txt
index 28124098645cf..ff7136e07d744 100644
--- a/airflow/providers/apache/pig/.latest-doc-only-change.txt
+++ b/airflow/providers/apache/pig/.latest-doc-only-change.txt
@@ -1 +1 @@
-6c3a67d4fccafe4ab6cd9ec8c7bacf2677f17038
+06acf40a4337759797f666d5bb27a5a393b74fed
diff --git a/airflow/providers/apache/pig/CHANGELOG.rst b/airflow/providers/apache/pig/CHANGELOG.rst
index c599c06bad8bc..f8338b3e9c91f 100644
--- a/airflow/providers/apache/pig/CHANGELOG.rst
+++ b/airflow/providers/apache/pig/CHANGELOG.rst
@@ -16,9 +16,65 @@
under the License.
+.. NOTE TO CONTRIBUTORS:
+ Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes
+ and you want to add an explanation to the users on how they are supposed to deal with them.
+ The changelog is updated and maintained semi-automatically by release manager.
+
Changelog
---------
+4.0.0
+.....
+
+This release of provider is only available for Airflow 2.3+ as explained in the
+`Apache Airflow providers support policy `_.
+
+Breaking changes
+~~~~~~~~~~~~~~~~
+
+You cannot use ``pig_properties`` any more as connection extras. If you want to add extra parameters
+to ``pig`` command, you need to do it via ``pig_properties`` (string list) of the PigCliHook (new parameter)
+or via ``pig_opts`` (string with options separated by spaces) or ``pig_properties`` (string list) in
+the PigOperator . Any use of ``pig_properties`` extras in connection will raise an exception,
+informing that you need to remove them and pass them as parameters.
+
+Both ``pig_properties`` and ``pig_opts`` are now templated fields in the PigOperator.
+
+* ``Pig cli connection properties cannot be passed by connection extra (#27644)``
+
+
+Misc
+~~~~
+
+* ``Move min airflow version to 2.3.0 for all providers (#27196)``
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Add documentation for July 2022 Provider's release (#25030)``
+ * ``Enable string normalization in python formatting - providers (#27205)``
+ * ``Update docs for September Provider's release (#26731)``
+ * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)``
+ * ``Prepare docs for new providers release (August 2022) (#25618)``
+ * ``Move provider dependencies to inside provider folders (#24672)``
+ * ``Remove 'hook-class-names' from provider.yaml (#24702)``
+
+3.0.0
+.....
+
+Breaking changes
+~~~~~~~~~~~~~~~~
+
+* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow
+ providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Add explanatory note for contributors about updating Changelog (#24229)``
+ * ``AIP-47 - Migrate apache pig DAGs to new design #22439 (#24212)``
+ * ``Prepare docs for May 2022 provider's release (#24231)``
+ * ``Update package description to remove double min-airflow specification (#24292)``
+
2.0.4
.....
diff --git a/airflow/providers/apache/pig/example_dags/example_pig.py b/airflow/providers/apache/pig/example_dags/example_pig.py
deleted file mode 100644
index ed1b34ab0c8a4..0000000000000
--- a/airflow/providers/apache/pig/example_dags/example_pig.py
+++ /dev/null
@@ -1,40 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""Example DAG demonstrating the usage of the PigOperator."""
-from datetime import datetime
-
-from airflow import DAG
-from airflow.providers.apache.pig.operators.pig import PigOperator
-
-dag = DAG(
- dag_id='example_pig_operator',
- schedule_interval=None,
- start_date=datetime(2021, 1, 1),
- catchup=False,
- tags=['example'],
-)
-
-# [START create_pig]
-run_this = PigOperator(
- task_id="run_example_pig_script",
- pig="ls /;",
- pig_opts="-x local",
- dag=dag,
-)
-# [END create_pig]
diff --git a/airflow/providers/apache/pig/hooks/pig.py b/airflow/providers/apache/pig/hooks/pig.py
index ae9db33c3db53..023b308e13e2c 100644
--- a/airflow/providers/apache/pig/hooks/pig.py
+++ b/airflow/providers/apache/pig/hooks/pig.py
@@ -15,63 +15,72 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
import subprocess
from tempfile import NamedTemporaryFile, TemporaryDirectory
-from typing import Any, List, Optional
+from typing import Any
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
class PigCliHook(BaseHook):
- """
- Simple wrapper around the pig CLI.
-
- Note that you can also set default pig CLI properties using the
- ``pig_properties`` to be used in your connection as in
- ``{"pig_properties": "-Dpig.tmpfilecompression=true"}``
+ """Simple wrapper around the pig CLI.
+ :param pig_cli_conn_id: Connection id used by the hook
+ :param pig_properties: additional properties added after pig cli command as list of strings.
"""
- conn_name_attr = 'pig_cli_conn_id'
- default_conn_name = 'pig_cli_default'
- conn_type = 'pig_cli'
- hook_name = 'Pig Client Wrapper'
+ conn_name_attr = "pig_cli_conn_id"
+ default_conn_name = "pig_cli_default"
+ conn_type = "pig_cli"
+ hook_name = "Pig Client Wrapper"
- def __init__(self, pig_cli_conn_id: str = default_conn_name) -> None:
+ def __init__(
+ self, pig_cli_conn_id: str = default_conn_name, pig_properties: list[str] | None = None
+ ) -> None:
super().__init__()
conn = self.get_connection(pig_cli_conn_id)
- self.pig_properties = conn.extra_dejson.get('pig_properties', '')
+ conn_pig_properties = conn.extra_dejson.get("pig_properties")
+ if conn_pig_properties:
+ raise RuntimeError(
+ "The PigCliHook used to have possibility of passing `pig_properties` to the Hook,"
+ " however with the 4.0.0 version of `apache-pig` provider it has been removed. You should"
+ " use ``pig_opts`` (space separated string) or ``pig_properties`` (string list) in the"
+ " PigOperator. You can also pass ``pig-properties`` in the PigCliHook `init`. Currently,"
+ f" the {pig_cli_conn_id} connection has those extras: `{conn_pig_properties}`."
+ )
+ self.pig_properties = pig_properties if pig_properties else []
self.conn = conn
self.sub_process = None
- def run_cli(self, pig: str, pig_opts: Optional[str] = None, verbose: bool = True) -> Any:
+ def run_cli(self, pig: str, pig_opts: str | None = None, verbose: bool = True) -> Any:
"""
- Run an pig script using the pig cli
+ Run a pig script using the pig cli
>>> ph = PigCliHook()
>>> result = ph.run_cli("ls /;", pig_opts="-x mapreduce")
>>> ("hdfs://" in result)
True
"""
- with TemporaryDirectory(prefix='airflow_pigop_') as tmp_dir:
+ with TemporaryDirectory(prefix="airflow_pigop_") as tmp_dir:
with NamedTemporaryFile(dir=tmp_dir) as f:
- f.write(pig.encode('utf-8'))
+ f.write(pig.encode("utf-8"))
f.flush()
fname = f.name
- pig_bin = 'pig'
- cmd_extra: List[str] = []
+ pig_bin = "pig"
+ cmd_extra: list[str] = []
pig_cmd = [pig_bin]
if self.pig_properties:
- pig_properties_list = self.pig_properties.split()
- pig_cmd.extend(pig_properties_list)
+ pig_cmd.extend(self.pig_properties)
if pig_opts:
pig_opts_list = pig_opts.split()
pig_cmd.extend(pig_opts_list)
- pig_cmd.extend(['-f', fname] + cmd_extra)
+ pig_cmd.extend(["-f", fname] + cmd_extra)
if verbose:
self.log.info("%s", " ".join(pig_cmd))
@@ -79,9 +88,9 @@ def run_cli(self, pig: str, pig_opts: Optional[str] = None, verbose: bool = True
pig_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd=tmp_dir, close_fds=True
)
self.sub_process = sub_process
- stdout = ''
- for line in iter(sub_process.stdout.readline, b''):
- stdout += line.decode('utf-8')
+ stdout = ""
+ for line in iter(sub_process.stdout.readline, b""):
+ stdout += line.decode("utf-8")
if verbose:
self.log.info(line.strip())
sub_process.wait()
diff --git a/airflow/providers/apache/pig/operators/pig.py b/airflow/providers/apache/pig/operators/pig.py
index 5a285530dc637..544895cea09dd 100644
--- a/airflow/providers/apache/pig/operators/pig.py
+++ b/airflow/providers/apache/pig/operators/pig.py
@@ -15,8 +15,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
import re
-from typing import TYPE_CHECKING, Any, Optional, Sequence
+from typing import TYPE_CHECKING, Any, Sequence
from airflow.models import BaseOperator
from airflow.providers.apache.pig.hooks.pig import PigCliHook
@@ -36,23 +38,25 @@ class PigOperator(BaseOperator):
you may want to use this along with the
``DAG(user_defined_macros=myargs)`` parameter. View the DAG
object documentation for more details.
- :param pig_opts: pig options, such as: -x tez, -useHCatalog, ...
+ :param pig_opts: pig options, such as: -x tez, -useHCatalog, ... - space separated list
+ :param pig_properties: pig properties, additional pig properties passed as list
"""
- template_fields: Sequence[str] = ('pig',)
+ template_fields: Sequence[str] = ("pig", "pig_opts", "pig_properties")
template_ext: Sequence[str] = (
- '.pig',
- '.piglatin',
+ ".pig",
+ ".piglatin",
)
- ui_color = '#f0e4ec'
+ ui_color = "#f0e4ec"
def __init__(
self,
*,
pig: str,
- pig_cli_conn_id: str = 'pig_cli_default',
+ pig_cli_conn_id: str = "pig_cli_default",
pigparams_jinja_translate: bool = False,
- pig_opts: Optional[str] = None,
+ pig_opts: str | None = None,
+ pig_properties: list[str] | None = None,
**kwargs: Any,
) -> None:
@@ -61,15 +65,16 @@ def __init__(
self.pig = pig
self.pig_cli_conn_id = pig_cli_conn_id
self.pig_opts = pig_opts
- self.hook: Optional[PigCliHook] = None
+ self.pig_properties = pig_properties
+ self.hook: PigCliHook | None = None
def prepare_template(self):
if self.pigparams_jinja_translate:
self.pig = re.sub(r"(\$([a-zA-Z_][a-zA-Z0-9_]*))", r"{{ \g<2> }}", self.pig)
- def execute(self, context: 'Context'):
- self.log.info('Executing: %s', self.pig)
- self.hook = PigCliHook(pig_cli_conn_id=self.pig_cli_conn_id)
+ def execute(self, context: Context):
+ self.log.info("Executing: %s", self.pig)
+ self.hook = PigCliHook(pig_cli_conn_id=self.pig_cli_conn_id, pig_properties=self.pig_properties)
self.hook.run_cli(pig=self.pig, pig_opts=self.pig_opts)
def on_kill(self):
diff --git a/airflow/providers/apache/pig/provider.yaml b/airflow/providers/apache/pig/provider.yaml
index 81540ae4420f6..df1c55a8a7b61 100644
--- a/airflow/providers/apache/pig/provider.yaml
+++ b/airflow/providers/apache/pig/provider.yaml
@@ -22,6 +22,8 @@ description: |
`Apache Pig `__
versions:
+ - 4.0.0
+ - 3.0.0
- 2.0.4
- 2.0.3
- 2.0.2
@@ -30,8 +32,8 @@ versions:
- 1.0.1
- 1.0.0
-additional-dependencies:
- - apache-airflow>=2.1.0
+dependencies:
+ - apache-airflow>=2.3.0
integrations:
- integration-name: Apache Pig
@@ -51,10 +53,6 @@ hooks:
python-modules:
- airflow.providers.apache.pig.hooks.pig
-hook-class-names:
- # deprecated - to be removed after providers add dependency on Airflow 2.2.0+
- - airflow.providers.apache.pig.hooks.pig.PigCliHook
-
connection-types:
- connection-type: pig_cli
hook-class-name: airflow.providers.apache.pig.hooks.pig.PigCliHook
diff --git a/airflow/providers/apache/pinot/CHANGELOG.rst b/airflow/providers/apache/pinot/CHANGELOG.rst
index 923a27315fe43..cafdffda1dbcd 100644
--- a/airflow/providers/apache/pinot/CHANGELOG.rst
+++ b/airflow/providers/apache/pinot/CHANGELOG.rst
@@ -16,9 +16,98 @@
under the License.
+.. NOTE TO CONTRIBUTORS:
+ Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes
+ and you want to add an explanation to the users on how they are supposed to deal with them.
+ The changelog is updated and maintained semi-automatically by release manager.
+
Changelog
---------
+4.0.0
+.....
+
+This release of provider is only available for Airflow 2.3+ as explained in the
+`Apache Airflow providers support policy `_.
+
+Breaking changes
+~~~~~~~~~~~~~~~~
+
+The admin command is now hard-coded to ``pinot-admin.sh``. The ``pinot-admin.sh`` command must be available
+on the path in order to use PinotAdminHook.
+
+* ``The pinot-admin.sh command is now hard-coded. (#27641)``
+
+Misc
+~~~~
+
+* ``Move min airflow version to 2.3.0 for all providers (#27196)``
+* ``Bump pinotdb version (#27201)``
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Enable string normalization in python formatting - providers (#27205)``
+
+3.2.1
+.....
+
+Misc
+~~~~
+
+* ``Add common-sql lower bound for common-sql (#25789)``
+
+Bug Fixes
+~~~~~~~~~
+
+* ``Fix PinotDB dependencies (#26705)``
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)``
+
+3.2.0
+.....
+
+Features
+~~~~~~~~
+
+* ``Deprecate hql parameters and synchronize DBApiHook method APIs (#25299)``
+* ``Unify DbApiHook.run() method with the methods which override it (#23971)``
+
+
+3.1.0
+.....
+
+Features
+~~~~~~~~
+
+* ``Move all SQL classes to common-sql provider (#24836)``
+
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Move provider dependencies to inside provider folders (#24672)``
+
+3.0.0
+.....
+
+Breaking changes
+~~~~~~~~~~~~~~~~
+
+* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow
+ providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers
+
+Misc
+~~~~
+
+* ``chore: Refactoring and Cleaning Apache Providers (#24219)``
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Add explanatory note for contributors about updating Changelog (#24229)``
+ * ``Prepare docs for May 2022 provider's release (#24231)``
+ * ``Update package description to remove double min-airflow specification (#24292)``
+
2.0.4
.....
diff --git a/airflow/providers/apache/pinot/hooks/pinot.py b/airflow/providers/apache/pinot/hooks/pinot.py
index 55ddce0bcc0cc..909b0182ddcb5 100644
--- a/airflow/providers/apache/pinot/hooks/pinot.py
+++ b/airflow/providers/apache/pinot/hooks/pinot.py
@@ -15,17 +15,18 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
import os
import subprocess
-from typing import Any, Dict, Iterable, List, Optional, Union
+from typing import Any, Iterable, Mapping
from pinotdb import connect
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
-from airflow.hooks.dbapi import DbApiHook
from airflow.models import Connection
+from airflow.providers.common.sql.hooks.sql import DbApiHook
class PinotAdminHook(BaseHook):
@@ -45,7 +46,10 @@ class PinotAdminHook(BaseHook):
following PR: https://github.com/apache/incubator-pinot/pull/4110
:param conn_id: The name of the connection to use.
- :param cmd_path: The filepath to the pinot-admin.sh executable
+ :param cmd_path: Do not modify the parameter. It used to be the filepath to the pinot-admin.sh
+ executable but in version 4.0.0 of apache-pinot provider, value of this parameter must
+ remain the default value: `pinot-admin.sh`. It is left here to not accidentally override
+ the `pinot_admin_system_exit` in case positional parameters were used to initialize the hook.
:param pinot_admin_system_exit: If true, the result is evaluated based on the status code.
Otherwise, the result is evaluated as a failure if "Error" or
"Exception" is in the output message.
@@ -61,7 +65,15 @@ def __init__(
conn = self.get_connection(conn_id)
self.host = conn.host
self.port = str(conn.port)
- self.cmd_path = conn.extra_dejson.get("cmd_path", cmd_path)
+ if cmd_path != "pinot-admin.sh":
+ raise RuntimeError(
+ "In version 4.0.0 of the PinotAdminHook the cmd_path has been hard-coded to"
+ " pinot-admin.sh. In order to avoid accidental using of this parameter as"
+ " positional `pinot_admin_system_exit` the `cmd_parameter`"
+ " parameter is left here but you should not modify it. Make sure that "
+ " `pinot-admin.sh` is on your PATH and do not change cmd_path value."
+ )
+ self.cmd_path = "pinot-admin.sh"
self.pinot_admin_system_exit = conn.extra_dejson.get(
"pinot_admin_system_exit", pinot_admin_system_exit
)
@@ -102,24 +114,24 @@ def add_table(self, file_path: str, with_exec: bool = True) -> Any:
def create_segment(
self,
- generator_config_file: Optional[str] = None,
- data_dir: Optional[str] = None,
- segment_format: Optional[str] = None,
- out_dir: Optional[str] = None,
- overwrite: Optional[str] = None,
- table_name: Optional[str] = None,
- segment_name: Optional[str] = None,
- time_column_name: Optional[str] = None,
- schema_file: Optional[str] = None,
- reader_config_file: Optional[str] = None,
- enable_star_tree_index: Optional[str] = None,
- star_tree_index_spec_file: Optional[str] = None,
- hll_size: Optional[str] = None,
- hll_columns: Optional[str] = None,
- hll_suffix: Optional[str] = None,
- num_threads: Optional[str] = None,
- post_creation_verification: Optional[str] = None,
- retry: Optional[str] = None,
+ generator_config_file: str | None = None,
+ data_dir: str | None = None,
+ segment_format: str | None = None,
+ out_dir: str | None = None,
+ overwrite: str | None = None,
+ table_name: str | None = None,
+ segment_name: str | None = None,
+ time_column_name: str | None = None,
+ schema_file: str | None = None,
+ reader_config_file: str | None = None,
+ enable_star_tree_index: str | None = None,
+ star_tree_index_spec_file: str | None = None,
+ hll_size: str | None = None,
+ hll_columns: str | None = None,
+ hll_suffix: str | None = None,
+ num_threads: str | None = None,
+ post_creation_verification: str | None = None,
+ retry: str | None = None,
) -> Any:
"""Create Pinot segment by run CreateSegment command"""
cmd = ["CreateSegment"]
@@ -180,7 +192,7 @@ def create_segment(
self.run_cli(cmd)
- def upload_segment(self, segment_dir: str, table_name: Optional[str] = None) -> Any:
+ def upload_segment(self, segment_dir: str, table_name: str | None = None) -> Any:
"""
Upload Segment with run UploadSegment command
@@ -196,16 +208,14 @@ def upload_segment(self, segment_dir: str, table_name: Optional[str] = None) ->
cmd += ["-tableName", table_name]
self.run_cli(cmd)
- def run_cli(self, cmd: List[str], verbose: bool = True) -> str:
+ def run_cli(self, cmd: list[str], verbose: bool = True) -> str:
"""
Run command with pinot-admin.sh
:param cmd: List of command going to be run by pinot-admin.sh script
:param verbose:
"""
- command = [self.cmd_path]
- command.extend(cmd)
-
+ command = [self.cmd_path, *cmd]
env = None
if self.pinot_admin_system_exit:
env = os.environ.copy()
@@ -214,13 +224,12 @@ def run_cli(self, cmd: List[str], verbose: bool = True) -> str:
if verbose:
self.log.info(" ".join(command))
-
with subprocess.Popen(
command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, close_fds=True, env=env
) as sub_process:
stdout = ""
if sub_process.stdout:
- for line in iter(sub_process.stdout.readline, b''):
+ for line in iter(sub_process.stdout.readline, b""):
stdout += line.decode("utf-8")
if verbose:
self.log.info(line.decode("utf-8").strip())
@@ -246,8 +255,8 @@ class PinotDbApiHook(DbApiHook):
https://docs.pinot.apache.org/users/api/querying-pinot-using-standard-sql
"""
- conn_name_attr = 'pinot_broker_conn_id'
- default_conn_name = 'pinot_broker_default'
+ conn_name_attr = "pinot_broker_conn_id"
+ default_conn_name = "pinot_broker_default"
supports_autocommit = False
def get_conn(self) -> Any:
@@ -257,10 +266,10 @@ def get_conn(self) -> Any:
pinot_broker_conn = connect(
host=conn.host,
port=conn.port,
- path=conn.extra_dejson.get('endpoint', '/query/sql'),
- scheme=conn.extra_dejson.get('schema', 'http'),
+ path=conn.extra_dejson.get("endpoint", "/query/sql"),
+ scheme=conn.extra_dejson.get("schema", "http"),
)
- self.log.info('Get the connection to pinot broker on %s', conn.host)
+ self.log.info("Get the connection to pinot broker on %s", conn.host)
return pinot_broker_conn
def get_uri(self) -> str:
@@ -272,12 +281,14 @@ def get_uri(self) -> str:
conn = self.get_connection(getattr(self, self.conn_name_attr))
host = conn.host
if conn.port is not None:
- host += f':{conn.port}'
- conn_type = 'http' if not conn.conn_type else conn.conn_type
- endpoint = conn.extra_dejson.get('endpoint', 'query/sql')
- return f'{conn_type}://{host}/{endpoint}'
+ host += f":{conn.port}"
+ conn_type = conn.conn_type or "http"
+ endpoint = conn.extra_dejson.get("endpoint", "query/sql")
+ return f"{conn_type}://{host}/{endpoint}"
- def get_records(self, sql: str, parameters: Optional[Union[Dict[str, Any], Iterable[Any]]] = None) -> Any:
+ def get_records(
+ self, sql: str | list[str], parameters: Iterable | Mapping | None = None, **kwargs
+ ) -> Any:
"""
Executes the sql and returns a set of records.
@@ -289,7 +300,7 @@ def get_records(self, sql: str, parameters: Optional[Union[Dict[str, Any], Itera
cur.execute(sql)
return cur.fetchall()
- def get_first(self, sql: str, parameters: Optional[Union[Dict[str, Any], Iterable[Any]]] = None) -> Any:
+ def get_first(self, sql: str | list[str], parameters: Iterable | Mapping | None = None) -> Any:
"""
Executes the sql and returns the first resulting row.
@@ -308,7 +319,7 @@ def insert_rows(
self,
table: str,
rows: str,
- target_fields: Optional[str] = None,
+ target_fields: str | None = None,
commit_every: int = 1000,
replace: bool = False,
**kwargs: Any,
diff --git a/airflow/providers/apache/pinot/provider.yaml b/airflow/providers/apache/pinot/provider.yaml
index 000827c9e141c..e8b02fbd4105c 100644
--- a/airflow/providers/apache/pinot/provider.yaml
+++ b/airflow/providers/apache/pinot/provider.yaml
@@ -22,6 +22,11 @@ description: |
`Apache Pinot `__
versions:
+ - 4.0.0
+ - 3.2.1
+ - 3.2.0
+ - 3.1.0
+ - 3.0.0
- 2.0.4
- 2.0.3
- 2.0.2
@@ -30,8 +35,10 @@ versions:
- 1.0.1
- 1.0.0
-additional-dependencies:
- - apache-airflow>=2.1.0
+dependencies:
+ - apache-airflow>=2.3.0
+ - apache-airflow-providers-common-sql>=1.3.1
+ - pinotdb>0.4.7
integrations:
- integration-name: Apache Pinot
diff --git a/airflow/providers/apache/spark/.latest-doc-only-change.txt b/airflow/providers/apache/spark/.latest-doc-only-change.txt
index cda183acd3b04..ff7136e07d744 100644
--- a/airflow/providers/apache/spark/.latest-doc-only-change.txt
+++ b/airflow/providers/apache/spark/.latest-doc-only-change.txt
@@ -1 +1 @@
-cb73053211367e2c2dd76d5279cdc7dc7b190124
+06acf40a4337759797f666d5bb27a5a393b74fed
diff --git a/airflow/providers/apache/spark/CHANGELOG.rst b/airflow/providers/apache/spark/CHANGELOG.rst
index f6e4cf48ba18d..7b0cca49af7e8 100644
--- a/airflow/providers/apache/spark/CHANGELOG.rst
+++ b/airflow/providers/apache/spark/CHANGELOG.rst
@@ -16,9 +16,74 @@
under the License.
+.. NOTE TO CONTRIBUTORS:
+ Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes
+ and you want to add an explanation to the users on how they are supposed to deal with them.
+ The changelog is updated and maintained semi-automatically by release manager.
+
Changelog
---------
+4.0.0
+.....
+
+This release of provider is only available for Airflow 2.3+ as explained in the
+`Apache Airflow providers support policy `_.
+
+Breaking changes
+~~~~~~~~~~~~~~~~
+
+The ``spark-binary`` connection extra could be set to any binary, but with 4.0.0 version only two values
+are allowed for it ``spark-submit`` and ``spark2-submit``.
+
+The ``spark-home`` connection extra is not allowed any more - the binary should be available on the
+PATH in order to use SparkSubmitHook and SparkSubmitOperator.
+
+* ``Remove custom spark home and custom binaries for spark (#27646)``
+
+Misc
+~~~~
+
+* ``Move min airflow version to 2.3.0 for all providers (#27196)``
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Add documentation for July 2022 Provider's release (#25030)``
+ * ``Enable string normalization in python formatting - providers (#27205)``
+ * ``Update docs for September Provider's release (#26731)``
+ * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)``
+ * ``Prepare docs for new providers release (August 2022) (#25618)``
+ * ``Move provider dependencies to inside provider folders (#24672)``
+ * ``Remove 'hook-class-names' from provider.yaml (#24702)``
+
+3.0.0
+.....
+
+Breaking changes
+~~~~~~~~~~~~~~~~
+
+* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow
+ providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers
+
+Bug Fixes
+~~~~~~~~~
+
+* ``Add typing for airflow/configuration.py (#23716)``
+* ``Fix backwards-compatibility introduced by fixing mypy problems (#24230)``
+
+Misc
+~~~~
+
+* ``AIP-47 - Migrate spark DAGs to new design #22439 (#24210)``
+* ``chore: Refactoring and Cleaning Apache Providers (#24219)``
+
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Add explanatory note for contributors about updating Changelog (#24229)``
+ * ``Prepare docs for May 2022 provider's release (#24231)``
+ * ``Update package description to remove double min-airflow specification (#24292)``
+
2.1.3
.....
diff --git a/airflow/providers/apache/spark/hooks/spark_jdbc.py b/airflow/providers/apache/spark/hooks/spark_jdbc.py
index df9d715be0dcc..87abfb863eae9 100644
--- a/airflow/providers/apache/spark/hooks/spark_jdbc.py
+++ b/airflow/providers/apache/spark/hooks/spark_jdbc.py
@@ -15,9 +15,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-#
+from __future__ import annotations
+
import os
-from typing import Any, Dict, Optional
+from typing import Any
from airflow.exceptions import AirflowException
from airflow.providers.apache.spark.hooks.spark_submit import SparkSubmitHook
@@ -85,41 +86,41 @@ class SparkJDBCHook(SparkSubmitHook):
types.
"""
- conn_name_attr = 'spark_conn_id'
- default_conn_name = 'spark_default'
- conn_type = 'spark_jdbc'
- hook_name = 'Spark JDBC'
+ conn_name_attr = "spark_conn_id"
+ default_conn_name = "spark_default"
+ conn_type = "spark_jdbc"
+ hook_name = "Spark JDBC"
def __init__(
self,
- spark_app_name: str = 'airflow-spark-jdbc',
+ spark_app_name: str = "airflow-spark-jdbc",
spark_conn_id: str = default_conn_name,
- spark_conf: Optional[Dict[str, Any]] = None,
- spark_py_files: Optional[str] = None,
- spark_files: Optional[str] = None,
- spark_jars: Optional[str] = None,
- num_executors: Optional[int] = None,
- executor_cores: Optional[int] = None,
- executor_memory: Optional[str] = None,
- driver_memory: Optional[str] = None,
+ spark_conf: dict[str, Any] | None = None,
+ spark_py_files: str | None = None,
+ spark_files: str | None = None,
+ spark_jars: str | None = None,
+ num_executors: int | None = None,
+ executor_cores: int | None = None,
+ executor_memory: str | None = None,
+ driver_memory: str | None = None,
verbose: bool = False,
- principal: Optional[str] = None,
- keytab: Optional[str] = None,
- cmd_type: str = 'spark_to_jdbc',
- jdbc_table: Optional[str] = None,
- jdbc_conn_id: str = 'jdbc-default',
- jdbc_driver: Optional[str] = None,
- metastore_table: Optional[str] = None,
+ principal: str | None = None,
+ keytab: str | None = None,
+ cmd_type: str = "spark_to_jdbc",
+ jdbc_table: str | None = None,
+ jdbc_conn_id: str = "jdbc-default",
+ jdbc_driver: str | None = None,
+ metastore_table: str | None = None,
jdbc_truncate: bool = False,
- save_mode: Optional[str] = None,
- save_format: Optional[str] = None,
- batch_size: Optional[int] = None,
- fetch_size: Optional[int] = None,
- num_partitions: Optional[int] = None,
- partition_column: Optional[str] = None,
- lower_bound: Optional[str] = None,
- upper_bound: Optional[str] = None,
- create_table_column_types: Optional[str] = None,
+ save_mode: str | None = None,
+ save_format: str | None = None,
+ batch_size: int | None = None,
+ fetch_size: int | None = None,
+ num_partitions: int | None = None,
+ partition_column: str | None = None,
+ lower_bound: str | None = None,
+ upper_bound: str | None = None,
+ create_table_column_types: str | None = None,
*args: Any,
**kwargs: Any,
):
@@ -154,73 +155,73 @@ def __init__(
self._create_table_column_types = create_table_column_types
self._jdbc_connection = self._resolve_jdbc_connection()
- def _resolve_jdbc_connection(self) -> Dict[str, Any]:
- conn_data = {'url': '', 'schema': '', 'conn_prefix': '', 'user': '', 'password': ''}
+ def _resolve_jdbc_connection(self) -> dict[str, Any]:
+ conn_data = {"url": "", "schema": "", "conn_prefix": "", "user": "", "password": ""}
try:
conn = self.get_connection(self._jdbc_conn_id)
if conn.port:
- conn_data['url'] = f"{conn.host}:{conn.port}"
+ conn_data["url"] = f"{conn.host}:{conn.port}"
else:
- conn_data['url'] = conn.host
- conn_data['schema'] = conn.schema
- conn_data['user'] = conn.login
- conn_data['password'] = conn.password
+ conn_data["url"] = conn.host
+ conn_data["schema"] = conn.schema
+ conn_data["user"] = conn.login
+ conn_data["password"] = conn.password
extra = conn.extra_dejson
- conn_data['conn_prefix'] = extra.get('conn_prefix', '')
+ conn_data["conn_prefix"] = extra.get("conn_prefix", "")
except AirflowException:
self.log.debug(
"Could not load jdbc connection string %s, defaulting to %s", self._jdbc_conn_id, ""
)
return conn_data
- def _build_jdbc_application_arguments(self, jdbc_conn: Dict[str, Any]) -> Any:
+ def _build_jdbc_application_arguments(self, jdbc_conn: dict[str, Any]) -> Any:
arguments = []
arguments += ["-cmdType", self._cmd_type]
- if self._jdbc_connection['url']:
+ if self._jdbc_connection["url"]:
arguments += [
- '-url',
+ "-url",
f"{jdbc_conn['conn_prefix']}{jdbc_conn['url']}/{jdbc_conn['schema']}",
]
- if self._jdbc_connection['user']:
- arguments += ['-user', self._jdbc_connection['user']]
- if self._jdbc_connection['password']:
- arguments += ['-password', self._jdbc_connection['password']]
+ if self._jdbc_connection["user"]:
+ arguments += ["-user", self._jdbc_connection["user"]]
+ if self._jdbc_connection["password"]:
+ arguments += ["-password", self._jdbc_connection["password"]]
if self._metastore_table:
- arguments += ['-metastoreTable', self._metastore_table]
+ arguments += ["-metastoreTable", self._metastore_table]
if self._jdbc_table:
- arguments += ['-jdbcTable', self._jdbc_table]
+ arguments += ["-jdbcTable", self._jdbc_table]
if self._jdbc_truncate:
- arguments += ['-jdbcTruncate', str(self._jdbc_truncate)]
+ arguments += ["-jdbcTruncate", str(self._jdbc_truncate)]
if self._jdbc_driver:
- arguments += ['-jdbcDriver', self._jdbc_driver]
+ arguments += ["-jdbcDriver", self._jdbc_driver]
if self._batch_size:
- arguments += ['-batchsize', str(self._batch_size)]
+ arguments += ["-batchsize", str(self._batch_size)]
if self._fetch_size:
- arguments += ['-fetchsize', str(self._fetch_size)]
+ arguments += ["-fetchsize", str(self._fetch_size)]
if self._num_partitions:
- arguments += ['-numPartitions', str(self._num_partitions)]
+ arguments += ["-numPartitions", str(self._num_partitions)]
if self._partition_column and self._lower_bound and self._upper_bound and self._num_partitions:
# these 3 parameters need to be used all together to take effect.
arguments += [
- '-partitionColumn',
+ "-partitionColumn",
self._partition_column,
- '-lowerBound',
+ "-lowerBound",
self._lower_bound,
- '-upperBound',
+ "-upperBound",
self._upper_bound,
]
if self._save_mode:
- arguments += ['-saveMode', self._save_mode]
+ arguments += ["-saveMode", self._save_mode]
if self._save_format:
- arguments += ['-saveFormat', self._save_format]
+ arguments += ["-saveFormat", self._save_format]
if self._create_table_column_types:
- arguments += ['-createTableColumnTypes', self._create_table_column_types]
+ arguments += ["-createTableColumnTypes", self._create_table_column_types]
return arguments
def submit_jdbc_job(self) -> None:
"""Submit Spark JDBC job"""
self._application_args = self._build_jdbc_application_arguments(self._jdbc_connection)
- self.submit(application=os.path.dirname(os.path.abspath(__file__)) + "/spark_jdbc_script.py")
+ self.submit(application=f"{os.path.dirname(os.path.abspath(__file__))}/spark_jdbc_script.py")
def get_conn(self) -> Any:
pass
diff --git a/airflow/providers/apache/spark/hooks/spark_jdbc_script.py b/airflow/providers/apache/spark/hooks/spark_jdbc_script.py
index c354de6ab0ca4..2cc10584d7752 100644
--- a/airflow/providers/apache/spark/hooks/spark_jdbc_script.py
+++ b/airflow/providers/apache/spark/hooks/spark_jdbc_script.py
@@ -15,9 +15,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-#
+from __future__ import annotations
+
import argparse
-from typing import Any, List, Optional
+from typing import Any
from pyspark.sql import SparkSession
@@ -27,11 +28,11 @@
def set_common_options(
spark_source: Any,
- url: str = 'localhost:5432',
- jdbc_table: str = 'default.default',
- user: str = 'root',
- password: str = 'root',
- driver: str = 'driver',
+ url: str = "localhost:5432",
+ jdbc_table: str = "default.default",
+ user: str = "root",
+ password: str = "root",
+ driver: str = "driver",
) -> Any:
"""
Get Spark source from JDBC connection
@@ -44,12 +45,12 @@ def set_common_options(
:param driver: JDBC resource driver
"""
spark_source = (
- spark_source.format('jdbc')
- .option('url', url)
- .option('dbtable', jdbc_table)
- .option('user', user)
- .option('password', password)
- .option('driver', driver)
+ spark_source.format("jdbc")
+ .option("url", url)
+ .option("dbtable", jdbc_table)
+ .option("user", user)
+ .option("password", password)
+ .option("driver", driver)
)
return spark_source
@@ -75,11 +76,11 @@ def spark_write_to_jdbc(
# now set write-specific options
if truncate:
- writer = writer.option('truncate', truncate)
+ writer = writer.option("truncate", truncate)
if batch_size:
- writer = writer.option('batchsize', batch_size)
+ writer = writer.option("batchsize", batch_size)
if num_partitions:
- writer = writer.option('numPartitions', num_partitions)
+ writer = writer.option("numPartitions", num_partitions)
if create_table_column_types:
writer = writer.option("createTableColumnTypes", create_table_column_types)
@@ -108,39 +109,39 @@ def spark_read_from_jdbc(
# now set specific read options
if fetch_size:
- reader = reader.option('fetchsize', fetch_size)
+ reader = reader.option("fetchsize", fetch_size)
if num_partitions:
- reader = reader.option('numPartitions', num_partitions)
+ reader = reader.option("numPartitions", num_partitions)
if partition_column and lower_bound and upper_bound:
reader = (
- reader.option('partitionColumn', partition_column)
- .option('lowerBound', lower_bound)
- .option('upperBound', upper_bound)
+ reader.option("partitionColumn", partition_column)
+ .option("lowerBound", lower_bound)
+ .option("upperBound", upper_bound)
)
reader.load().write.saveAsTable(metastore_table, format=save_format, mode=save_mode)
-def _parse_arguments(args: Optional[List[str]] = None) -> Any:
- parser = argparse.ArgumentParser(description='Spark-JDBC')
- parser.add_argument('-cmdType', dest='cmd_type', action='store')
- parser.add_argument('-url', dest='url', action='store')
- parser.add_argument('-user', dest='user', action='store')
- parser.add_argument('-password', dest='password', action='store')
- parser.add_argument('-metastoreTable', dest='metastore_table', action='store')
- parser.add_argument('-jdbcTable', dest='jdbc_table', action='store')
- parser.add_argument('-jdbcDriver', dest='jdbc_driver', action='store')
- parser.add_argument('-jdbcTruncate', dest='truncate', action='store')
- parser.add_argument('-saveMode', dest='save_mode', action='store')
- parser.add_argument('-saveFormat', dest='save_format', action='store')
- parser.add_argument('-batchsize', dest='batch_size', action='store')
- parser.add_argument('-fetchsize', dest='fetch_size', action='store')
- parser.add_argument('-name', dest='name', action='store')
- parser.add_argument('-numPartitions', dest='num_partitions', action='store')
- parser.add_argument('-partitionColumn', dest='partition_column', action='store')
- parser.add_argument('-lowerBound', dest='lower_bound', action='store')
- parser.add_argument('-upperBound', dest='upper_bound', action='store')
- parser.add_argument('-createTableColumnTypes', dest='create_table_column_types', action='store')
+def _parse_arguments(args: list[str] | None = None) -> Any:
+ parser = argparse.ArgumentParser(description="Spark-JDBC")
+ parser.add_argument("-cmdType", dest="cmd_type", action="store")
+ parser.add_argument("-url", dest="url", action="store")
+ parser.add_argument("-user", dest="user", action="store")
+ parser.add_argument("-password", dest="password", action="store")
+ parser.add_argument("-metastoreTable", dest="metastore_table", action="store")
+ parser.add_argument("-jdbcTable", dest="jdbc_table", action="store")
+ parser.add_argument("-jdbcDriver", dest="jdbc_driver", action="store")
+ parser.add_argument("-jdbcTruncate", dest="truncate", action="store")
+ parser.add_argument("-saveMode", dest="save_mode", action="store")
+ parser.add_argument("-saveFormat", dest="save_format", action="store")
+ parser.add_argument("-batchsize", dest="batch_size", action="store")
+ parser.add_argument("-fetchsize", dest="fetch_size", action="store")
+ parser.add_argument("-name", dest="name", action="store")
+ parser.add_argument("-numPartitions", dest="num_partitions", action="store")
+ parser.add_argument("-partitionColumn", dest="partition_column", action="store")
+ parser.add_argument("-lowerBound", dest="lower_bound", action="store")
+ parser.add_argument("-upperBound", dest="upper_bound", action="store")
+ parser.add_argument("-createTableColumnTypes", dest="create_table_column_types", action="store")
return parser.parse_args(args=args)
diff --git a/airflow/providers/apache/spark/hooks/spark_sql.py b/airflow/providers/apache/spark/hooks/spark_sql.py
index 2621bf9a4099b..d6f8f56c27732 100644
--- a/airflow/providers/apache/spark/hooks/spark_sql.py
+++ b/airflow/providers/apache/spark/hooks/spark_sql.py
@@ -15,9 +15,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-#
+from __future__ import annotations
+
import subprocess
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
+from typing import TYPE_CHECKING, Any
from airflow.exceptions import AirflowException, AirflowNotFoundException
from airflow.hooks.base import BaseHook
@@ -49,30 +50,30 @@ class SparkSqlHook(BaseHook):
(Default: The ``queue`` value set in the Connection, or ``"default"``)
"""
- conn_name_attr = 'conn_id'
- default_conn_name = 'spark_sql_default'
- conn_type = 'spark_sql'
- hook_name = 'Spark SQL'
+ conn_name_attr = "conn_id"
+ default_conn_name = "spark_sql_default"
+ conn_type = "spark_sql"
+ hook_name = "Spark SQL"
def __init__(
self,
sql: str,
- conf: Optional[str] = None,
+ conf: str | None = None,
conn_id: str = default_conn_name,
- total_executor_cores: Optional[int] = None,
- executor_cores: Optional[int] = None,
- executor_memory: Optional[str] = None,
- keytab: Optional[str] = None,
- principal: Optional[str] = None,
- master: Optional[str] = None,
- name: str = 'default-name',
- num_executors: Optional[int] = None,
+ total_executor_cores: int | None = None,
+ executor_cores: int | None = None,
+ executor_memory: str | None = None,
+ keytab: str | None = None,
+ principal: str | None = None,
+ master: str | None = None,
+ name: str = "default-name",
+ num_executors: int | None = None,
verbose: bool = True,
- yarn_queue: Optional[str] = None,
+ yarn_queue: str | None = None,
) -> None:
super().__init__()
- options: Dict = {}
- conn: Optional[Connection] = None
+ options: dict = {}
+ conn: Connection | None = None
try:
conn = self.get_connection(conn_id)
@@ -109,7 +110,7 @@ def __init__(
def get_conn(self) -> Any:
pass
- def _prepare_command(self, cmd: Union[str, List[str]]) -> List[str]:
+ def _prepare_command(self, cmd: str | list[str]) -> list[str]:
"""
Construct the spark-sql command to execute. Verbose output is enabled
as default.
diff --git a/airflow/providers/apache/spark/hooks/spark_submit.py b/airflow/providers/apache/spark/hooks/spark_submit.py
index 0f5dc2f7307cc..bfc08eda6402d 100644
--- a/airflow/providers/apache/spark/hooks/spark_submit.py
+++ b/airflow/providers/apache/spark/hooks/spark_submit.py
@@ -15,12 +15,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-#
+from __future__ import annotations
+
+import contextlib
import os
import re
import subprocess
import time
-from typing import Any, Dict, Iterator, List, Optional, Union
+from typing import Any, Iterator
from airflow.configuration import conf as airflow_conf
from airflow.exceptions import AirflowException
@@ -28,17 +30,16 @@
from airflow.security.kerberos import renew_from_kt
from airflow.utils.log.logging_mixin import LoggingMixin
-try:
+with contextlib.suppress(ImportError, NameError):
from airflow.kubernetes import kube_client
-except (ImportError, NameError):
- pass
+
+ALLOWED_SPARK_BINARIES = ["spark-submit", "spark2-submit"]
class SparkSubmitHook(BaseHook, LoggingMixin):
"""
This hook is a wrapper around the spark-submit binary to kick off a spark-submit job.
- It requires that the "spark-submit" binary is in the PATH or the spark_home to be
- supplied.
+ It requires that the "spark-submit" binary is in the PATH.
:param conf: Arbitrary Spark configuration properties
:param spark_conn_id: The :ref:`spark connection id ` as configured
@@ -80,46 +81,46 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
Some distros may use spark2-submit.
"""
- conn_name_attr = 'conn_id'
- default_conn_name = 'spark_default'
- conn_type = 'spark'
- hook_name = 'Spark'
+ conn_name_attr = "conn_id"
+ default_conn_name = "spark_default"
+ conn_type = "spark"
+ hook_name = "Spark"
@staticmethod
- def get_ui_field_behaviour() -> Dict[str, Any]:
+ def get_ui_field_behaviour() -> dict[str, Any]:
"""Returns custom field behaviour"""
return {
- "hidden_fields": ['schema', 'login', 'password'],
+ "hidden_fields": ["schema", "login", "password"],
"relabeling": {},
}
def __init__(
self,
- conf: Optional[Dict[str, Any]] = None,
- conn_id: str = 'spark_default',
- files: Optional[str] = None,
- py_files: Optional[str] = None,
- archives: Optional[str] = None,
- driver_class_path: Optional[str] = None,
- jars: Optional[str] = None,
- java_class: Optional[str] = None,
- packages: Optional[str] = None,
- exclude_packages: Optional[str] = None,
- repositories: Optional[str] = None,
- total_executor_cores: Optional[int] = None,
- executor_cores: Optional[int] = None,
- executor_memory: Optional[str] = None,
- driver_memory: Optional[str] = None,
- keytab: Optional[str] = None,
- principal: Optional[str] = None,
- proxy_user: Optional[str] = None,
- name: str = 'default-name',
- num_executors: Optional[int] = None,
+ conf: dict[str, Any] | None = None,
+ conn_id: str = "spark_default",
+ files: str | None = None,
+ py_files: str | None = None,
+ archives: str | None = None,
+ driver_class_path: str | None = None,
+ jars: str | None = None,
+ java_class: str | None = None,
+ packages: str | None = None,
+ exclude_packages: str | None = None,
+ repositories: str | None = None,
+ total_executor_cores: int | None = None,
+ executor_cores: int | None = None,
+ executor_memory: str | None = None,
+ driver_memory: str | None = None,
+ keytab: str | None = None,
+ principal: str | None = None,
+ proxy_user: str | None = None,
+ name: str = "default-name",
+ num_executors: int | None = None,
status_poll_interval: int = 1,
- application_args: Optional[List[Any]] = None,
- env_vars: Optional[Dict[str, Any]] = None,
+ application_args: list[Any] | None = None,
+ env_vars: dict[str, Any] | None = None,
verbose: bool = False,
- spark_binary: Optional[str] = None,
+ spark_binary: str | None = None,
) -> None:
super().__init__()
self._conf = conf or {}
@@ -146,42 +147,47 @@ def __init__(
self._application_args = application_args
self._env_vars = env_vars
self._verbose = verbose
- self._submit_sp: Optional[Any] = None
- self._yarn_application_id: Optional[str] = None
- self._kubernetes_driver_pod: Optional[str] = None
+ self._submit_sp: Any | None = None
+ self._yarn_application_id: str | None = None
+ self._kubernetes_driver_pod: str | None = None
self._spark_binary = spark_binary
+ if self._spark_binary is not None and self._spark_binary not in ALLOWED_SPARK_BINARIES:
+ raise RuntimeError(
+ f"The spark-binary extra can be on of {ALLOWED_SPARK_BINARIES} and it"
+ f" was `{spark_binary}`. Please make sure your spark binary is one of the"
+ f" allowed ones and that it is available on the PATH"
+ )
self._connection = self._resolve_connection()
- self._is_yarn = 'yarn' in self._connection['master']
- self._is_kubernetes = 'k8s' in self._connection['master']
+ self._is_yarn = "yarn" in self._connection["master"]
+ self._is_kubernetes = "k8s" in self._connection["master"]
if self._is_kubernetes and kube_client is None:
raise RuntimeError(
f"{self._connection['master']} specified by kubernetes dependencies are not installed!"
)
self._should_track_driver_status = self._resolve_should_track_driver_status()
- self._driver_id: Optional[str] = None
- self._driver_status: Optional[str] = None
- self._spark_exit_code: Optional[int] = None
- self._env: Optional[Dict[str, Any]] = None
+ self._driver_id: str | None = None
+ self._driver_status: str | None = None
+ self._spark_exit_code: int | None = None
+ self._env: dict[str, Any] | None = None
def _resolve_should_track_driver_status(self) -> bool:
"""
- Determines whether or not this hook should poll the spark driver status through
+ Determines whether this hook should poll the spark driver status through
subsequent spark-submit status requests after the initial spark-submit request
:return: if the driver status should be tracked
"""
- return 'spark://' in self._connection['master'] and self._connection['deploy_mode'] == 'cluster'
+ return "spark://" in self._connection["master"] and self._connection["deploy_mode"] == "cluster"
- def _resolve_connection(self) -> Dict[str, Any]:
+ def _resolve_connection(self) -> dict[str, Any]:
# Build from connection master or default to yarn if not available
conn_data = {
- 'master': 'yarn',
- 'queue': None,
- 'deploy_mode': None,
- 'spark_home': None,
- 'spark_binary': self._spark_binary or "spark-submit",
- 'namespace': None,
+ "master": "yarn",
+ "queue": None,
+ "deploy_mode": None,
+ "spark_binary": self._spark_binary or "spark-submit",
+ "namespace": None,
}
try:
@@ -189,44 +195,47 @@ def _resolve_connection(self) -> Dict[str, Any]:
# k8s://https://:
conn = self.get_connection(self._conn_id)
if conn.port:
- conn_data['master'] = f"{conn.host}:{conn.port}"
+ conn_data["master"] = f"{conn.host}:{conn.port}"
else:
- conn_data['master'] = conn.host
+ conn_data["master"] = conn.host
# Determine optional yarn queue from the extra field
extra = conn.extra_dejson
- conn_data['queue'] = extra.get('queue')
- conn_data['deploy_mode'] = extra.get('deploy-mode')
- conn_data['spark_home'] = extra.get('spark-home')
- conn_data['spark_binary'] = self._spark_binary or extra.get('spark-binary', "spark-submit")
- conn_data['namespace'] = extra.get('namespace')
+ conn_data["queue"] = extra.get("queue")
+ conn_data["deploy_mode"] = extra.get("deploy-mode")
+ spark_binary = self._spark_binary or extra.get("spark-binary", "spark-submit")
+ if spark_binary not in ALLOWED_SPARK_BINARIES:
+ raise RuntimeError(
+ f"The `spark-binary` extra can be on of {ALLOWED_SPARK_BINARIES} and it"
+ f" was `{spark_binary}`. Please make sure your spark binary is one of the"
+ " allowed ones and that it is available on the PATH"
+ )
+ conn_spark_home = extra.get("spark-home")
+ if conn_spark_home:
+ raise RuntimeError(
+ "The `spark-home` extra is not allowed any more. Please make sure your `spark-submit` or"
+ " `spark2-submit` are available on the PATH."
+ )
+ conn_data["spark_binary"] = spark_binary
+ conn_data["namespace"] = extra.get("namespace")
except AirflowException:
self.log.info(
- "Could not load connection string %s, defaulting to %s", self._conn_id, conn_data['master']
+ "Could not load connection string %s, defaulting to %s", self._conn_id, conn_data["master"]
)
- if 'spark.kubernetes.namespace' in self._conf:
- conn_data['namespace'] = self._conf['spark.kubernetes.namespace']
+ if "spark.kubernetes.namespace" in self._conf:
+ conn_data["namespace"] = self._conf["spark.kubernetes.namespace"]
return conn_data
def get_conn(self) -> Any:
pass
- def _get_spark_binary_path(self) -> List[str]:
- # If the spark_home is passed then build the spark-submit executable path using
- # the spark_home; otherwise assume that spark-submit is present in the path to
- # the executing user
- if self._connection['spark_home']:
- connection_cmd = [
- os.path.join(self._connection['spark_home'], 'bin', self._connection['spark_binary'])
- ]
- else:
- connection_cmd = [self._connection['spark_binary']]
-
- return connection_cmd
+ def _get_spark_binary_path(self) -> list[str]:
+ # Assume that spark-submit is present in the path to the executing user
+ return [self._connection["spark_binary"]]
- def _mask_cmd(self, connection_cmd: Union[str, List[str]]) -> str:
+ def _mask_cmd(self, connection_cmd: str | list[str]) -> str:
# Mask any password related fields in application args with key value pair
# where key contains password (case insensitive), e.g. HivePassword='abc'
connection_cmd_masked = re.sub(
@@ -243,14 +252,14 @@ def _mask_cmd(self, connection_cmd: Union[str, List[str]]) -> str:
# (matched above); if the value is quoted,
# it may contain whitespace.
r"(\2)", # Optional matching quote.
- r'\1******\3',
- ' '.join(connection_cmd),
+ r"\1******\3",
+ " ".join(connection_cmd),
flags=re.I,
)
return connection_cmd_masked
- def _build_spark_submit_command(self, application: str) -> List[str]:
+ def _build_spark_submit_command(self, application: str) -> list[str]:
"""
Construct the spark-submit command to execute.
@@ -260,7 +269,7 @@ def _build_spark_submit_command(self, application: str) -> List[str]:
connection_cmd = self._get_spark_binary_path()
# The url of the spark master
- connection_cmd += ["--master", self._connection['master']]
+ connection_cmd += ["--master", self._connection["master"]]
for key in self._conf:
connection_cmd += ["--conf", f"{key}={str(self._conf[key])}"]
@@ -273,11 +282,11 @@ def _build_spark_submit_command(self, application: str) -> List[str]:
tmpl = "spark.kubernetes.driverEnv.{}={}"
for key in self._env_vars:
connection_cmd += ["--conf", tmpl.format(key, str(self._env_vars[key]))]
- elif self._env_vars and self._connection['deploy_mode'] != "cluster":
+ elif self._env_vars and self._connection["deploy_mode"] != "cluster":
self._env = self._env_vars # Do it on Popen of the process
- elif self._env_vars and self._connection['deploy_mode'] == "cluster":
+ elif self._env_vars and self._connection["deploy_mode"] == "cluster":
raise AirflowException("SparkSubmitHook env_vars is not supported in standalone-cluster mode.")
- if self._is_kubernetes and self._connection['namespace']:
+ if self._is_kubernetes and self._connection["namespace"]:
connection_cmd += [
"--conf",
f"spark.kubernetes.namespace={self._connection['namespace']}",
@@ -320,10 +329,10 @@ def _build_spark_submit_command(self, application: str) -> List[str]:
connection_cmd += ["--class", self._java_class]
if self._verbose:
connection_cmd += ["--verbose"]
- if self._connection['queue']:
- connection_cmd += ["--queue", self._connection['queue']]
- if self._connection['deploy_mode']:
- connection_cmd += ["--deploy-mode", self._connection['deploy_mode']]
+ if self._connection["queue"]:
+ connection_cmd += ["--queue", self._connection["queue"]]
+ if self._connection["deploy_mode"]:
+ connection_cmd += ["--deploy-mode", self._connection["deploy_mode"]]
# The actual script to execute
connection_cmd += [application]
@@ -336,15 +345,15 @@ def _build_spark_submit_command(self, application: str) -> List[str]:
return connection_cmd
- def _build_track_driver_status_command(self) -> List[str]:
+ def _build_track_driver_status_command(self) -> list[str]:
"""
Construct the command to poll the driver status.
:return: full command to be executed
"""
curl_max_wait_time = 30
- spark_host = self._connection['master']
- if spark_host.endswith(':6066'):
+ spark_host = self._connection["master"]
+ if spark_host.endswith(":6066"):
spark_host = spark_host.replace("spark://", "http://")
connection_cmd = [
"/usr/bin/curl",
@@ -355,9 +364,7 @@ def _build_track_driver_status_command(self) -> List[str]:
self.log.info(connection_cmd)
# The driver id so we can poll for its status
- if self._driver_id:
- pass
- else:
+ if not self._driver_id:
raise AirflowException(
"Invalid status: attempted to poll driver status but no driver id is known. Giving up."
)
@@ -367,7 +374,7 @@ def _build_track_driver_status_command(self) -> List[str]:
connection_cmd = self._get_spark_binary_path()
# The url to the spark master
- connection_cmd += ["--master", self._connection['master']]
+ connection_cmd += ["--master", self._connection["master"]]
# The driver id so we can poll for its status
if self._driver_id:
@@ -457,8 +464,8 @@ def _process_spark_submit_log(self, itr: Iterator[Any]) -> None:
line = line.strip()
# If we run yarn cluster mode, we want to extract the application id from
# the logs so we can kill the application when we stop it unexpectedly
- if self._is_yarn and self._connection['deploy_mode'] == 'cluster':
- match = re.search('(application[0-9_]+)', line)
+ if self._is_yarn and self._connection["deploy_mode"] == "cluster":
+ match = re.search("(application[0-9_]+)", line)
if match:
self._yarn_application_id = match.groups()[0]
self.log.info("Identified spark driver id: %s", self._yarn_application_id)
@@ -466,13 +473,13 @@ def _process_spark_submit_log(self, itr: Iterator[Any]) -> None:
# If we run Kubernetes cluster mode, we want to extract the driver pod id
# from the logs so we can kill the application when we stop it unexpectedly
elif self._is_kubernetes:
- match = re.search(r'\s*pod name: ((.+?)-([a-z0-9]+)-driver)', line)
+ match = re.search(r"\s*pod name: ((.+?)-([a-z0-9]+)-driver)", line)
if match:
self._kubernetes_driver_pod = match.groups()[0]
self.log.info("Identified spark driver pod: %s", self._kubernetes_driver_pod)
# Store the Spark Exit code
- match_exit_code = re.search(r'\s*[eE]xit code: (\d+)', line)
+ match_exit_code = re.search(r"\s*[eE]xit code: (\d+)", line)
if match_exit_code:
self._spark_exit_code = int(match_exit_code.groups()[0])
@@ -480,7 +487,7 @@ def _process_spark_submit_log(self, itr: Iterator[Any]) -> None:
# we need to extract the driver id from the logs. This allows us to poll for
# the status using the driver id. Also, we can kill the driver when needed.
elif self._should_track_driver_status and not self._driver_id:
- match_driver_id = re.search(r'(driver-[0-9\-]+)', line)
+ match_driver_id = re.search(r"(driver-[0-9\-]+)", line)
if match_driver_id:
self._driver_id = match_driver_id.groups()[0]
self.log.info("identified spark driver id: %s", self._driver_id)
@@ -505,7 +512,7 @@ def _process_spark_status_log(self, itr: Iterator[Any]) -> None:
# Check if the log line is about the driver status and extract the status.
if "driverState" in line:
- self._driver_status = line.split(' : ')[1].replace(',', '').replace('\"', '').strip()
+ self._driver_status = line.split(" : ")[1].replace(",", "").replace('"', "").strip()
driver_found = True
self.log.debug("spark driver status log: %s", line)
@@ -577,23 +584,16 @@ def _start_driver_status_tracking(self) -> None:
f"returncode = {returncode}"
)
- def _build_spark_driver_kill_command(self) -> List[str]:
+ def _build_spark_driver_kill_command(self) -> list[str]:
"""
Construct the spark-submit command to kill a driver.
:return: full command to kill a driver
"""
- # If the spark_home is passed then build the spark-submit executable path using
- # the spark_home; otherwise assume that spark-submit is present in the path to
- # the executing user
- if self._connection['spark_home']:
- connection_cmd = [
- os.path.join(self._connection['spark_home'], 'bin', self._connection['spark_binary'])
- ]
- else:
- connection_cmd = [self._connection['spark_binary']]
+ # Assume that spark-submit is present in the path to the executing user
+ connection_cmd = [self._connection["spark_binary"]]
# The url to the spark master
- connection_cmd += ["--master", self._connection['master']]
+ connection_cmd += ["--master", self._connection["master"]]
# The actual kill command
if self._driver_id:
@@ -607,20 +607,17 @@ def on_kill(self) -> None:
"""Kill Spark submit command"""
self.log.debug("Kill Command is being called")
- if self._should_track_driver_status:
- if self._driver_id:
- self.log.info('Killing driver %s on cluster', self._driver_id)
+ if self._should_track_driver_status and self._driver_id:
+ self.log.info("Killing driver %s on cluster", self._driver_id)
- kill_cmd = self._build_spark_driver_kill_command()
- with subprocess.Popen(
- kill_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
- ) as driver_kill:
- self.log.info(
- "Spark driver %s killed with return code: %s", self._driver_id, driver_kill.wait()
- )
+ kill_cmd = self._build_spark_driver_kill_command()
+ with subprocess.Popen(kill_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) as driver_kill:
+ self.log.info(
+ "Spark driver %s killed with return code: %s", self._driver_id, driver_kill.wait()
+ )
if self._submit_sp and self._submit_sp.poll() is None:
- self.log.info('Sending kill signal to %s', self._connection['spark_binary'])
+ self.log.info("Sending kill signal to %s", self._connection["spark_binary"])
self._submit_sp.kill()
if self._yarn_application_id:
@@ -632,7 +629,8 @@ def on_kill(self) -> None:
# we still attempt to kill the yarn application
renew_from_kt(self._principal, self._keytab, exit_on_fail=False)
env = os.environ.copy()
- env["KRB5CCNAME"] = airflow_conf.get_mandatory_value('kerberos', 'ccache')
+ ccacche = airflow_conf.get_mandatory_value("kerberos", "ccache")
+ env["KRB5CCNAME"] = ccacche
with subprocess.Popen(
kill_cmd, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE
@@ -640,7 +638,7 @@ def on_kill(self) -> None:
self.log.info("YARN app killed with return code: %s", yarn_kill.wait())
if self._kubernetes_driver_pod:
- self.log.info('Killing pod %s on Kubernetes', self._kubernetes_driver_pod)
+ self.log.info("Killing pod %s on Kubernetes", self._kubernetes_driver_pod)
# Currently only instantiate Kubernetes client for killing a spark pod.
try:
@@ -649,7 +647,7 @@ def on_kill(self) -> None:
client = kube_client.get_kube_client()
api_response = client.delete_namespaced_pod(
self._kubernetes_driver_pod,
- self._connection['namespace'],
+ self._connection["namespace"],
body=kubernetes.client.V1DeleteOptions(),
pretty=True,
)
diff --git a/airflow/providers/apache/spark/operators/spark_jdbc.py b/airflow/providers/apache/spark/operators/spark_jdbc.py
index 87f244be50899..7dd035db40fa7 100644
--- a/airflow/providers/apache/spark/operators/spark_jdbc.py
+++ b/airflow/providers/apache/spark/operators/spark_jdbc.py
@@ -15,8 +15,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-#
-from typing import TYPE_CHECKING, Any, Dict, Optional
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any
from airflow.providers.apache.spark.hooks.spark_jdbc import SparkJDBCHook
from airflow.providers.apache.spark.operators.spark_submit import SparkSubmitOperator
@@ -96,34 +97,34 @@ class SparkJDBCOperator(SparkSubmitOperator):
def __init__(
self,
*,
- spark_app_name: str = 'airflow-spark-jdbc',
- spark_conn_id: str = 'spark-default',
- spark_conf: Optional[Dict[str, Any]] = None,
- spark_py_files: Optional[str] = None,
- spark_files: Optional[str] = None,
- spark_jars: Optional[str] = None,
- num_executors: Optional[int] = None,
- executor_cores: Optional[int] = None,
- executor_memory: Optional[str] = None,
- driver_memory: Optional[str] = None,
+ spark_app_name: str = "airflow-spark-jdbc",
+ spark_conn_id: str = "spark-default",
+ spark_conf: dict[str, Any] | None = None,
+ spark_py_files: str | None = None,
+ spark_files: str | None = None,
+ spark_jars: str | None = None,
+ num_executors: int | None = None,
+ executor_cores: int | None = None,
+ executor_memory: str | None = None,
+ driver_memory: str | None = None,
verbose: bool = False,
- principal: Optional[str] = None,
- keytab: Optional[str] = None,
- cmd_type: str = 'spark_to_jdbc',
- jdbc_table: Optional[str] = None,
- jdbc_conn_id: str = 'jdbc-default',
- jdbc_driver: Optional[str] = None,
- metastore_table: Optional[str] = None,
+ principal: str | None = None,
+ keytab: str | None = None,
+ cmd_type: str = "spark_to_jdbc",
+ jdbc_table: str | None = None,
+ jdbc_conn_id: str = "jdbc-default",
+ jdbc_driver: str | None = None,
+ metastore_table: str | None = None,
jdbc_truncate: bool = False,
- save_mode: Optional[str] = None,
- save_format: Optional[str] = None,
- batch_size: Optional[int] = None,
- fetch_size: Optional[int] = None,
- num_partitions: Optional[int] = None,
- partition_column: Optional[str] = None,
- lower_bound: Optional[str] = None,
- upper_bound: Optional[str] = None,
- create_table_column_types: Optional[str] = None,
+ save_mode: str | None = None,
+ save_format: str | None = None,
+ batch_size: int | None = None,
+ fetch_size: int | None = None,
+ num_partitions: int | None = None,
+ partition_column: str | None = None,
+ lower_bound: str | None = None,
+ upper_bound: str | None = None,
+ create_table_column_types: str | None = None,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
@@ -155,9 +156,9 @@ def __init__(
self._lower_bound = lower_bound
self._upper_bound = upper_bound
self._create_table_column_types = create_table_column_types
- self._hook: Optional[SparkJDBCHook] = None
+ self._hook: SparkJDBCHook | None = None
- def execute(self, context: "Context") -> None:
+ def execute(self, context: Context) -> None:
"""Call the SparkSubmitHook to run the provided spark job"""
if self._hook is None:
self._hook = self._get_hook()
diff --git a/airflow/providers/apache/spark/operators/spark_sql.py b/airflow/providers/apache/spark/operators/spark_sql.py
index 33f19e9c43287..a5150d281aebe 100644
--- a/airflow/providers/apache/spark/operators/spark_sql.py
+++ b/airflow/providers/apache/spark/operators/spark_sql.py
@@ -15,8 +15,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-#
-from typing import TYPE_CHECKING, Any, Optional, Sequence
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any, Sequence
from airflow.models import BaseOperator
from airflow.providers.apache.spark.hooks.spark_sql import SparkSqlHook
@@ -51,26 +52,26 @@ class SparkSqlOperator(BaseOperator):
(Default: The ``queue`` value set in the Connection, or ``"default"``)
"""
- template_fields: Sequence[str] = ('_sql',)
+ template_fields: Sequence[str] = ("_sql",)
template_ext: Sequence[str] = (".sql", ".hql")
- template_fields_renderers = {'_sql': 'sql'}
+ template_fields_renderers = {"_sql": "sql"}
def __init__(
self,
*,
sql: str,
- conf: Optional[str] = None,
- conn_id: str = 'spark_sql_default',
- total_executor_cores: Optional[int] = None,
- executor_cores: Optional[int] = None,
- executor_memory: Optional[str] = None,
- keytab: Optional[str] = None,
- principal: Optional[str] = None,
- master: Optional[str] = None,
- name: str = 'default-name',
- num_executors: Optional[int] = None,
+ conf: str | None = None,
+ conn_id: str = "spark_sql_default",
+ total_executor_cores: int | None = None,
+ executor_cores: int | None = None,
+ executor_memory: str | None = None,
+ keytab: str | None = None,
+ principal: str | None = None,
+ master: str | None = None,
+ name: str = "default-name",
+ num_executors: int | None = None,
verbose: bool = True,
- yarn_queue: Optional[str] = None,
+ yarn_queue: str | None = None,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
@@ -87,9 +88,9 @@ def __init__(
self._num_executors = num_executors
self._verbose = verbose
self._yarn_queue = yarn_queue
- self._hook: Optional[SparkSqlHook] = None
+ self._hook: SparkSqlHook | None = None
- def execute(self, context: "Context") -> None:
+ def execute(self, context: Context) -> None:
"""Call the SparkSqlHook to run the provided sql query"""
if self._hook is None:
self._hook = self._get_hook()
diff --git a/airflow/providers/apache/spark/operators/spark_submit.py b/airflow/providers/apache/spark/operators/spark_submit.py
index db1114cf201dc..b0b2961c73f4c 100644
--- a/airflow/providers/apache/spark/operators/spark_submit.py
+++ b/airflow/providers/apache/spark/operators/spark_submit.py
@@ -15,8 +15,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-#
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any, Sequence
from airflow.models import BaseOperator
from airflow.providers.apache.spark.hooks.spark_submit import SparkSubmitHook
@@ -29,8 +30,7 @@
class SparkSubmitOperator(BaseOperator):
"""
This hook is a wrapper around the spark-submit binary to kick off a spark-submit job.
- It requires that the "spark-submit" binary is in the PATH or the spark-home is set
- in the extra on the connection.
+ It requires that the "spark-submit" binary is in the PATH.
.. seealso::
For more information on how to use this operator, take a look at the guide:
@@ -73,52 +73,52 @@ class SparkSubmitOperator(BaseOperator):
"""
template_fields: Sequence[str] = (
- '_application',
- '_conf',
- '_files',
- '_py_files',
- '_jars',
- '_driver_class_path',
- '_packages',
- '_exclude_packages',
- '_keytab',
- '_principal',
- '_proxy_user',
- '_name',
- '_application_args',
- '_env_vars',
+ "_application",
+ "_conf",
+ "_files",
+ "_py_files",
+ "_jars",
+ "_driver_class_path",
+ "_packages",
+ "_exclude_packages",
+ "_keytab",
+ "_principal",
+ "_proxy_user",
+ "_name",
+ "_application_args",
+ "_env_vars",
)
- ui_color = WEB_COLORS['LIGHTORANGE']
+ ui_color = WEB_COLORS["LIGHTORANGE"]
def __init__(
self,
*,
- application: str = '',
- conf: Optional[Dict[str, Any]] = None,
- conn_id: str = 'spark_default',
- files: Optional[str] = None,
- py_files: Optional[str] = None,
- archives: Optional[str] = None,
- driver_class_path: Optional[str] = None,
- jars: Optional[str] = None,
- java_class: Optional[str] = None,
- packages: Optional[str] = None,
- exclude_packages: Optional[str] = None,
- repositories: Optional[str] = None,
- total_executor_cores: Optional[int] = None,
- executor_cores: Optional[int] = None,
- executor_memory: Optional[str] = None,
- driver_memory: Optional[str] = None,
- keytab: Optional[str] = None,
- principal: Optional[str] = None,
- proxy_user: Optional[str] = None,
- name: str = 'arrow-spark',
- num_executors: Optional[int] = None,
+ application: str = "",
+ conf: dict[str, Any] | None = None,
+ conn_id: str = "spark_default",
+ files: str | None = None,
+ py_files: str | None = None,
+ archives: str | None = None,
+ driver_class_path: str | None = None,
+ jars: str | None = None,
+ java_class: str | None = None,
+ packages: str | None = None,
+ exclude_packages: str | None = None,
+ repositories: str | None = None,
+ total_executor_cores: int | None = None,
+ executor_cores: int | None = None,
+ executor_memory: str | None = None,
+ driver_memory: str | None = None,
+ keytab: str | None = None,
+ principal: str | None = None,
+ proxy_user: str | None = None,
+ name: str = "arrow-spark",
+ num_executors: int | None = None,
status_poll_interval: int = 1,
- application_args: Optional[List[Any]] = None,
- env_vars: Optional[Dict[str, Any]] = None,
+ application_args: list[Any] | None = None,
+ env_vars: dict[str, Any] | None = None,
verbose: bool = False,
- spark_binary: Optional[str] = None,
+ spark_binary: str | None = None,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
@@ -147,10 +147,10 @@ def __init__(
self._env_vars = env_vars
self._verbose = verbose
self._spark_binary = spark_binary
- self._hook: Optional[SparkSubmitHook] = None
+ self._hook: SparkSubmitHook | None = None
self._conn_id = conn_id
- def execute(self, context: "Context") -> None:
+ def execute(self, context: Context) -> None:
"""Call the SparkSubmitHook to run the provided spark job"""
if self._hook is None:
self._hook = self._get_hook()
diff --git a/airflow/providers/apache/spark/provider.yaml b/airflow/providers/apache/spark/provider.yaml
index c0a2dd23c21c8..a40fa651f3084 100644
--- a/airflow/providers/apache/spark/provider.yaml
+++ b/airflow/providers/apache/spark/provider.yaml
@@ -22,6 +22,8 @@ description: |
`Apache Spark `__
versions:
+ - 4.0.0
+ - 3.0.0
- 2.1.3
- 2.1.2
- 2.1.1
@@ -35,8 +37,9 @@ versions:
- 1.0.1
- 1.0.0
-additional-dependencies:
- - apache-airflow>=2.1.0
+dependencies:
+ - apache-airflow>=2.3.0
+ - pyspark
integrations:
- integration-name: Apache Spark
@@ -61,10 +64,6 @@ hooks:
- airflow.providers.apache.spark.hooks.spark_sql
- airflow.providers.apache.spark.hooks.spark_submit
-hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+
- - airflow.providers.apache.spark.hooks.spark_jdbc.SparkJDBCHook
- - airflow.providers.apache.spark.hooks.spark_sql.SparkSqlHook
- - airflow.providers.apache.spark.hooks.spark_submit.SparkSubmitHook
connection-types:
- hook-class-name: airflow.providers.apache.spark.hooks.spark_jdbc.SparkJDBCHook
diff --git a/airflow/providers/apache/sqoop/.latest-doc-only-change.txt b/airflow/providers/apache/sqoop/.latest-doc-only-change.txt
index e7c3c940c9c77..ff7136e07d744 100644
--- a/airflow/providers/apache/sqoop/.latest-doc-only-change.txt
+++ b/airflow/providers/apache/sqoop/.latest-doc-only-change.txt
@@ -1 +1 @@
-602abe8394fafe7de54df7e73af56de848cdf617
+06acf40a4337759797f666d5bb27a5a393b74fed
diff --git a/airflow/providers/apache/sqoop/CHANGELOG.rst b/airflow/providers/apache/sqoop/CHANGELOG.rst
index 5178f1d269566..a5785dd689508 100644
--- a/airflow/providers/apache/sqoop/CHANGELOG.rst
+++ b/airflow/providers/apache/sqoop/CHANGELOG.rst
@@ -16,9 +16,50 @@
under the License.
+.. NOTE TO CONTRIBUTORS:
+ Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes
+ and you want to add an explanation to the users on how they are supposed to deal with them.
+ The changelog is updated and maintained semi-automatically by release manager.
+
Changelog
---------
+3.1.0
+.....
+
+This release of provider is only available for Airflow 2.3+ as explained in the
+`Apache Airflow providers support policy `_.
+
+Misc
+~~~~
+
+* ``Move min airflow version to 2.3.0 for all providers (#27196)``
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Add documentation for July 2022 Provider's release (#25030)``
+ * ``Enable string normalization in python formatting - providers (#27205)``
+ * ``Update docs for September Provider's release (#26731)``
+ * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)``
+ * ``Prepare docs for new providers release (August 2022) (#25618)``
+ * ``Move provider dependencies to inside provider folders (#24672)``
+ * ``Remove 'hook-class-names' from provider.yaml (#24702)``
+
+3.0.0
+.....
+
+Breaking changes
+~~~~~~~~~~~~~~~~
+
+* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow
+ providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Add explanatory note for contributors about updating Changelog (#24229)``
+ * ``Prepare docs for May 2022 provider's release (#24231)``
+ * ``Update package description to remove double min-airflow specification (#24292)``
+
2.1.3
.....
diff --git a/airflow/providers/apache/sqoop/hooks/sqoop.py b/airflow/providers/apache/sqoop/hooks/sqoop.py
index 65ed7500cb5ae..43a1c105885b6 100644
--- a/airflow/providers/apache/sqoop/hooks/sqoop.py
+++ b/airflow/providers/apache/sqoop/hooks/sqoop.py
@@ -15,12 +15,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-#
-
"""This module contains a sqoop 1.x hook"""
+from __future__ import annotations
+
import subprocess
from copy import deepcopy
-from typing import Any, Dict, List, Optional
+from typing import Any
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
@@ -48,30 +48,30 @@ class SqoopHook(BaseHook):
:param properties: Properties to set via the -D argument
"""
- conn_name_attr = 'conn_id'
- default_conn_name = 'sqoop_default'
- conn_type = 'sqoop'
- hook_name = 'Sqoop'
+ conn_name_attr = "conn_id"
+ default_conn_name = "sqoop_default"
+ conn_type = "sqoop"
+ hook_name = "Sqoop"
def __init__(
self,
conn_id: str = default_conn_name,
verbose: bool = False,
- num_mappers: Optional[int] = None,
- hcatalog_database: Optional[str] = None,
- hcatalog_table: Optional[str] = None,
- properties: Optional[Dict[str, Any]] = None,
+ num_mappers: int | None = None,
+ hcatalog_database: str | None = None,
+ hcatalog_table: str | None = None,
+ properties: dict[str, Any] | None = None,
) -> None:
# No mutable types in the default parameters
super().__init__()
self.conn = self.get_connection(conn_id)
connection_parameters = self.conn.extra_dejson
- self.job_tracker = connection_parameters.get('job_tracker', None)
- self.namenode = connection_parameters.get('namenode', None)
- self.libjars = connection_parameters.get('libjars', None)
- self.files = connection_parameters.get('files', None)
- self.archives = connection_parameters.get('archives', None)
- self.password_file = connection_parameters.get('password_file', None)
+ self.job_tracker = connection_parameters.get("job_tracker", None)
+ self.namenode = connection_parameters.get("namenode", None)
+ self.libjars = connection_parameters.get("libjars", None)
+ self.files = connection_parameters.get("files", None)
+ self.archives = connection_parameters.get("archives", None)
+ self.password_file = connection_parameters.get("password_file", None)
self.hcatalog_database = hcatalog_database
self.hcatalog_table = hcatalog_table
self.verbose = verbose
@@ -83,17 +83,17 @@ def __init__(
def get_conn(self) -> Any:
return self.conn
- def cmd_mask_password(self, cmd_orig: List[str]) -> List[str]:
+ def cmd_mask_password(self, cmd_orig: list[str]) -> list[str]:
"""Mask command password for safety"""
cmd = deepcopy(cmd_orig)
try:
- password_index = cmd.index('--password')
- cmd[password_index + 1] = 'MASKED'
+ password_index = cmd.index("--password")
+ cmd[password_index + 1] = "MASKED"
except ValueError:
self.log.debug("No password in sqoop cmd")
return cmd
- def popen(self, cmd: List[str], **kwargs: Any) -> None:
+ def popen(self, cmd: list[str], **kwargs: Any) -> None:
"""
Remote Popen
@@ -101,7 +101,7 @@ def popen(self, cmd: List[str], **kwargs: Any) -> None:
:param kwargs: extra arguments to Popen (see subprocess.Popen)
:return: handle to subprocess
"""
- masked_cmd = ' '.join(self.cmd_mask_password(cmd))
+ masked_cmd = " ".join(self.cmd_mask_password(cmd))
self.log.info("Executing command: %s", masked_cmd)
with subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, **kwargs) as sub_process:
self.sub_process_pid = sub_process.pid
@@ -112,7 +112,7 @@ def popen(self, cmd: List[str], **kwargs: Any) -> None:
if sub_process.returncode:
raise AirflowException(f"Sqoop command failed: {masked_cmd}")
- def _prepare_command(self, export: bool = False) -> List[str]:
+ def _prepare_command(self, export: bool = False) -> list[str]:
sqoop_cmd_type = "export" if export else "import"
connection_cmd = ["sqoop", sqoop_cmd_type]
@@ -149,7 +149,7 @@ def _prepare_command(self, export: bool = False) -> List[str]:
connect_str += f":{self.conn.port}"
if self.conn.schema:
self.log.info("CONNECTION TYPE %s", self.conn.conn_type)
- if self.conn.conn_type != 'mssql':
+ if self.conn.conn_type != "mssql":
connect_str += f"/{self.conn.schema}"
else:
connect_str += f";databaseName={self.conn.schema}"
@@ -158,7 +158,7 @@ def _prepare_command(self, export: bool = False) -> List[str]:
return connection_cmd
@staticmethod
- def _get_export_format_argument(file_type: str = 'text') -> List[str]:
+ def _get_export_format_argument(file_type: str = "text") -> list[str]:
if file_type == "avro":
return ["--as-avrodatafile"]
elif file_type == "sequence":
@@ -172,14 +172,14 @@ def _get_export_format_argument(file_type: str = 'text') -> List[str]:
def _import_cmd(
self,
- target_dir: Optional[str],
+ target_dir: str | None,
append: bool,
file_type: str,
- split_by: Optional[str],
- direct: Optional[bool],
+ split_by: str | None,
+ direct: bool | None,
driver: Any,
extra_import_options: Any,
- ) -> List[str]:
+ ) -> list[str]:
cmd = self._prepare_command(export=False)
@@ -202,7 +202,7 @@ def _import_cmd(
if extra_import_options:
for key, value in extra_import_options.items():
- cmd += [f'--{key}']
+ cmd += [f"--{key}"]
if value:
cmd += [str(value)]
@@ -211,16 +211,16 @@ def _import_cmd(
def import_table(
self,
table: str,
- target_dir: Optional[str] = None,
+ target_dir: str | None = None,
append: bool = False,
file_type: str = "text",
- columns: Optional[str] = None,
- split_by: Optional[str] = None,
- where: Optional[str] = None,
+ columns: str | None = None,
+ split_by: str | None = None,
+ where: str | None = None,
direct: bool = False,
driver: Any = None,
- extra_import_options: Optional[Dict[str, Any]] = None,
- schema: Optional[str] = None,
+ extra_import_options: dict[str, Any] | None = None,
+ schema: str | None = None,
) -> Any:
"""
Imports table from remote location to target dir. Arguments are
@@ -257,13 +257,13 @@ def import_table(
def import_query(
self,
query: str,
- target_dir: Optional[str] = None,
+ target_dir: str | None = None,
append: bool = False,
file_type: str = "text",
- split_by: Optional[str] = None,
- direct: Optional[bool] = None,
- driver: Optional[Any] = None,
- extra_import_options: Optional[Dict[str, Any]] = None,
+ split_by: str | None = None,
+ direct: bool | None = None,
+ driver: Any | None = None,
+ extra_import_options: dict[str, Any] | None = None,
) -> Any:
"""
Imports a specific query from the rdbms to hdfs
@@ -288,21 +288,21 @@ def import_query(
def _export_cmd(
self,
table: str,
- export_dir: Optional[str] = None,
- input_null_string: Optional[str] = None,
- input_null_non_string: Optional[str] = None,
- staging_table: Optional[str] = None,
+ export_dir: str | None = None,
+ input_null_string: str | None = None,
+ input_null_non_string: str | None = None,
+ staging_table: str | None = None,
clear_staging_table: bool = False,
- enclosed_by: Optional[str] = None,
- escaped_by: Optional[str] = None,
- input_fields_terminated_by: Optional[str] = None,
- input_lines_terminated_by: Optional[str] = None,
- input_optionally_enclosed_by: Optional[str] = None,
+ enclosed_by: str | None = None,
+ escaped_by: str | None = None,
+ input_fields_terminated_by: str | None = None,
+ input_lines_terminated_by: str | None = None,
+ input_optionally_enclosed_by: str | None = None,
batch: bool = False,
relaxed_isolation: bool = False,
- extra_export_options: Optional[Dict[str, Any]] = None,
- schema: Optional[str] = None,
- ) -> List[str]:
+ extra_export_options: dict[str, Any] | None = None,
+ schema: str | None = None,
+ ) -> list[str]:
cmd = self._prepare_command(export=True)
@@ -344,7 +344,7 @@ def _export_cmd(
if extra_export_options:
for key, value in extra_export_options.items():
- cmd += [f'--{key}']
+ cmd += [f"--{key}"]
if value:
cmd += [str(value)]
@@ -359,20 +359,20 @@ def _export_cmd(
def export_table(
self,
table: str,
- export_dir: Optional[str] = None,
- input_null_string: Optional[str] = None,
- input_null_non_string: Optional[str] = None,
- staging_table: Optional[str] = None,
+ export_dir: str | None = None,
+ input_null_string: str | None = None,
+ input_null_non_string: str | None = None,
+ staging_table: str | None = None,
clear_staging_table: bool = False,
- enclosed_by: Optional[str] = None,
- escaped_by: Optional[str] = None,
- input_fields_terminated_by: Optional[str] = None,
- input_lines_terminated_by: Optional[str] = None,
- input_optionally_enclosed_by: Optional[str] = None,
+ enclosed_by: str | None = None,
+ escaped_by: str | None = None,
+ input_fields_terminated_by: str | None = None,
+ input_lines_terminated_by: str | None = None,
+ input_optionally_enclosed_by: str | None = None,
batch: bool = False,
relaxed_isolation: bool = False,
- extra_export_options: Optional[Dict[str, Any]] = None,
- schema: Optional[str] = None,
+ extra_export_options: dict[str, Any] | None = None,
+ schema: str | None = None,
) -> None:
"""
Exports Hive table to remote location. Arguments are copies of direct
diff --git a/airflow/providers/apache/sqoop/operators/sqoop.py b/airflow/providers/apache/sqoop/operators/sqoop.py
index 3ad7c5fe1a464..f6a44ebbf64a0 100644
--- a/airflow/providers/apache/sqoop/operators/sqoop.py
+++ b/airflow/providers/apache/sqoop/operators/sqoop.py
@@ -15,11 +15,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-#
"""This module contains a sqoop 1 operator"""
+from __future__ import annotations
+
import os
import signal
-from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence
+from typing import TYPE_CHECKING, Any, Sequence
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
@@ -84,71 +85,71 @@ class SqoopOperator(BaseOperator):
"""
template_fields: Sequence[str] = (
- 'conn_id',
- 'cmd_type',
- 'table',
- 'query',
- 'target_dir',
- 'file_type',
- 'columns',
- 'split_by',
- 'where',
- 'export_dir',
- 'input_null_string',
- 'input_null_non_string',
- 'staging_table',
- 'enclosed_by',
- 'escaped_by',
- 'input_fields_terminated_by',
- 'input_lines_terminated_by',
- 'input_optionally_enclosed_by',
- 'properties',
- 'extra_import_options',
- 'driver',
- 'extra_export_options',
- 'hcatalog_database',
- 'hcatalog_table',
- 'schema',
+ "conn_id",
+ "cmd_type",
+ "table",
+ "query",
+ "target_dir",
+ "file_type",
+ "columns",
+ "split_by",
+ "where",
+ "export_dir",
+ "input_null_string",
+ "input_null_non_string",
+ "staging_table",
+ "enclosed_by",
+ "escaped_by",
+ "input_fields_terminated_by",
+ "input_lines_terminated_by",
+ "input_optionally_enclosed_by",
+ "properties",
+ "extra_import_options",
+ "driver",
+ "extra_export_options",
+ "hcatalog_database",
+ "hcatalog_table",
+ "schema",
)
- template_fields_renderers = {'query': 'sql'}
- ui_color = '#7D8CA4'
+ template_fields_renderers = {"query": "sql"}
+ ui_color = "#7D8CA4"
def __init__(
self,
*,
- conn_id: str = 'sqoop_default',
- cmd_type: str = 'import',
- table: Optional[str] = None,
- query: Optional[str] = None,
- target_dir: Optional[str] = None,
+ conn_id: str = "sqoop_default",
+ cmd_type: str = "import",
+ table: str | None = None,
+ query: str | None = None,
+ target_dir: str | None = None,
append: bool = False,
- file_type: str = 'text',
- columns: Optional[str] = None,
- num_mappers: Optional[int] = None,
- split_by: Optional[str] = None,
- where: Optional[str] = None,
- export_dir: Optional[str] = None,
- input_null_string: Optional[str] = None,
- input_null_non_string: Optional[str] = None,
- staging_table: Optional[str] = None,
+ file_type: str = "text",
+ columns: str | None = None,
+ num_mappers: int | None = None,
+ split_by: str | None = None,
+ where: str | None = None,
+ export_dir: str | None = None,
+ input_null_string: str | None = None,
+ input_null_non_string: str | None = None,
+ staging_table: str | None = None,
clear_staging_table: bool = False,
- enclosed_by: Optional[str] = None,
- escaped_by: Optional[str] = None,
- input_fields_terminated_by: Optional[str] = None,
- input_lines_terminated_by: Optional[str] = None,
- input_optionally_enclosed_by: Optional[str] = None,
+ enclosed_by: str | None = None,
+ escaped_by: str | None = None,
+ input_fields_terminated_by: str | None = None,
+ input_lines_terminated_by: str | None = None,
+ input_optionally_enclosed_by: str | None = None,
batch: bool = False,
direct: bool = False,
- driver: Optional[Any] = None,
+ driver: Any | None = None,
verbose: bool = False,
relaxed_isolation: bool = False,
- properties: Optional[Dict[str, Any]] = None,
- hcatalog_database: Optional[str] = None,
- hcatalog_table: Optional[str] = None,
+ properties: dict[str, Any] | None = None,
+ hcatalog_database: str | None = None,
+ hcatalog_table: str | None = None,
create_hcatalog_table: bool = False,
- extra_import_options: Optional[Dict[str, Any]] = None,
- extra_export_options: Optional[Dict[str, Any]] = None,
- schema: Optional[str] = None,
+ extra_import_options: dict[str, Any] | None = None,
+ extra_export_options: dict[str, Any] | None = None,
+ schema: str | None = None,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
@@ -184,15 +185,15 @@ def __init__(
self.properties = properties
self.extra_import_options = extra_import_options or {}
self.extra_export_options = extra_export_options or {}
- self.hook: Optional[SqoopHook] = None
+ self.hook: SqoopHook | None = None
self.schema = schema
- def execute(self, context: "Context") -> None:
+ def execute(self, context: Context) -> None:
"""Execute sqoop job"""
if self.hook is None:
self.hook = self._get_hook()
- if self.cmd_type == 'export':
+ if self.cmd_type == "export":
self.hook.export_table(
table=self.table, # type: ignore
export_dir=self.export_dir,
@@ -210,15 +211,15 @@ def execute(self, context: "Context") -> None:
extra_export_options=self.extra_export_options,
schema=self.schema,
)
- elif self.cmd_type == 'import':
+ elif self.cmd_type == "import":
# add create hcatalog table to extra import options if option passed
# if new params are added to constructor can pass them in here
# so don't modify sqoop_hook for each param
if self.create_hcatalog_table:
- self.extra_import_options['create-hcatalog-table'] = ''
+ self.extra_import_options["create-hcatalog-table"] = ""
if self.table and self.query:
- raise AirflowException('Cannot specify query and table together. Need to specify either or.')
+ raise AirflowException("Cannot specify query and table together. Need to specify either or.")
if self.table:
self.hook.import_table(
@@ -253,7 +254,7 @@ def execute(self, context: "Context") -> None:
def on_kill(self) -> None:
if self.hook is None:
self.hook = self._get_hook()
- self.log.info('Sending SIGTERM signal to bash process group')
+ self.log.info("Sending SIGTERM signal to bash process group")
os.killpg(os.getpgid(self.hook.sub_process_pid), signal.SIGTERM)
def _get_hook(self) -> SqoopHook:
diff --git a/airflow/providers/apache/sqoop/provider.yaml b/airflow/providers/apache/sqoop/provider.yaml
index 51b94f251f81f..b542763fd0f85 100644
--- a/airflow/providers/apache/sqoop/provider.yaml
+++ b/airflow/providers/apache/sqoop/provider.yaml
@@ -22,6 +22,8 @@ description: |
`Apache Sqoop `__
versions:
+ - 3.1.0
+ - 3.0.0
- 2.1.3
- 2.1.2
- 2.1.1
@@ -32,8 +34,8 @@ versions:
- 1.0.1
- 1.0.0
-additional-dependencies:
- - apache-airflow>=2.1.0
+dependencies:
+ - apache-airflow>=2.3.0
integrations:
- integration-name: Apache Sqoop
@@ -53,9 +55,6 @@ hooks:
python-modules:
- airflow.providers.apache.sqoop.hooks.sqoop
-hook-class-names:
- # deprecated - to be removed after providers add dependency on Airflow 2.2.0+
- - airflow.providers.apache.sqoop.hooks.sqoop.SqoopHook
connection-types:
- hook-class-name: airflow.providers.apache.sqoop.hooks.sqoop.SqoopHook
diff --git a/airflow/providers/arangodb/.latest-doc-only-change.txt b/airflow/providers/arangodb/.latest-doc-only-change.txt
new file mode 100644
index 0000000000000..ff7136e07d744
--- /dev/null
+++ b/airflow/providers/arangodb/.latest-doc-only-change.txt
@@ -0,0 +1 @@
+06acf40a4337759797f666d5bb27a5a393b74fed
diff --git a/airflow/providers/arangodb/CHANGELOG.rst b/airflow/providers/arangodb/CHANGELOG.rst
index 1bafa2d67956f..d6dfb2f0d4010 100644
--- a/airflow/providers/arangodb/CHANGELOG.rst
+++ b/airflow/providers/arangodb/CHANGELOG.rst
@@ -17,8 +17,51 @@
specific language governing permissions and limitations
under the License.
+.. NOTE TO CONTRIBUTORS:
+ Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes
+ and you want to add an explanation to the users on how they are supposed to deal with them.
+ The changelog is updated and maintained semi-automatically by release manager.
+
Changelog
---------
+
+2.1.0
+.....
+
+This release of provider is only available for Airflow 2.3+ as explained in the
+`Apache Airflow providers support policy `_.
+
+Misc
+~~~~
+
+* ``Move min airflow version to 2.3.0 for all providers (#27196)``
+* ``Fix links to sources for examples (#24386)``
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Add documentation for July 2022 Provider's release (#25030)``
+ * ``Enable string normalization in python formatting - providers (#27205)``
+ * ``Update docs for September Provider's release (#26731)``
+ * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)``
+ * ``Prepare docs for new providers release (August 2022) (#25618)``
+ * ``Move provider dependencies to inside provider folders (#24672)``
+
+2.0.0
+.....
+
+Breaking changes
+~~~~~~~~~~~~~~~~
+
+* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow
+ providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Clean up f-strings in logging calls (#23597)``
+ * ``Add explanatory note for contributors about updating Changelog (#24229)``
+ * ``Prepare docs for May 2022 provider's release (#24231)``
+ * ``Update package description to remove double min-airflow specification (#24292)``
+
1.0.0
.....
diff --git a/airflow/providers/arangodb/example_dags/example_arangodb.py b/airflow/providers/arangodb/example_dags/example_arangodb.py
index f9da187cfb665..71c6346ef6073 100644
--- a/airflow/providers/arangodb/example_dags/example_arangodb.py
+++ b/airflow/providers/arangodb/example_dags/example_arangodb.py
@@ -14,6 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
from datetime import datetime
from airflow.models.dag import DAG
@@ -21,9 +23,9 @@
from airflow.providers.arangodb.sensors.arangodb import AQLSensor
dag = DAG(
- 'example_arangodb_operator',
+ "example_arangodb_operator",
start_date=datetime(2021, 1, 1),
- tags=['example'],
+ tags=["example"],
catchup=False,
)
@@ -41,7 +43,7 @@
# [START howto_aql_sensor_template_file_arangodb]
-sensor = AQLSensor(
+sensor2 = AQLSensor(
task_id="aql_sensor_template_file",
query="search_judy.sql",
timeout=60,
@@ -55,7 +57,7 @@
# [START howto_aql_operator_arangodb]
operator = AQLOperator(
- task_id='aql_operator',
+ task_id="aql_operator",
query="FOR doc IN students RETURN doc",
dag=dag,
result_processor=lambda cursor: print([document["name"] for document in cursor]),
@@ -65,8 +67,8 @@
# [START howto_aql_operator_template_file_arangodb]
-operator = AQLOperator(
- task_id='aql_operator_template_file',
+operator2 = AQLOperator(
+ task_id="aql_operator_template_file",
dag=dag,
result_processor=lambda cursor: print([document["name"] for document in cursor]),
query="search_all.sql",
diff --git a/airflow/providers/arangodb/hooks/arangodb.py b/airflow/providers/arangodb/hooks/arangodb.py
index f88ed9fb33cfa..a30583e837859 100644
--- a/airflow/providers/arangodb/hooks/arangodb.py
+++ b/airflow/providers/arangodb/hooks/arangodb.py
@@ -15,9 +15,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""This module allows connecting to a ArangoDB."""
-from typing import Any, Dict, Optional
+from __future__ import annotations
+
+from typing import Any
from arango import AQLQueryExecuteError, ArangoClient as ArangoDBClient
from arango.result import Result
@@ -35,10 +36,10 @@ class ArangoDBHook(BaseHook):
:param arangodb_conn_id: Reference to :ref:`ArangoDB connection id `.
"""
- conn_name_attr = 'arangodb_conn_id'
- default_conn_name = 'arangodb_default'
- conn_type = 'arangodb'
- hook_name = 'ArangoDB'
+ conn_name_attr = "arangodb_conn_id"
+ default_conn_name = "arangodb_default"
+ conn_type = "arangodb"
+ hook_name = "ArangoDB"
def __init__(self, arangodb_conn_id: str = default_conn_name, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
@@ -48,7 +49,7 @@ def __init__(self, arangodb_conn_id: str = default_conn_name, *args, **kwargs) -
self.password = None
self.db_conn = None
self.arangodb_conn_id = arangodb_conn_id
- self.client: Optional[ArangoDBClient] = None
+ self.client: ArangoDBClient | None = None
self.get_conn()
def get_conn(self) -> ArangoDBClient:
@@ -90,7 +91,7 @@ def create_collection(self, name):
self.db_conn.create_collection(name)
return True
else:
- self.log.info('Collection already exists: %s', name)
+ self.log.info("Collection already exists: %s", name)
return False
def create_database(self, name):
@@ -98,7 +99,7 @@ def create_database(self, name):
self.db_conn.create_database(name)
return True
else:
- self.log.info('Database already exists: %s', name)
+ self.log.info("Database already exists: %s", name)
return False
def create_graph(self, name):
@@ -106,24 +107,24 @@ def create_graph(self, name):
self.db_conn.create_graph(name)
return True
else:
- self.log.info('Graph already exists: %s', name)
+ self.log.info("Graph already exists: %s", name)
return False
@staticmethod
- def get_ui_field_behaviour() -> Dict[str, Any]:
+ def get_ui_field_behaviour() -> dict[str, Any]:
return {
- "hidden_fields": ['port', 'extra'],
+ "hidden_fields": ["port", "extra"],
"relabeling": {
- 'host': 'ArangoDB Host URL or comma separated list of URLs (coordinators in a cluster)',
- 'schema': 'ArangoDB Database',
- 'login': 'ArangoDB Username',
- 'password': 'ArangoDB Password',
+ "host": "ArangoDB Host URL or comma separated list of URLs (coordinators in a cluster)",
+ "schema": "ArangoDB Database",
+ "login": "ArangoDB Username",
+ "password": "ArangoDB Password",
},
"placeholders": {
- 'host': 'eg."http://127.0.0.1:8529" or "http://127.0.0.1:8529,http://127.0.0.1:8530"'
- ' (coordinators in a cluster)',
- 'schema': '_system',
- 'login': 'root',
- 'password': 'password',
+ "host": 'eg."http://127.0.0.1:8529" or "http://127.0.0.1:8529,http://127.0.0.1:8530"'
+ " (coordinators in a cluster)",
+ "schema": "_system",
+ "login": "root",
+ "password": "password",
},
}
diff --git a/airflow/providers/arangodb/operators/arangodb.py b/airflow/providers/arangodb/operators/arangodb.py
index 716ca455d7e0a..8f8c357713e74 100644
--- a/airflow/providers/arangodb/operators/arangodb.py
+++ b/airflow/providers/arangodb/operators/arangodb.py
@@ -15,7 +15,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import TYPE_CHECKING, Callable, Optional, Sequence
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Callable, Sequence
from airflow.models import BaseOperator
from airflow.providers.arangodb.hooks.arangodb import ArangoDBHook
@@ -38,7 +40,7 @@ class AQLOperator(BaseOperator):
:param arangodb_conn_id: Reference to :ref:`ArangoDB connection id `.
"""
- template_fields: Sequence[str] = ('query',)
+ template_fields: Sequence[str] = ("query",)
template_ext: Sequence[str] = (".sql",)
template_fields_renderers = {"query": "sql"}
@@ -47,8 +49,8 @@ def __init__(
self,
*,
query: str,
- arangodb_conn_id: str = 'arangodb_default',
- result_processor: Optional[Callable] = None,
+ arangodb_conn_id: str = "arangodb_default",
+ result_processor: Callable | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -56,8 +58,8 @@ def __init__(
self.query = query
self.result_processor = result_processor
- def execute(self, context: 'Context'):
- self.log.info('Executing: %s', self.query)
+ def execute(self, context: Context):
+ self.log.info("Executing: %s", self.query)
hook = ArangoDBHook(arangodb_conn_id=self.arangodb_conn_id)
result = hook.query(self.query)
if self.result_processor:
diff --git a/airflow/providers/arangodb/provider.yaml b/airflow/providers/arangodb/provider.yaml
index 129be5b8a0d9c..b830c10976360 100644
--- a/airflow/providers/arangodb/provider.yaml
+++ b/airflow/providers/arangodb/provider.yaml
@@ -20,7 +20,14 @@ package-name: apache-airflow-providers-arangodb
name: ArangoDB
description: |
`ArangoDB `__
+
+dependencies:
+ - apache-airflow>=2.3.0
+ - python-arango>=7.3.2
+
versions:
+ - 2.1.0
+ - 2.0.0
- 1.0.0
integrations:
diff --git a/airflow/providers/arangodb/sensors/arangodb.py b/airflow/providers/arangodb/sensors/arangodb.py
index ee9d0d2a9004d..541cb9b38f8e3 100644
--- a/airflow/providers/arangodb/sensors/arangodb.py
+++ b/airflow/providers/arangodb/sensors/arangodb.py
@@ -15,6 +15,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
from typing import TYPE_CHECKING, Sequence
from airflow.providers.arangodb.hooks.arangodb import ArangoDBHook
@@ -36,7 +38,7 @@ class AQLSensor(BaseSensorOperator):
:param arangodb_db: Target ArangoDB name.
"""
- template_fields: Sequence[str] = ('query',)
+ template_fields: Sequence[str] = ("query",)
template_ext: Sequence[str] = (".sql",)
template_fields_renderers = {"query": "sql"}
@@ -46,7 +48,7 @@ def __init__(self, *, query: str, arangodb_conn_id: str = "arangodb_default", **
self.arangodb_conn_id = arangodb_conn_id
self.query = query
- def poke(self, context: 'Context') -> bool:
+ def poke(self, context: Context) -> bool:
self.log.info("Sensor running the following query: %s", self.query)
hook = ArangoDBHook(self.arangodb_conn_id)
records = hook.query(self.query, count=True).count()
diff --git a/airflow/providers/asana/.latest-doc-only-change.txt b/airflow/providers/asana/.latest-doc-only-change.txt
index ab24993f57139..ff7136e07d744 100644
--- a/airflow/providers/asana/.latest-doc-only-change.txt
+++ b/airflow/providers/asana/.latest-doc-only-change.txt
@@ -1 +1 @@
-8b6b0848a3cacf9999477d6af4d2a87463f03026
+06acf40a4337759797f666d5bb27a5a393b74fed
diff --git a/airflow/providers/asana/CHANGELOG.rst b/airflow/providers/asana/CHANGELOG.rst
index 5262e12f92bc0..a868ec662a007 100644
--- a/airflow/providers/asana/CHANGELOG.rst
+++ b/airflow/providers/asana/CHANGELOG.rst
@@ -15,9 +15,75 @@
specific language governing permissions and limitations
under the License.
+.. NOTE TO CONTRIBUTORS:
+ Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes
+ and you want to add an explanation to the users on how they are supposed to deal with them.
+ The changelog is updated and maintained semi-automatically by release manager.
+
Changelog
---------
+3.0.0
+.....
+
+This release of provider is only available for Airflow 2.3+ as explained in the
+`Apache Airflow providers support policy `_.
+
+Breaking changes
+~~~~~~~~~~~~~~~~
+
+* In AsanaHook, non-prefixed extra fields are supported and are preferred. So if you should update your
+ connection to replace ``extra__asana__workspace`` with ``workspace`` etc.
+
+Misc
+~~~~
+
+* ``Move min airflow version to 2.3.0 for all providers (#27196)``
+
+Features
+~~~~~~~~
+
+* ``Allow and prefer non-prefixed extra fields for AsanaHook (#27043)``
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Enable string normalization in python formatting - providers (#27205)``
+ * ``Update docs for September Provider's release (#26731)``
+ * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)``
+
+
+2.0.1
+.....
+
+Bug Fixes
+~~~~~~~~~
+
+* ``Update providers to use functools compat for ''cached_property'' (#24582)``
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Move provider dependencies to inside provider folders (#24672)``
+ * ``Remove 'hook-class-names' from provider.yaml (#24702)``
+
+2.0.0
+.....
+
+Breaking changes
+~~~~~~~~~~~~~~~~
+
+* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow
+ providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Add explanatory note for contributors about updating Changelog (#24229)``
+ * ``Migrate Asana example DAGs to new design #22440 (#24131)``
+ * ``Prepare provider documentation 2022.05.11 (#23631)``
+ * ``Use new Breese for building, pulling and verifying the images. (#23104)``
+ * ``Fix new MyPy errors in main (#22884)``
+ * ``Prepare docs for May 2022 provider's release (#24231)``
+ * ``Update package description to remove double min-airflow specification (#24292)``
+
1.1.3
.....
diff --git a/airflow/providers/asana/hooks/asana.py b/airflow/providers/asana/hooks/asana.py
index 45816437f81b8..544a5afb59961 100644
--- a/airflow/providers/asana/hooks/asana.py
+++ b/airflow/providers/asana/hooks/asana.py
@@ -15,22 +15,47 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
"""Connect to Asana."""
-import sys
-from typing import Any, Dict, Optional
+from __future__ import annotations
+
+from functools import wraps
+from typing import Any
from asana import Client # type: ignore[attr-defined]
from asana.error import NotFoundError # type: ignore[attr-defined]
-if sys.version_info >= (3, 8):
- from functools import cached_property
-else:
- from cached_property import cached_property
-
+from airflow.compat.functools import cached_property
from airflow.hooks.base import BaseHook
+def _ensure_prefixes(conn_type):
+ """
+ Remove when provider min airflow version >= 2.5.0 since this is handled by
+ provider manager from that version.
+ """
+
+ def dec(func):
+ @wraps(func)
+ def inner():
+ field_behaviors = func()
+ conn_attrs = {"host", "schema", "login", "password", "port", "extra"}
+
+ def _ensure_prefix(field):
+ if field not in conn_attrs and not field.startswith("extra__"):
+ return f"extra__{conn_type}__{field}"
+ else:
+ return field
+
+ if "placeholders" in field_behaviors:
+ placeholders = field_behaviors["placeholders"]
+ field_behaviors["placeholders"] = {_ensure_prefix(k): v for k, v in placeholders.items()}
+ return field_behaviors
+
+ return inner
+
+ return dec
+
+
class AsanaHook(BaseHook):
"""Wrapper around Asana Python client library."""
@@ -43,34 +68,48 @@ def __init__(self, conn_id: str = default_conn_name, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.connection = self.get_connection(conn_id)
extras = self.connection.extra_dejson
- self.workspace = extras.get("extra__asana__workspace") or None
- self.project = extras.get("extra__asana__project") or None
+ self.workspace = self._get_field(extras, "workspace") or None
+ self.project = self._get_field(extras, "project") or None
+
+ def _get_field(self, extras: dict, field_name: str):
+ """Get field from extra, first checking short name, then for backcompat we check for prefixed name."""
+ backcompat_prefix = "extra__asana__"
+ if field_name.startswith("extra__"):
+ raise ValueError(
+ f"Got prefixed name {field_name}; please remove the '{backcompat_prefix}' prefix "
+ "when using this method."
+ )
+ if field_name in extras:
+ return extras[field_name] or None
+ prefixed_name = f"{backcompat_prefix}{field_name}"
+ return extras.get(prefixed_name) or None
def get_conn(self) -> Client:
return self.client
@staticmethod
- def get_connection_form_widgets() -> Dict[str, Any]:
+ def get_connection_form_widgets() -> dict[str, Any]:
"""Returns connection widgets to add to connection form"""
from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
from flask_babel import lazy_gettext
from wtforms import StringField
return {
- "extra__asana__workspace": StringField(lazy_gettext("Workspace"), widget=BS3TextFieldWidget()),
- "extra__asana__project": StringField(lazy_gettext("Project"), widget=BS3TextFieldWidget()),
+ "workspace": StringField(lazy_gettext("Workspace"), widget=BS3TextFieldWidget()),
+ "project": StringField(lazy_gettext("Project"), widget=BS3TextFieldWidget()),
}
@staticmethod
- def get_ui_field_behaviour() -> Dict[str, Any]:
+ @_ensure_prefixes(conn_type="asana")
+ def get_ui_field_behaviour() -> dict[str, Any]:
"""Returns custom field behaviour"""
return {
"hidden_fields": ["port", "host", "login", "schema"],
"relabeling": {},
"placeholders": {
"password": "Asana personal access token",
- "extra__asana__workspace": "Asana workspace gid",
- "extra__asana__project": "Asana project gid",
+ "workspace": "Asana workspace gid",
+ "project": "Asana project gid",
},
}
@@ -85,7 +124,7 @@ def client(self) -> Client:
return Client.access_token(self.connection.password)
- def create_task(self, task_name: str, params: Optional[dict]) -> dict:
+ def create_task(self, task_name: str, params: dict | None) -> dict:
"""
Creates an Asana task.
@@ -99,7 +138,7 @@ def create_task(self, task_name: str, params: Optional[dict]) -> dict:
response = self.client.tasks.create(params=merged_params)
return response
- def _merge_create_task_parameters(self, task_name: str, task_params: Optional[dict]) -> dict:
+ def _merge_create_task_parameters(self, task_name: str, task_params: dict | None) -> dict:
"""
Merge create_task parameters with default params from the connection.
@@ -107,7 +146,7 @@ def _merge_create_task_parameters(self, task_name: str, task_params: Optional[di
:param task_params: Other task parameters which should override defaults from the connection
:return: A dict of merged parameters to use in the new task
"""
- merged_params: Dict[str, Any] = {"name": task_name}
+ merged_params: dict[str, Any] = {"name": task_name}
if self.project:
merged_params["projects"] = [self.project]
# Only use default workspace if user did not provide a project id
@@ -145,7 +184,7 @@ def delete_task(self, task_id: str) -> dict:
self.log.info("Asana task %s not found for deletion.", task_id)
return {}
- def find_task(self, params: Optional[dict]) -> list:
+ def find_task(self, params: dict | None) -> list:
"""
Retrieves a list of Asana tasks that match search parameters.
@@ -158,7 +197,7 @@ def find_task(self, params: Optional[dict]) -> list:
response = self.client.tasks.find_all(params=merged_params)
return list(response)
- def _merge_find_task_parameters(self, search_parameters: Optional[dict]) -> dict:
+ def _merge_find_task_parameters(self, search_parameters: dict | None) -> dict:
"""
Merge find_task parameters with default params from the connection.
diff --git a/airflow/providers/asana/operators/asana_tasks.py b/airflow/providers/asana/operators/asana_tasks.py
index 66d88291a3903..9d204153d4e16 100644
--- a/airflow/providers/asana/operators/asana_tasks.py
+++ b/airflow/providers/asana/operators/asana_tasks.py
@@ -15,8 +15,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
-from typing import TYPE_CHECKING, Optional
+from typing import TYPE_CHECKING
from airflow.models import BaseOperator
from airflow.providers.asana.hooks.asana import AsanaHook
@@ -47,7 +48,7 @@ def __init__(
*,
conn_id: str,
name: str,
- task_parameters: Optional[dict] = None,
+ task_parameters: dict | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -56,7 +57,7 @@ def __init__(
self.name = name
self.task_parameters = task_parameters
- def execute(self, context: 'Context') -> str:
+ def execute(self, context: Context) -> str:
hook = AsanaHook(conn_id=self.conn_id)
response = hook.create_task(self.name, self.task_parameters)
self.log.info(response)
@@ -93,7 +94,7 @@ def __init__(
self.asana_task_gid = asana_task_gid
self.task_parameters = task_parameters
- def execute(self, context: 'Context') -> None:
+ def execute(self, context: Context) -> None:
hook = AsanaHook(conn_id=self.conn_id)
response = hook.update_task(self.asana_task_gid, self.task_parameters)
self.log.info(response)
@@ -123,7 +124,7 @@ def __init__(
self.conn_id = conn_id
self.asana_task_gid = asana_task_gid
- def execute(self, context: 'Context') -> None:
+ def execute(self, context: Context) -> None:
hook = AsanaHook(conn_id=self.conn_id)
response = hook.delete_task(self.asana_task_gid)
self.log.info(response)
@@ -148,7 +149,7 @@ def __init__(
self,
*,
conn_id: str,
- search_parameters: Optional[dict] = None,
+ search_parameters: dict | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -156,7 +157,7 @@ def __init__(
self.conn_id = conn_id
self.search_parameters = search_parameters
- def execute(self, context: 'Context') -> list:
+ def execute(self, context: Context) -> list:
hook = AsanaHook(conn_id=self.conn_id)
response = hook.find_task(self.search_parameters)
self.log.info(response)
diff --git a/airflow/providers/asana/provider.yaml b/airflow/providers/asana/provider.yaml
index e82dd7c6ecc45..d8b27c99080b1 100644
--- a/airflow/providers/asana/provider.yaml
+++ b/airflow/providers/asana/provider.yaml
@@ -22,14 +22,18 @@ description: |
`Asana `__
versions:
+ - 3.0.0
+ - 2.0.1
+ - 2.0.0
- 1.1.3
- 1.1.2
- 1.1.1
- 1.1.0
- 1.0.0
-additional-dependencies:
- - apache-airflow>=2.1.0
+dependencies:
+ - apache-airflow>=2.3.0
+ - asana>=0.10
integrations:
- integration-name: Asana
@@ -48,8 +52,6 @@ hooks:
python-modules:
- airflow.providers.asana.hooks.asana
-hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+
- - airflow.providers.asana.hooks.asana.AsanaHook
connection-types:
- hook-class-name: airflow.providers.asana.hooks.asana.AsanaHook
diff --git a/airflow/providers/apache/drill/example_dags/__init__.py b/airflow/providers/atlassian/__init__.py
similarity index 100%
rename from airflow/providers/apache/drill/example_dags/__init__.py
rename to airflow/providers/atlassian/__init__.py
diff --git a/airflow/providers/atlassian/jira/CHANGELOG.rst b/airflow/providers/atlassian/jira/CHANGELOG.rst
new file mode 100644
index 0000000000000..e316872b87bfb
--- /dev/null
+++ b/airflow/providers/atlassian/jira/CHANGELOG.rst
@@ -0,0 +1,45 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+
+.. NOTE TO CONTRIBUTORS:
+ Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes
+ and you want to add an explanation to the users on how they are supposed to deal with them.
+ The changelog is updated and maintained semi-automatically by release manager.
+
+Changelog
+---------
+
+1.1.0
+.....
+
+This release of provider is only available for Airflow 2.3+ as explained in the
+`Apache Airflow providers support policy `_.
+
+Misc
+~~~~
+
+* ``Move min airflow version to 2.3.0 for all providers (#27196)``
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Enable string normalization in python formatting - providers (#27205)``
+
+1.0.0
+.....
+
+Initial version of the provider.
diff --git a/airflow/providers/apache/kylin/example_dags/__init__.py b/airflow/providers/atlassian/jira/__init__.py
similarity index 100%
rename from airflow/providers/apache/kylin/example_dags/__init__.py
rename to airflow/providers/atlassian/jira/__init__.py
diff --git a/airflow/providers/apache/livy/example_dags/__init__.py b/airflow/providers/atlassian/jira/hooks/__init__.py
similarity index 100%
rename from airflow/providers/apache/livy/example_dags/__init__.py
rename to airflow/providers/atlassian/jira/hooks/__init__.py
diff --git a/airflow/providers/atlassian/jira/hooks/jira.py b/airflow/providers/atlassian/jira/hooks/jira.py
new file mode 100644
index 0000000000000..7ce1a9e80a16a
--- /dev/null
+++ b/airflow/providers/atlassian/jira/hooks/jira.py
@@ -0,0 +1,90 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Hook for JIRA"""
+from __future__ import annotations
+
+from typing import Any
+
+from jira import JIRA
+from jira.exceptions import JIRAError
+
+from airflow.exceptions import AirflowException
+from airflow.hooks.base import BaseHook
+
+
+class JiraHook(BaseHook):
+ """
+ Jira interaction hook, a Wrapper around JIRA Python SDK.
+
+ :param jira_conn_id: reference to a pre-defined Jira Connection
+ """
+
+ default_conn_name = "jira_default"
+ conn_type = "jira"
+ conn_name_attr = "jira_conn_id"
+ hook_name = "JIRA"
+
+ def __init__(self, jira_conn_id: str = default_conn_name, proxies: Any | None = None) -> None:
+ super().__init__()
+ self.jira_conn_id = jira_conn_id
+ self.proxies = proxies
+ self.client: JIRA | None = None
+ self.get_conn()
+
+ def get_conn(self) -> JIRA:
+ if not self.client:
+ self.log.debug("Creating Jira client for conn_id: %s", self.jira_conn_id)
+
+ get_server_info = True
+ validate = True
+ extra_options = {}
+ if not self.jira_conn_id:
+ raise AirflowException("Failed to create jira client. no jira_conn_id provided")
+
+ conn = self.get_connection(self.jira_conn_id)
+ if conn.extra is not None:
+ extra_options = conn.extra_dejson
+ # only required attributes are taken for now,
+ # more can be added ex: async, logging, max_retries
+
+ # verify
+ if "verify" in extra_options and extra_options["verify"].lower() == "false":
+ extra_options["verify"] = False
+
+ # validate
+ if "validate" in extra_options and extra_options["validate"].lower() == "false":
+ validate = False
+
+ if "get_server_info" in extra_options and extra_options["get_server_info"].lower() == "false":
+ get_server_info = False
+
+ try:
+ self.client = JIRA(
+ conn.host,
+ options=extra_options,
+ basic_auth=(conn.login, conn.password),
+ get_server_info=get_server_info,
+ validate=validate,
+ proxies=self.proxies,
+ )
+ except JIRAError as jira_error:
+ raise AirflowException(f"Failed to create jira client, jira error: {str(jira_error)}")
+ except Exception as e:
+ raise AirflowException(f"Failed to create jira client, error: {str(e)}")
+
+ return self.client
diff --git a/airflow/providers/dbt/cloud/example_dags/__init__.py b/airflow/providers/atlassian/jira/operators/__init__.py
similarity index 100%
rename from airflow/providers/dbt/cloud/example_dags/__init__.py
rename to airflow/providers/atlassian/jira/operators/__init__.py
diff --git a/airflow/providers/atlassian/jira/operators/jira.py b/airflow/providers/atlassian/jira/operators/jira.py
new file mode 100644
index 0000000000000..b1fe7bb05cd65
--- /dev/null
+++ b/airflow/providers/atlassian/jira/operators/jira.py
@@ -0,0 +1,91 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any, Callable, Sequence
+
+from airflow.exceptions import AirflowException
+from airflow.models import BaseOperator
+from airflow.providers.atlassian.jira.hooks.jira import JIRAError, JiraHook
+
+if TYPE_CHECKING:
+ from airflow.utils.context import Context
+
+
+class JiraOperator(BaseOperator):
+ """
+ JiraOperator to interact and perform action on Jira issue tracking system.
+ This operator is designed to use Jira Python SDK: http://jira.readthedocs.io
+
+ :param jira_conn_id: reference to a pre-defined Jira Connection
+ :param jira_method: method name from Jira Python SDK to be called
+ :param jira_method_args: required method parameters for the jira_method. (templated)
+ :param result_processor: function to further process the response from Jira
+ :param get_jira_resource_method: function or operator to get jira resource
+ on which the provided jira_method will be executed
+ """
+
+ template_fields: Sequence[str] = ("jira_method_args",)
+
+ def __init__(
+ self,
+ *,
+ jira_method: str,
+ jira_conn_id: str = "jira_default",
+ jira_method_args: dict | None = None,
+ result_processor: Callable | None = None,
+ get_jira_resource_method: Callable | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.jira_conn_id = jira_conn_id
+ self.method_name = jira_method
+ self.jira_method_args = jira_method_args
+ self.result_processor = result_processor
+ self.get_jira_resource_method = get_jira_resource_method
+
+ def execute(self, context: Context) -> Any:
+ try:
+ if self.get_jira_resource_method is not None:
+ # if get_jira_resource_method is provided, jira_method will be executed on
+ # resource returned by executing the get_jira_resource_method.
+ # This makes all the provided methods of JIRA sdk accessible and usable
+ # directly at the JiraOperator without additional wrappers.
+ # ref: http://jira.readthedocs.io/en/latest/api.html
+ if isinstance(self.get_jira_resource_method, JiraOperator):
+ resource = self.get_jira_resource_method.execute(**context)
+ else:
+ resource = self.get_jira_resource_method(**context)
+ else:
+ # Default method execution is on the top level jira client resource
+ hook = JiraHook(jira_conn_id=self.jira_conn_id)
+ resource = hook.client
+
+ # Current Jira-Python SDK (1.0.7) has issue with pickling the jira response.
+ # ex: self.xcom_push(context, key='operator_response', value=jira_response)
+ # This could potentially throw error if jira_result is not picklable
+ jira_result = getattr(resource, self.method_name)(**self.jira_method_args)
+ if self.result_processor:
+ return self.result_processor(context, jira_result)
+
+ return jira_result
+
+ except JIRAError as jira_error:
+ raise AirflowException(f"Failed to execute jiraOperator, error: {str(jira_error)}")
+ except Exception as e:
+ raise AirflowException(f"Jira operator error: {str(e)}")
diff --git a/airflow/providers/atlassian/jira/provider.yaml b/airflow/providers/atlassian/jira/provider.yaml
new file mode 100644
index 0000000000000..3e68f29ec4bd8
--- /dev/null
+++ b/airflow/providers/atlassian/jira/provider.yaml
@@ -0,0 +1,55 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+---
+package-name: apache-airflow-providers-atlassian-jira
+name: Atlassian Jira
+description: |
+ `Atlassian Jira `__
+
+versions:
+ - 1.1.0
+ - 1.0.0
+
+dependencies:
+ - apache-airflow>=2.3.0
+ - JIRA>1.0.7
+
+integrations:
+ - integration-name: Atlassian Jira
+ external-doc-url: https://www.atlassian.com/pl/software/jira
+ logo: /integration-logos/jira/Jira.png
+ tags: [software]
+
+operators:
+ - integration-name: Atlassian Jira
+ python-modules:
+ - airflow.providers.atlassian.jira.operators.jira
+
+sensors:
+ - integration-name: Atlassian Jira
+ python-modules:
+ - airflow.providers.atlassian.jira.sensors.jira
+
+hooks:
+ - integration-name: Atlassian Jira
+ python-modules:
+ - airflow.providers.atlassian.jira.hooks.jira
+
+connection-types:
+ - hook-class-name: airflow.providers.atlassian.jira.hooks.jira.JiraHook
+ connection-type: jira
diff --git a/airflow/providers/elasticsearch/example_dags/__init__.py b/airflow/providers/atlassian/jira/sensors/__init__.py
similarity index 100%
rename from airflow/providers/elasticsearch/example_dags/__init__.py
rename to airflow/providers/atlassian/jira/sensors/__init__.py
diff --git a/airflow/providers/atlassian/jira/sensors/jira.py b/airflow/providers/atlassian/jira/sensors/jira.py
new file mode 100644
index 0000000000000..cd3af90dab2ae
--- /dev/null
+++ b/airflow/providers/atlassian/jira/sensors/jira.py
@@ -0,0 +1,139 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any, Callable, Sequence
+
+from jira.resources import Issue, Resource
+
+from airflow.providers.atlassian.jira.hooks.jira import JiraHook
+from airflow.providers.atlassian.jira.operators.jira import JIRAError
+from airflow.sensors.base import BaseSensorOperator
+
+if TYPE_CHECKING:
+ from airflow.utils.context import Context
+
+
+class JiraSensor(BaseSensorOperator):
+ """
+ Monitors a jira ticket for any change.
+
+ :param jira_conn_id: reference to a pre-defined Jira Connection
+ :param method_name: method name from jira-python-sdk to be execute
+ :param method_params: parameters for the method method_name
+ :param result_processor: function that return boolean and act as a sensor response
+ """
+
+ def __init__(
+ self,
+ *,
+ method_name: str,
+ jira_conn_id: str = "jira_default",
+ method_params: dict | None = None,
+ result_processor: Callable | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.jira_conn_id = jira_conn_id
+ self.result_processor = None
+ if result_processor is not None:
+ self.result_processor = result_processor
+ self.method_name = method_name
+ self.method_params = method_params
+
+ def poke(self, context: Context) -> Any:
+ hook = JiraHook(jira_conn_id=self.jira_conn_id)
+ resource = hook.get_conn()
+ jira_result = getattr(resource, self.method_name)(**self.method_params)
+ if self.result_processor is None:
+ return jira_result
+ return self.result_processor(jira_result)
+
+
+class JiraTicketSensor(JiraSensor):
+ """
+ Monitors a jira ticket for given change in terms of function.
+
+ :param jira_conn_id: reference to a pre-defined Jira Connection
+ :param ticket_id: id of the ticket to be monitored
+ :param field: field of the ticket to be monitored
+ :param expected_value: expected value of the field
+ :param result_processor: function that return boolean and act as a sensor response
+ """
+
+ template_fields: Sequence[str] = ("ticket_id",)
+
+ def __init__(
+ self,
+ *,
+ jira_conn_id: str = "jira_default",
+ ticket_id: str | None = None,
+ field: str | None = None,
+ expected_value: str | None = None,
+ field_checker_func: Callable | None = None,
+ **kwargs,
+ ) -> None:
+
+ self.jira_conn_id = jira_conn_id
+ self.ticket_id = ticket_id
+ self.field = field
+ self.expected_value = expected_value
+ if field_checker_func is None:
+ field_checker_func = self.issue_field_checker
+
+ super().__init__(jira_conn_id=jira_conn_id, result_processor=field_checker_func, **kwargs)
+
+ def poke(self, context: Context) -> Any:
+ self.log.info("Jira Sensor checking for change in ticket: %s", self.ticket_id)
+
+ self.method_name = "issue"
+ self.method_params = {"id": self.ticket_id, "fields": self.field}
+ return JiraSensor.poke(self, context=context)
+
+ def issue_field_checker(self, issue: Issue) -> bool | None:
+ """Check issue using different conditions to prepare to evaluate sensor."""
+ result = None
+ try:
+ if issue is not None and self.field is not None and self.expected_value is not None:
+
+ field_val = getattr(issue.fields, self.field)
+ if field_val is not None:
+ if isinstance(field_val, list):
+ result = self.expected_value in field_val
+ elif isinstance(field_val, str):
+ result = self.expected_value.lower() == field_val.lower()
+ elif isinstance(field_val, Resource) and getattr(field_val, "name"):
+ result = self.expected_value.lower() == field_val.name.lower()
+ else:
+ self.log.warning(
+ "Not implemented checker for issue field %s which "
+ "is neither string nor list nor Jira Resource",
+ self.field,
+ )
+
+ except JIRAError as jira_error:
+ self.log.error("Jira error while checking with expected value: %s", jira_error)
+ except Exception:
+ self.log.exception("Error while checking with expected value %s:", self.expected_value)
+ if result is True:
+ self.log.info(
+ "Issue field %s has expected value %s, returning success", self.field, self.expected_value
+ )
+ else:
+ self.log.info("Issue field %s don't have expected value %s yet.", self.field, self.expected_value)
+ return result
diff --git a/airflow/providers/celery/.latest-doc-only-change.txt b/airflow/providers/celery/.latest-doc-only-change.txt
index 28124098645cf..ff7136e07d744 100644
--- a/airflow/providers/celery/.latest-doc-only-change.txt
+++ b/airflow/providers/celery/.latest-doc-only-change.txt
@@ -1 +1 @@
-6c3a67d4fccafe4ab6cd9ec8c7bacf2677f17038
+06acf40a4337759797f666d5bb27a5a393b74fed
diff --git a/airflow/providers/celery/CHANGELOG.rst b/airflow/providers/celery/CHANGELOG.rst
index 4a291303f63af..33e86a4d7fb5c 100644
--- a/airflow/providers/celery/CHANGELOG.rst
+++ b/airflow/providers/celery/CHANGELOG.rst
@@ -16,9 +16,50 @@
under the License.
+.. NOTE TO CONTRIBUTORS:
+ Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes
+ and you want to add an explanation to the users on how they are supposed to deal with them.
+ The changelog is updated and maintained semi-automatically by release manager.
+
Changelog
---------
+3.1.0
+.....
+
+This release of provider is only available for Airflow 2.3+ as explained in the
+`Apache Airflow providers support policy `_.
+
+Misc
+~~~~
+
+* ``Move min airflow version to 2.3.0 for all providers (#27196)``
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Add documentation for July 2022 Provider's release (#25030)``
+ * ``Update old style typing (#26872)``
+ * ``Enable string normalization in python formatting - providers (#27205)``
+ * ``Update docs for September Provider's release (#26731)``
+ * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)``
+ * ``Prepare docs for new providers release (August 2022) (#25618)``
+ * ``Move provider dependencies to inside provider folders (#24672)``
+
+3.0.0
+.....
+
+Breaking changes
+~~~~~~~~~~~~~~~~
+
+* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow
+ providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Add explanatory note for contributors about updating Changelog (#24229)``
+ * ``Prepare docs for May 2022 provider's release (#24231)``
+ * ``Update package description to remove double min-airflow specification (#24292)``
+
2.1.4
.....
diff --git a/airflow/providers/celery/provider.yaml b/airflow/providers/celery/provider.yaml
index 058796dc23302..8adc4c1c6c12c 100644
--- a/airflow/providers/celery/provider.yaml
+++ b/airflow/providers/celery/provider.yaml
@@ -22,6 +22,8 @@ description: |
`Celery `__
versions:
+ - 3.1.0
+ - 3.0.0
- 2.1.4
- 2.1.3
- 2.1.2
@@ -31,8 +33,14 @@ versions:
- 1.0.1
- 1.0.0
-additional-dependencies:
- - apache-airflow>=2.2.0
+dependencies:
+ - apache-airflow>=2.3.0
+ # The Celery is known to introduce problems when upgraded to a MAJOR version. Airflow Core
+ # Uses Celery for CeleryExecutor, and we also know that Kubernetes Python client follows SemVer
+ # (https://docs.celeryq.dev/en/stable/contributing.html?highlight=semver#versions).
+ # Make sure that the limit here is synchronized with [celery] extra in the airflow core
+ - celery>=5.2.3,<6
+ - flower>=1.0.0
integrations:
- integration-name: Celery
diff --git a/airflow/providers/celery/sensors/celery_queue.py b/airflow/providers/celery/sensors/celery_queue.py
index 5a7674ae6bc09..78e412c06a0b6 100644
--- a/airflow/providers/celery/sensors/celery_queue.py
+++ b/airflow/providers/celery/sensors/celery_queue.py
@@ -15,8 +15,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
-from typing import TYPE_CHECKING, Optional
+from typing import TYPE_CHECKING
from celery.app import control
@@ -36,13 +37,13 @@ class CeleryQueueSensor(BaseSensorOperator):
:param target_task_id: Task id for checking
"""
- def __init__(self, *, celery_queue: str, target_task_id: Optional[str] = None, **kwargs) -> None:
+ def __init__(self, *, celery_queue: str, target_task_id: str | None = None, **kwargs) -> None:
super().__init__(**kwargs)
self.celery_queue = celery_queue
self.target_task_id = target_task_id
- def _check_task_id(self, context: 'Context') -> bool:
+ def _check_task_id(self, context: Context) -> bool:
"""
Gets the returned Celery result from the Airflow task
ID provided to the sensor, and returns True if the
@@ -50,13 +51,12 @@ def _check_task_id(self, context: 'Context') -> bool:
:param context: Airflow's execution context
:return: True if task has been executed, otherwise False
- :rtype: bool
"""
- ti = context['ti']
+ ti = context["ti"]
celery_result = ti.xcom_pull(task_ids=self.target_task_id)
return celery_result.ready()
- def poke(self, context: 'Context') -> bool:
+ def poke(self, context: Context) -> bool:
if self.target_task_id:
return self._check_task_id(context)
@@ -71,8 +71,8 @@ def poke(self, context: 'Context') -> bool:
scheduled = len(scheduled[self.celery_queue])
active = len(active[self.celery_queue])
- self.log.info('Checking if celery queue %s is empty.', self.celery_queue)
+ self.log.info("Checking if celery queue %s is empty.", self.celery_queue)
return reserved == 0 and scheduled == 0 and active == 0
except KeyError:
- raise KeyError(f'Could not locate Celery queue {self.celery_queue}')
+ raise KeyError(f"Could not locate Celery queue {self.celery_queue}")
diff --git a/airflow/providers/cloudant/.latest-doc-only-change.txt b/airflow/providers/cloudant/.latest-doc-only-change.txt
index ab24993f57139..ff7136e07d744 100644
--- a/airflow/providers/cloudant/.latest-doc-only-change.txt
+++ b/airflow/providers/cloudant/.latest-doc-only-change.txt
@@ -1 +1 @@
-8b6b0848a3cacf9999477d6af4d2a87463f03026
+06acf40a4337759797f666d5bb27a5a393b74fed
diff --git a/airflow/providers/cloudant/CHANGELOG.rst b/airflow/providers/cloudant/CHANGELOG.rst
index a0167e8d356ba..8e452f1088931 100644
--- a/airflow/providers/cloudant/CHANGELOG.rst
+++ b/airflow/providers/cloudant/CHANGELOG.rst
@@ -16,9 +16,54 @@
under the License.
+.. NOTE TO CONTRIBUTORS:
+ Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes
+ and you want to add an explanation to the users on how they are supposed to deal with them.
+ The changelog is updated and maintained semi-automatically by release manager.
+
Changelog
---------
+3.1.0
+.....
+
+This release of provider is only available for Airflow 2.3+ as explained in the
+`Apache Airflow providers support policy `_.
+
+Misc
+~~~~
+
+* ``Move min airflow version to 2.3.0 for all providers (#27196)``
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Add documentation for July 2022 Provider's release (#25030)``
+ * ``Update old style typing (#26872)``
+ * ``Enable string normalization in python formatting - providers (#27205)``
+ * ``Update docs for September Provider's release (#26731)``
+ * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)``
+ * ``Prepare docs for new providers release (August 2022) (#25618)``
+ * ``Move provider dependencies to inside provider folders (#24672)``
+ * ``Remove 'hook-class-names' from provider.yaml (#24702)``
+
+3.0.0
+.....
+
+Breaking changes
+~~~~~~~~~~~~~~~~
+
+* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow
+ providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Add explanatory note for contributors about updating Changelog (#24229)``
+ * ``Prepare provider documentation 2022.05.11 (#23631)``
+ * ``Use new Breese for building, pulling and verifying the images. (#23104)``
+ * ``Fix new MyPy errors in main (#22884)``
+ * ``Prepare docs for May 2022 provider's release (#24231)``
+ * ``Update package description to remove double min-airflow specification (#24292)``
+
2.0.4
.....
diff --git a/airflow/providers/cloudant/hooks/cloudant.py b/airflow/providers/cloudant/hooks/cloudant.py
index 2398376dc4b24..248ee4d32c82a 100644
--- a/airflow/providers/cloudant/hooks/cloudant.py
+++ b/airflow/providers/cloudant/hooks/cloudant.py
@@ -16,7 +16,9 @@
# specific language governing permissions and limitations
# under the License.
"""Hook for Cloudant"""
-from typing import Any, Dict
+from __future__ import annotations
+
+from typing import Any
from cloudant import cloudant # type: ignore[attr-defined]
@@ -33,17 +35,17 @@ class CloudantHook(BaseHook):
:param cloudant_conn_id: The connection id to authenticate and get a session object from cloudant.
"""
- conn_name_attr = 'cloudant_conn_id'
- default_conn_name = 'cloudant_default'
- conn_type = 'cloudant'
- hook_name = 'Cloudant'
+ conn_name_attr = "cloudant_conn_id"
+ default_conn_name = "cloudant_default"
+ conn_type = "cloudant"
+ hook_name = "Cloudant"
@staticmethod
- def get_ui_field_behaviour() -> Dict[str, Any]:
+ def get_ui_field_behaviour() -> dict[str, Any]:
"""Returns custom field behaviour"""
return {
- "hidden_fields": ['port', 'extra'],
- "relabeling": {'host': 'Account', 'login': 'Username (or API Key)', 'schema': 'Database'},
+ "hidden_fields": ["port", "extra"],
+ "relabeling": {"host": "Account", "login": "Username (or API Key)", "schema": "Database"},
}
def __init__(self, cloudant_conn_id: str = default_conn_name) -> None:
@@ -61,7 +63,6 @@ def get_conn(self) -> cloudant:
- 'password' equals the 'Password' (required)
:return: an authorized cloudant session context manager object.
- :rtype: cloudant
"""
conn = self.get_connection(self.cloudant_conn_id)
@@ -72,6 +73,6 @@ def get_conn(self) -> cloudant:
return cloudant_session
def _validate_connection(self, conn: cloudant) -> None:
- for conn_param in ['login', 'password']:
+ for conn_param in ["login", "password"]:
if not getattr(conn, conn_param):
- raise AirflowException(f'missing connection parameter {conn_param}')
+ raise AirflowException(f"missing connection parameter {conn_param}")
diff --git a/airflow/providers/cloudant/provider.yaml b/airflow/providers/cloudant/provider.yaml
index 9c4479f1859c5..8495f93af0b22 100644
--- a/airflow/providers/cloudant/provider.yaml
+++ b/airflow/providers/cloudant/provider.yaml
@@ -22,6 +22,8 @@ description: |
`IBM Cloudant `__
versions:
+ - 3.1.0
+ - 3.0.0
- 2.0.4
- 2.0.3
- 2.0.2
@@ -30,8 +32,9 @@ versions:
- 1.0.1
- 1.0.0
-additional-dependencies:
- - apache-airflow>=2.1.0
+dependencies:
+ - apache-airflow>=2.3.0
+ - cloudant>=2.0
integrations:
- integration-name: IBM Cloudant
@@ -44,9 +47,6 @@ hooks:
python-modules:
- airflow.providers.cloudant.hooks.cloudant
-hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+
- - airflow.providers.cloudant.hooks.cloudant.CloudantHook
-
connection-types:
- hook-class-name: airflow.providers.cloudant.hooks.cloudant.CloudantHook
connection-type: cloudant
diff --git a/airflow/providers/cncf/kubernetes/CHANGELOG.rst b/airflow/providers/cncf/kubernetes/CHANGELOG.rst
index bb2d236594044..f7b293fd1fd40 100644
--- a/airflow/providers/cncf/kubernetes/CHANGELOG.rst
+++ b/airflow/providers/cncf/kubernetes/CHANGELOG.rst
@@ -16,21 +16,192 @@
under the License.
+.. NOTE TO CONTRIBUTORS:
+ Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes
+ and you want to add an explanation to the users on how they are supposed to deal with them.
+ The changelog is updated and maintained semi-automatically by release manager.
+
Changelog
---------
-main
-....
+5.0.0
+.....
+
+This release of provider is only available for Airflow 2.3+ as explained in the
+`Apache Airflow providers support policy `_.
+
+Breaking changes
+~~~~~~~~~~~~~~~~
+
+Previously KubernetesPodOperator considered some settings from the Airflow config's ``kubernetes`` section.
+Such consideration was deprecated in 4.1.0 and is now removed. If you previously relied on the Airflow
+config, and you want client generation to have non-default configuration, you will need to define your
+configuration in an Airflow connection and set KPO to use the connection. See kubernetes provider
+documentation on defining a kubernetes Airflow connection for details.
+
+Drop support for providing ``resource`` as dict in ``KubernetesPodOperator``. You
+should use ``container_resources`` with ``V1ResourceRequirements``.
+
+Param ``node_selectors`` has been removed in ``KubernetesPodOperator``; use ``node_selector`` instead.
+
+The following backcompat modules for KubernetesPodOperator are removed and you must now use
+the corresponding objects from the kubernetes library:
+
+* ``airflow.providers.cncf.kubernetes.backcompat.pod``
+* ``airflow.providers.cncf.kubernetes.backcompat.pod_runtime_info_env``
+* ``airflow.providers.cncf.kubernetes.backcompat.volume``
+* ``airflow.providers.cncf.kubernetes.backcompat.volume_mount``
+
+In ``KubernetesHook.get_namespace``, if a connection is defined but a namespace isn't set, we
+currently return 'default'; this behavior is deprecated. In the next release, we'll return ``None``.
+
+* ``Remove deprecated backcompat objects for KPO (#27518)``
+* ``Remove support for node_selectors param in KPO (#27515)``
+* ``Remove unused backcompat method in k8s hook (#27490)``
+* ``Drop support for providing ''resource'' as dict in ''KubernetesPodOperator'' (#27197)``
+* ``Deprecate use of core get_kube_client in PodManager (#26848)``
+* ``Don't consider airflow core conf for KPO (#26849)``
+
+Misc
+~~~~
+
+* ``Move min airflow version to 2.3.0 for all providers (#27196)``
+* ``Use log.exception where more economical than log.error (#27517)``
+
+Features
+~~~~~~~~
+
+Previously, ``name`` was a required argument for KubernetesPodOperator (when also not supplying pod
+template or full pod spec). Now, if ``name`` is not supplied, ``task_id`` will be used.
+
+KubernetsPodOperator argument ``namespace`` is now optional. If not supplied via KPO param or pod
+template file or full pod spec, then we'll check the airflow conn,
+then if in a k8s pod, try to infer the namespace from the container, then finally
+will use the ``default`` namespace.
+
+
+* ``Add container_resources as KubernetesPodOperator templatable (#27457)``
+* ``Add deprecation warning re unset namespace in k8s hook (#27202)``
+* ``add container_name option for SparkKubernetesSensor (#26560)``
+* ``Allow xcom sidecar container image to be configurable in KPO (#26766)``
+* ``Improve task_id to pod name conversion (#27524)``
+* ``Make pod name optional in KubernetesPodOperator (#27120)``
+* ``Make namespace optional for KPO (#27116)``
+* ``Enable template rendering for env_vars field for the @task.kubernetes decorator (#27433)``
+
+Bug Fixes
+~~~~~~~~~
+
+* ``Fix KubernetesHook fail on an attribute absence (#25787)``
+* ``Fix log message for kubernetes hooks (#26999)``
+* ``Remove extra__kubernetes__ prefix from k8s hook extras (#27021)``
+* ``KPO should use hook's get namespace method to get namespace (#27516)``
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Update old style typing (#26872)``
+ * ``Enable string normalization in python formatting - providers (#27205)``
+ * ``Update docs for September Provider's release (#26731)``
+
+4.4.0
+.....
+
+Features
+~~~~~~~~
+
+* ``feat(KubernetesPodOperator): Add support of container_security_context (#25530)``
+* ``Add @task.kubernetes taskflow decorator (#25663)``
+* ``pretty print KubernetesPodOperator rendered template env_vars (#25850)``
+
+Bug Fixes
+~~~~~~~~~
+
+* ``Avoid calculating all elements when one item is needed (#26377)``
+* ``Wait for xcom sidecar container to start before sidecar exec (#25055)``
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)``
+ * ``Prepare to release cncf.kubernetes provider (#26588)``
+
+4.3.0
+.....
Features
~~~~~~~~
-KubernetesPodOperator now uses KubernetesHook
-`````````````````````````````````````````````
+* ``Improve taskflow type hints with ParamSpec (#25173)``
+
+Bug Fixes
+~~~~~~~~~
+
+* ``Fix xcom_sidecar stuck problem (#24993)``
+
+4.2.0
+.....
-Previously, KubernetesPodOperator relied on core Airflow configuration (namely setting for kubernetes executor) for certain settings used in client generation. Now KubernetesPodOperator uses KubernetesHook, and the consideration of core k8s settings is officially deprecated.
+Features
+~~~~~~~~
+
+* ``Add 'airflow_kpo_in_cluster' label to KPO pods (#24658)``
+* ``Use found pod for deletion in KubernetesPodOperator (#22092)``
+
+Bug Fixes
+~~~~~~~~~
+
+* ``Revert "Fix await_container_completion condition (#23883)" (#24474)``
+* ``Update providers to use functools compat for ''cached_property'' (#24582)``
+
+Misc
+~~~~
+* ``Rename 'resources' arg in Kub op to k8s_resources (#24673)``
-If you are using the Airflow configuration settings (e.g. as opposed to operator params) to configure the kubernetes client, then prior to the next major release you will need to add an Airflow connection and set your KPO tasks to use that connection.
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Only assert stuff for mypy when type checking (#24937)``
+ * ``Remove 'xcom_push' flag from providers (#24823)``
+ * ``More typing and minor refactor for kubernetes (#24719)``
+ * ``Move provider dependencies to inside provider folders (#24672)``
+ * ``Use our yaml util in all providers (#24720)``
+ * ``Remove 'hook-class-names' from provider.yaml (#24702)``
+
+4.1.0
+.....
+
+Features
+~~~~~~~~
+
+* Previously, KubernetesPodOperator relied on core Airflow configuration (namely setting for kubernetes
+ executor) for certain settings used in client generation. Now KubernetesPodOperator
+ uses KubernetesHook, and the consideration of core k8s settings is officially deprecated.
+
+* If you are using the Airflow configuration settings (e.g. as opposed to operator params) to
+ configure the kubernetes client, then prior to the next major release you will need to
+ add an Airflow connection and set your KPO tasks to use that connection.
+
+* ``Use KubernetesHook to create api client in KubernetesPodOperator (#20578)``
+* ``[FEATURE] KPO use K8S hook (#22086)``
+* ``Add param docs to KubernetesHook and KubernetesPodOperator (#23955) (#24054)``
+
+Bug Fixes
+~~~~~~~~~
+
+* ``Use "remote" pod when patching KPO pod as "checked" (#23676)``
+* ``Don't use the root logger in KPO _suppress function (#23835)``
+* ``Fix await_container_completion condition (#23883)``
+
+Misc
+~~~~
+
+* ``Migrate Cncf.Kubernetes example DAGs to new design #22441 (#24132)``
+* ``Clean up f-strings in logging calls (#23597)``
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Add explanatory note for contributors about updating Changelog (#24229)``
+ * ``pydocstyle D202 added (#24221)``
+ * ``Prepare docs for May 2022 provider's release (#24231)``
+ * ``Update package description to remove double min-airflow specification (#24292)``
4.0.2
.....
@@ -46,7 +217,6 @@ Bug Fixes
appropriate section above if needed. Do not delete the lines(!):
* ``Add YANKED to yanked releases of the cncf.kubernetes (#23378)``
-.. Review and move the new changes to one of the sections above:
* ``Fix k8s pod.execute randomly stuck indefinitely by logs consumption (#23497) (#23618)``
* ``Revert "Fix k8s pod.execute randomly stuck indefinitely by logs consumption (#23497) (#23618)" (#23656)``
@@ -127,7 +297,7 @@ Features
~~~~~~~~
* ``Add map_index label to mapped KubernetesPodOperator (#21916)``
-* ``Change KubePodOperator labels from exeuction_date to run_id (#21960)``
+* ``Change KubernetesPodOperator labels from execution_date to run_id (#21960)``
Misc
~~~~
@@ -196,7 +366,7 @@ Notes on changes KubernetesPodOperator and PodLauncher
Overview
''''''''
-Generally speaking if you did not subclass ``KubernetesPodOperator`` and you didn't use the ``PodLauncher`` class directly,
+Generally speaking if you did not subclass ``KubernetesPodOperator`` and you did not use the ``PodLauncher`` class directly,
then you don't need to worry about this change. If however you have subclassed ``KubernetesPodOperator``, what
follows are some notes on the changes in this release.
@@ -392,7 +562,7 @@ Breaking changes
Features
~~~~~~~~
-* ``Add 'KubernetesPodOperat' 'pod-template-file' jinja template support (#15942)``
+* ``Add 'KubernetesPodOperator' 'pod-template-file' jinja template support (#15942)``
* ``Save pod name to xcom for KubernetesPodOperator (#15755)``
Bug Fixes
@@ -401,7 +571,7 @@ Bug Fixes
* ``Bug Fix Pod-Template Affinity Ignored due to empty Affinity K8S Object (#15787)``
* ``Bug Pod Template File Values Ignored (#16095)``
* ``Fix issue with parsing error logs in the KPO (#15638)``
-* ``Fix unsuccessful KubernetesPod final_state call when 'is_delete_operator_pod=True' (#15490)``
+* ``Fix unsuccessful KubernetesPodOperator final_state call when 'is_delete_operator_pod=True' (#15490)``
.. Below changes are excluded from the changelog. Move them to
appropriate section above if needed. Do not delete the lines(!):
@@ -423,7 +593,7 @@ Bug Fixes
~~~~~~~~~
* ``Fix timeout when using XCom with KubernetesPodOperator (#15388)``
-* ``Fix labels on the pod created by ''KubernetsPodOperator'' (#15492)``
+* ``Fix labels on the pod created by ''KubernetesPodOperator'' (#15492)``
1.1.0
.....
diff --git a/airflow/providers/cncf/kubernetes/__init__.py b/airflow/providers/cncf/kubernetes/__init__.py
index 0998e31143fc8..3ca5974776c6c 100644
--- a/airflow/providers/cncf/kubernetes/__init__.py
+++ b/airflow/providers/cncf/kubernetes/__init__.py
@@ -15,6 +15,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
+
import sys
if sys.version_info < (3, 7):
@@ -33,7 +35,7 @@ def _reduce_Logger(logger):
if logging.getLogger(logger.name) is not logger:
import pickle
- raise pickle.PicklingError('logger cannot be pickled')
+ raise pickle.PicklingError("logger cannot be pickled")
return logging.getLogger, (logger.name,)
def _reduce_RootLogger(logger):
diff --git a/airflow/providers/cncf/kubernetes/backcompat/backwards_compat_converters.py b/airflow/providers/cncf/kubernetes/backcompat/backwards_compat_converters.py
index bf2b8329f8ccd..9d37da983e519 100644
--- a/airflow/providers/cncf/kubernetes/backcompat/backwards_compat_converters.py
+++ b/airflow/providers/cncf/kubernetes/backcompat/backwards_compat_converters.py
@@ -15,8 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Executes task in a Kubernetes POD"""
-
-from typing import List
+from __future__ import annotations
from kubernetes.client import ApiClient, models as k8s
@@ -63,20 +62,6 @@ def convert_volume_mount(volume_mount) -> k8s.V1VolumeMount:
return _convert_kube_model_object(volume_mount, k8s.V1VolumeMount)
-def convert_resources(resources) -> k8s.V1ResourceRequirements:
- """
- Converts an airflow Resources object into a k8s.V1ResourceRequirements
-
- :param resources:
- :return: k8s.V1ResourceRequirements
- """
- if isinstance(resources, dict):
- from airflow.providers.cncf.kubernetes.backcompat.pod import Resources
-
- resources = Resources(**resources)
- return _convert_kube_model_object(resources, k8s.V1ResourceRequirements)
-
-
def convert_port(port) -> k8s.V1ContainerPort:
"""
Converts an airflow Port object into a k8s.V1ContainerPort
@@ -87,7 +72,7 @@ def convert_port(port) -> k8s.V1ContainerPort:
return _convert_kube_model_object(port, k8s.V1ContainerPort)
-def convert_env_vars(env_vars) -> List[k8s.V1EnvVar]:
+def convert_env_vars(env_vars) -> list[k8s.V1EnvVar]:
"""
Converts a dictionary into a list of env_vars
@@ -115,7 +100,7 @@ def convert_pod_runtime_info_env(pod_runtime_info_envs) -> k8s.V1EnvVar:
return _convert_kube_model_object(pod_runtime_info_envs, k8s.V1EnvVar)
-def convert_image_pull_secrets(image_pull_secrets) -> List[k8s.V1LocalObjectReference]:
+def convert_image_pull_secrets(image_pull_secrets) -> list[k8s.V1LocalObjectReference]:
"""
Converts a PodRuntimeInfoEnv into an k8s.V1EnvVar
diff --git a/airflow/providers/cncf/kubernetes/backcompat/pod.py b/airflow/providers/cncf/kubernetes/backcompat/pod.py
deleted file mode 100644
index 7f18117e18bfa..0000000000000
--- a/airflow/providers/cncf/kubernetes/backcompat/pod.py
+++ /dev/null
@@ -1,119 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-Classes for interacting with Kubernetes API.
-
-This module is deprecated. Please use :mod:`kubernetes.client.models.V1ResourceRequirements`
-and :mod:`kubernetes.client.models.V1ContainerPort`.
-"""
-
-import warnings
-
-from kubernetes.client import models as k8s
-
-warnings.warn(
- (
- "This module is deprecated. Please use `kubernetes.client.models.V1ResourceRequirements`"
- " and `kubernetes.client.models.V1ContainerPort`."
- ),
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class Resources:
- """backwards compat for Resources."""
-
- __slots__ = (
- 'request_memory',
- 'request_cpu',
- 'limit_memory',
- 'limit_cpu',
- 'limit_gpu',
- 'request_ephemeral_storage',
- 'limit_ephemeral_storage',
- )
-
- """
- :param request_memory: requested memory
- :param request_cpu: requested CPU number
- :param request_ephemeral_storage: requested ephemeral storage
- :param limit_memory: limit for memory usage
- :param limit_cpu: Limit for CPU used
- :param limit_gpu: Limits for GPU used
- :param limit_ephemeral_storage: Limit for ephemeral storage
- """
-
- def __init__(
- self,
- request_memory=None,
- request_cpu=None,
- request_ephemeral_storage=None,
- limit_memory=None,
- limit_cpu=None,
- limit_gpu=None,
- limit_ephemeral_storage=None,
- ):
- self.request_memory = request_memory
- self.request_cpu = request_cpu
- self.request_ephemeral_storage = request_ephemeral_storage
- self.limit_memory = limit_memory
- self.limit_cpu = limit_cpu
- self.limit_gpu = limit_gpu
- self.limit_ephemeral_storage = limit_ephemeral_storage
-
- def to_k8s_client_obj(self):
- """
- Converts to k8s object.
-
- @rtype: object
- """
- limits_raw = {
- 'cpu': self.limit_cpu,
- 'memory': self.limit_memory,
- 'nvidia.com/gpu': self.limit_gpu,
- 'ephemeral-storage': self.limit_ephemeral_storage,
- }
- requests_raw = {
- 'cpu': self.request_cpu,
- 'memory': self.request_memory,
- 'ephemeral-storage': self.request_ephemeral_storage,
- }
-
- limits = {k: v for k, v in limits_raw.items() if v}
- requests = {k: v for k, v in requests_raw.items() if v}
- resource_req = k8s.V1ResourceRequirements(limits=limits, requests=requests)
- return resource_req
-
-
-class Port:
- """POD port"""
-
- __slots__ = ('name', 'container_port')
-
- def __init__(self, name=None, container_port=None):
- """Creates port"""
- self.name = name
- self.container_port = container_port
-
- def to_k8s_client_obj(self):
- """
- Converts to k8s object.
-
- :rtype: object
- """
- return k8s.V1ContainerPort(name=self.name, container_port=self.container_port)
diff --git a/airflow/providers/cncf/kubernetes/backcompat/pod_runtime_info_env.py b/airflow/providers/cncf/kubernetes/backcompat/pod_runtime_info_env.py
deleted file mode 100644
index f08aecff33a89..0000000000000
--- a/airflow/providers/cncf/kubernetes/backcompat/pod_runtime_info_env.py
+++ /dev/null
@@ -1,56 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-Classes for interacting with Kubernetes API.
-
-This module is deprecated. Please use :mod:`kubernetes.client.models.V1EnvVar`.
-"""
-
-import warnings
-
-import kubernetes.client.models as k8s
-
-warnings.warn(
- "This module is deprecated. Please use `kubernetes.client.models.V1EnvVar`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class PodRuntimeInfoEnv:
- """Defines Pod runtime information as environment variable."""
-
- def __init__(self, name, field_path):
- """
- Adds Kubernetes pod runtime information as environment variables such as namespace, pod IP, pod name.
- Full list of options can be found in kubernetes documentation.
-
- :param name: the name of the environment variable
- :param field_path: path to pod runtime info. Ex: metadata.namespace | status.podIP
- """
- self.name = name
- self.field_path = field_path
-
- def to_k8s_client_obj(self):
- """Converts to k8s object.
-
- :return: kubernetes.client.models.V1EnvVar
- """
- return k8s.V1EnvVar(
- name=self.name,
- value_from=k8s.V1EnvVarSource(field_ref=k8s.V1ObjectFieldSelector(field_path=self.field_path)),
- )
diff --git a/airflow/providers/cncf/kubernetes/backcompat/volume.py b/airflow/providers/cncf/kubernetes/backcompat/volume.py
deleted file mode 100644
index c51ce8a551e38..0000000000000
--- a/airflow/providers/cncf/kubernetes/backcompat/volume.py
+++ /dev/null
@@ -1,62 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""This module is deprecated. Please use :mod:`kubernetes.client.models.V1Volume`."""
-
-import warnings
-
-from kubernetes.client import models as k8s
-
-warnings.warn(
- "This module is deprecated. Please use `kubernetes.client.models.V1Volume`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class Volume:
- """Backward compatible Volume"""
-
- def __init__(self, name, configs):
- """Adds Kubernetes Volume to pod. allows pod to access features like ConfigMaps
- and Persistent Volumes
-
- :param name: the name of the volume mount
- :param configs: dictionary of any features needed for volume. We purposely keep this
- vague since there are multiple volume types with changing configs.
- """
- self.name = name
- self.configs = configs
-
- def to_k8s_client_obj(self) -> k8s.V1Volume:
- """
- Converts to k8s object.
-
- :return: Volume Mount k8s object
- """
- resp = k8s.V1Volume(name=self.name)
- for k, v in self.configs.items():
- snake_key = Volume._convert_to_snake_case(k)
- if hasattr(resp, snake_key):
- setattr(resp, snake_key, v)
- else:
- raise AttributeError(f"V1Volume does not have attribute {k}")
- return resp
-
- # source: https://www.geeksforgeeks.org/python-program-to-convert-camel-case-string-to-snake-case/
- @staticmethod
- def _convert_to_snake_case(input_string):
- return ''.join('_' + i.lower() if i.isupper() else i for i in input_string).lstrip('_')
diff --git a/airflow/providers/cncf/kubernetes/backcompat/volume_mount.py b/airflow/providers/cncf/kubernetes/backcompat/volume_mount.py
deleted file mode 100644
index f9faed9d04a97..0000000000000
--- a/airflow/providers/cncf/kubernetes/backcompat/volume_mount.py
+++ /dev/null
@@ -1,58 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Classes for interacting with Kubernetes API"""
-
-import warnings
-
-from kubernetes.client import models as k8s
-
-warnings.warn(
- "This module is deprecated. Please use `kubernetes.client.models.V1VolumeMount`.",
- DeprecationWarning,
- stacklevel=2,
-)
-
-
-class VolumeMount:
- """Backward compatible VolumeMount"""
-
- __slots__ = ('name', 'mount_path', 'sub_path', 'read_only')
-
- def __init__(self, name, mount_path, sub_path, read_only):
- """
- Initialize a Kubernetes Volume Mount. Used to mount pod level volumes to
- running container.
-
- :param name: the name of the volume mount
- :param mount_path:
- :param sub_path: subpath within the volume mount
- :param read_only: whether to access pod with read-only mode
- """
- self.name = name
- self.mount_path = mount_path
- self.sub_path = sub_path
- self.read_only = read_only
-
- def to_k8s_client_obj(self) -> k8s.V1VolumeMount:
- """
- Converts to k8s object.
-
- :return: Volume Mount k8s object
- """
- return k8s.V1VolumeMount(
- name=self.name, mount_path=self.mount_path, sub_path=self.sub_path, read_only=self.read_only
- )
diff --git a/airflow/mypy/__init__.py b/airflow/providers/cncf/kubernetes/decorators/__init__.py
similarity index 100%
rename from airflow/mypy/__init__.py
rename to airflow/providers/cncf/kubernetes/decorators/__init__.py
diff --git a/airflow/providers/cncf/kubernetes/decorators/kubernetes.py b/airflow/providers/cncf/kubernetes/decorators/kubernetes.py
new file mode 100644
index 0000000000000..f68927c676433
--- /dev/null
+++ b/airflow/providers/cncf/kubernetes/decorators/kubernetes.py
@@ -0,0 +1,128 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import inspect
+import os
+import pickle
+import uuid
+from tempfile import TemporaryDirectory
+from textwrap import dedent
+from typing import TYPE_CHECKING, Callable, Sequence
+
+from kubernetes.client import models as k8s
+
+from airflow.decorators.base import DecoratedOperator, TaskDecorator, task_decorator_factory
+from airflow.providers.cncf.kubernetes.operators.kubernetes_pod import KubernetesPodOperator
+from airflow.providers.cncf.kubernetes.python_kubernetes_script import (
+ remove_task_decorator,
+ write_python_script,
+)
+
+if TYPE_CHECKING:
+ from airflow.utils.context import Context
+
+_PYTHON_SCRIPT_ENV = "__PYTHON_SCRIPT"
+
+_FILENAME_IN_CONTAINER = "/tmp/script.py"
+
+
+def _generate_decode_command() -> str:
+ return (
+ f'python -c "import base64, os;'
+ rf"x = os.environ[\"{_PYTHON_SCRIPT_ENV}\"];"
+ rf'f = open(\"{_FILENAME_IN_CONTAINER}\", \"w\"); f.write(x); f.close()"'
+ )
+
+
+def _read_file_contents(filename):
+ with open(filename) as script_file:
+ return script_file.read()
+
+
+class _KubernetesDecoratedOperator(DecoratedOperator, KubernetesPodOperator):
+ custom_operator_name = "@task.kubernetes"
+
+ # `cmds` and `arguments` are used internally by the operator
+ template_fields: Sequence[str] = tuple(
+ {"op_args", "op_kwargs", *KubernetesPodOperator.template_fields} - {"cmds", "arguments"}
+ )
+
+ # since we won't mutate the arguments, we should just do the shallow copy
+ # there are some cases we can't deepcopy the objects (e.g protobuf).
+ shallow_copy_attrs: Sequence[str] = ("python_callable",)
+
+ def __init__(self, namespace: str = "default", **kwargs) -> None:
+ self.pickling_library = pickle
+ super().__init__(
+ namespace=namespace,
+ name=kwargs.pop("name", f"k8s_airflow_pod_{uuid.uuid4().hex}"),
+ cmds=["bash"],
+ arguments=["-cx", f"{_generate_decode_command()} && python {_FILENAME_IN_CONTAINER}"],
+ **kwargs,
+ )
+
+ def _get_python_source(self):
+ raw_source = inspect.getsource(self.python_callable)
+ res = dedent(raw_source)
+ res = remove_task_decorator(res, "@task.kubernetes")
+ return res
+
+ def execute(self, context: Context):
+ with TemporaryDirectory(prefix="venv") as tmp_dir:
+ script_filename = os.path.join(tmp_dir, "script.py")
+ py_source = self._get_python_source()
+
+ jinja_context = {
+ "op_args": self.op_args,
+ "op_kwargs": self.op_kwargs,
+ "pickling_library": self.pickling_library.__name__,
+ "python_callable": self.python_callable.__name__,
+ "python_callable_source": py_source,
+ "string_args_global": False,
+ }
+ write_python_script(jinja_context=jinja_context, filename=script_filename)
+
+ self.env_vars = [
+ *self.env_vars,
+ k8s.V1EnvVar(name=_PYTHON_SCRIPT_ENV, value=_read_file_contents(script_filename)),
+ ]
+ return super().execute(context)
+
+
+def kubernetes_task(
+ python_callable: Callable | None = None,
+ multiple_outputs: bool | None = None,
+ **kwargs,
+) -> TaskDecorator:
+ """Kubernetes operator decorator.
+
+ This wraps a function to be executed in K8s using KubernetesPodOperator.
+ Also accepts any argument that DockerOperator will via ``kwargs``. Can be
+ reused in a single DAG.
+
+ :param python_callable: Function to decorate
+ :param multiple_outputs: if set, function return value will be
+ unrolled to multiple XCom values. Dict will unroll to xcom values with
+ keys as XCom keys. Defaults to False.
+ """
+ return task_decorator_factory(
+ python_callable=python_callable,
+ multiple_outputs=multiple_outputs,
+ decorated_operator_class=_KubernetesDecoratedOperator,
+ **kwargs,
+ )
diff --git a/airflow/providers/cncf/kubernetes/example_dags/example_spark_kubernetes.py b/airflow/providers/cncf/kubernetes/example_dags/example_spark_kubernetes.py
deleted file mode 100644
index d01d4b1328c68..0000000000000
--- a/airflow/providers/cncf/kubernetes/example_dags/example_spark_kubernetes.py
+++ /dev/null
@@ -1,66 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-This is an example DAG which uses SparkKubernetesOperator and SparkKubernetesSensor.
-In this example, we create two tasks which execute sequentially.
-The first task is to submit sparkApplication on Kubernetes cluster(the example uses spark-pi application).
-and the second task is to check the final state of the sparkApplication that submitted in the first state.
-
-Spark-on-k8s operator is required to be already installed on Kubernetes
-https://github.com/GoogleCloudPlatform/spark-on-k8s-operator
-"""
-
-from datetime import datetime, timedelta
-
-# [START import_module]
-# The DAG object; we'll need this to instantiate a DAG
-from airflow import DAG
-
-# Operators; we need this to operate!
-from airflow.providers.cncf.kubernetes.operators.spark_kubernetes import SparkKubernetesOperator
-from airflow.providers.cncf.kubernetes.sensors.spark_kubernetes import SparkKubernetesSensor
-
-# [END import_module]
-
-
-# [START instantiate_dag]
-
-dag = DAG(
- 'spark_pi',
- default_args={'max_active_runs': 1},
- description='submit spark-pi as sparkApplication on kubernetes',
- schedule_interval=timedelta(days=1),
- start_date=datetime(2021, 1, 1),
- catchup=False,
-)
-
-t1 = SparkKubernetesOperator(
- task_id='spark_pi_submit',
- namespace="default",
- application_file="example_spark_kubernetes_spark_pi.yaml",
- do_xcom_push=True,
- dag=dag,
-)
-
-t2 = SparkKubernetesSensor(
- task_id='spark_pi_monitor',
- namespace="default",
- application_name="{{ task_instance.xcom_pull(task_ids='spark_pi_submit')['metadata']['name'] }}",
- dag=dag,
-)
-t1 >> t2
diff --git a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py
index e15dce67ef40a..85b76d5f52970 100644
--- a/airflow/providers/cncf/kubernetes/hooks/kubernetes.py
+++ b/airflow/providers/cncf/kubernetes/hooks/kubernetes.py
@@ -14,29 +14,20 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-import sys
+from __future__ import annotations
+
import tempfile
import warnings
-from typing import Any, Dict, Generator, List, Optional, Tuple, Union
-
-from kubernetes.config import ConfigException
-
-from airflow.kubernetes.kube_client import _disable_verify_ssl, _enable_tcp_keepalive
-
-if sys.version_info >= (3, 8):
- from functools import cached_property
-else:
- from cached_property import cached_property
+from typing import TYPE_CHECKING, Any, Generator
from kubernetes import client, config, watch
+from kubernetes.config import ConfigException
-try:
- import airflow.utils.yaml as yaml
-except ImportError:
- import yaml # type: ignore[no-redef]
-
+from airflow.compat.functools import cached_property
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
+from airflow.kubernetes.kube_client import _disable_verify_ssl, _enable_tcp_keepalive
+from airflow.utils import yaml
def _load_body_to_dict(body):
@@ -51,10 +42,10 @@ class KubernetesHook(BaseHook):
"""
Creates Kubernetes API connection.
- - use in cluster configuration by using ``extra__kubernetes__in_cluster`` in connection
- - use custom config by providing path to the file using ``extra__kubernetes__kube_config_path``
+ - use in cluster configuration by using extra field ``in_cluster`` in connection
+ - use custom config by providing path to the file using extra field ``kube_config_path`` in connection
- use custom configuration by providing content of kubeconfig file via
- ``extra__kubernetes__kube_config`` in connection
+ extra field ``kube_config`` in connection
- use default config by providing no extras
This hook check for configuration option in the above order. Once an option is present it will
@@ -76,53 +67,52 @@ class KubernetesHook(BaseHook):
:param disable_tcp_keepalive: Set to ``True`` if you want to disable keepalive logic.
"""
- conn_name_attr = 'kubernetes_conn_id'
- default_conn_name = 'kubernetes_default'
- conn_type = 'kubernetes'
- hook_name = 'Kubernetes Cluster Connection'
+ conn_name_attr = "kubernetes_conn_id"
+ default_conn_name = "kubernetes_default"
+ conn_type = "kubernetes"
+ hook_name = "Kubernetes Cluster Connection"
+
+ DEFAULT_NAMESPACE = "default"
@staticmethod
- def get_connection_form_widgets() -> Dict[str, Any]:
+ def get_connection_form_widgets() -> dict[str, Any]:
"""Returns connection widgets to add to connection form"""
from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
from flask_babel import lazy_gettext
from wtforms import BooleanField, StringField
return {
- "extra__kubernetes__in_cluster": BooleanField(lazy_gettext('In cluster configuration')),
- "extra__kubernetes__kube_config_path": StringField(
- lazy_gettext('Kube config path'), widget=BS3TextFieldWidget()
- ),
- "extra__kubernetes__kube_config": StringField(
- lazy_gettext('Kube config (JSON format)'), widget=BS3TextFieldWidget()
+ "in_cluster": BooleanField(lazy_gettext("In cluster configuration")),
+ "kube_config_path": StringField(lazy_gettext("Kube config path"), widget=BS3TextFieldWidget()),
+ "kube_config": StringField(
+ lazy_gettext("Kube config (JSON format)"), widget=BS3TextFieldWidget()
),
- "extra__kubernetes__namespace": StringField(
- lazy_gettext('Namespace'), widget=BS3TextFieldWidget()
+ "namespace": StringField(lazy_gettext("Namespace"), widget=BS3TextFieldWidget()),
+ "cluster_context": StringField(lazy_gettext("Cluster context"), widget=BS3TextFieldWidget()),
+ "disable_verify_ssl": BooleanField(lazy_gettext("Disable SSL")),
+ "disable_tcp_keepalive": BooleanField(lazy_gettext("Disable TCP keepalive")),
+ "xcom_sidecar_container_image": StringField(
+ lazy_gettext("XCom sidecar image"), widget=BS3TextFieldWidget()
),
- "extra__kubernetes__cluster_context": StringField(
- lazy_gettext('Cluster context'), widget=BS3TextFieldWidget()
- ),
- "extra__kubernetes__disable_verify_ssl": BooleanField(lazy_gettext('Disable SSL')),
- "extra__kubernetes__disable_tcp_keepalive": BooleanField(lazy_gettext('Disable TCP keepalive')),
}
@staticmethod
- def get_ui_field_behaviour() -> Dict[str, Any]:
+ def get_ui_field_behaviour() -> dict[str, Any]:
"""Returns custom field behaviour"""
return {
- "hidden_fields": ['host', 'schema', 'login', 'password', 'port', 'extra'],
+ "hidden_fields": ["host", "schema", "login", "password", "port", "extra"],
"relabeling": {},
}
def __init__(
self,
- conn_id: Optional[str] = default_conn_name,
- client_configuration: Optional[client.Configuration] = None,
- cluster_context: Optional[str] = None,
- config_file: Optional[str] = None,
- in_cluster: Optional[bool] = None,
- disable_verify_ssl: Optional[bool] = None,
- disable_tcp_keepalive: Optional[bool] = None,
+ conn_id: str | None = default_conn_name,
+ client_configuration: client.Configuration | None = None,
+ cluster_context: str | None = None,
+ config_file: str | None = None,
+ in_cluster: bool | None = None,
+ disable_verify_ssl: bool | None = None,
+ disable_tcp_keepalive: bool | None = None,
) -> None:
super().__init__()
self.conn_id = conn_id
@@ -132,14 +122,7 @@ def __init__(
self.in_cluster = in_cluster
self.disable_verify_ssl = disable_verify_ssl
self.disable_tcp_keepalive = disable_tcp_keepalive
-
- # these params used for transition in KPO to K8s hook
- # for a deprecation period we will continue to consider k8s settings from airflow.cfg
- self._deprecated_core_disable_tcp_keepalive: Optional[bool] = None
- self._deprecated_core_disable_verify_ssl: Optional[bool] = None
- self._deprecated_core_in_cluster: Optional[bool] = None
- self._deprecated_core_cluster_context: Optional[str] = None
- self._deprecated_core_config_file: Optional[str] = None
+ self._is_in_cluster: bool | None = None
@staticmethod
def _coalesce_param(*params):
@@ -157,7 +140,12 @@ def conn_extras(self):
return extras
def _get_field(self, field_name):
- if field_name.startswith('extra_'):
+ """
+ Prior to Airflow 2.3, in order to make use of UI customizations for extra fields,
+ we needed to store them with the prefix ``extra__kubernetes__``. This method
+ handles the backcompat, i.e. if the extra dict contains prefixed fields.
+ """
+ if field_name.startswith("extra__"):
raise ValueError(
f"Got prefixed name {field_name}; please remove the 'extra__kubernetes__' prefix "
f"when using this method."
@@ -167,31 +155,12 @@ def _get_field(self, field_name):
prefixed_name = f"extra__kubernetes__{field_name}"
return self.conn_extras.get(prefixed_name) or None
- @staticmethod
- def _deprecation_warning_core_param(deprecation_warnings):
- settings_list_str = ''.join([f"\n\t{k}={v!r}" for k, v in deprecation_warnings])
- warnings.warn(
- f"\nApplying core Airflow settings from section [kubernetes] with the following keys:"
- f"{settings_list_str}\n"
- "In a future release, KubernetesPodOperator will no longer consider core\n"
- "Airflow settings; define an Airflow connection instead.",
- DeprecationWarning,
- )
-
- def get_conn(self) -> Any:
+ def get_conn(self) -> client.ApiClient:
"""Returns kubernetes api session for use with requests"""
-
- in_cluster = self._coalesce_param(
- self.in_cluster, self.conn_extras.get("extra__kubernetes__in_cluster") or None
- )
- cluster_context = self._coalesce_param(
- self.cluster_context, self.conn_extras.get("extra__kubernetes__cluster_context") or None
- )
- kubeconfig_path = self._coalesce_param(
- self.config_file, self.conn_extras.get("extra__kubernetes__kube_config_path") or None
- )
-
- kubeconfig = self.conn_extras.get("extra__kubernetes__kube_config") or None
+ in_cluster = self._coalesce_param(self.in_cluster, self._get_field("in_cluster"))
+ cluster_context = self._coalesce_param(self.cluster_context, self._get_field("cluster_context"))
+ kubeconfig_path = self._coalesce_param(self.config_file, self._get_field("kube_config_path"))
+ kubeconfig = self._get_field("kube_config")
num_selected_configuration = len([o for o in [in_cluster, kubeconfig, kubeconfig_path] if o])
if num_selected_configuration > 1:
@@ -208,30 +177,6 @@ def get_conn(self) -> Any:
self.disable_tcp_keepalive, _get_bool(self._get_field("disable_tcp_keepalive"))
)
- # BEGIN apply settings from core kubernetes configuration
- # this section should be removed in next major release
- deprecation_warnings: List[Tuple[str, Any]] = []
- if disable_verify_ssl is None and self._deprecated_core_disable_verify_ssl is True:
- deprecation_warnings.append(('verify_ssl', False))
- disable_verify_ssl = self._deprecated_core_disable_verify_ssl
- # by default, hook will try in_cluster first. so we only need to
- # apply core airflow config and alert when False and in_cluster not otherwise set.
- if in_cluster is None and self._deprecated_core_in_cluster is False:
- deprecation_warnings.append(('in_cluster', self._deprecated_core_in_cluster))
- in_cluster = self._deprecated_core_in_cluster
- if not cluster_context and self._deprecated_core_cluster_context:
- deprecation_warnings.append(('cluster_context', self._deprecated_core_cluster_context))
- cluster_context = self._deprecated_core_cluster_context
- if not kubeconfig_path and self._deprecated_core_config_file:
- deprecation_warnings.append(('config_file', self._deprecated_core_config_file))
- kubeconfig_path = self._deprecated_core_config_file
- if disable_tcp_keepalive is None and self._deprecated_core_disable_tcp_keepalive is True:
- deprecation_warnings.append(('enable_tcp_keepalive', False))
- disable_tcp_keepalive = True
- if deprecation_warnings:
- self._deprecation_warning_core_param(deprecation_warnings)
- # END apply settings from core kubernetes configuration
-
if disable_verify_ssl is True:
_disable_verify_ssl()
if disable_tcp_keepalive is not True:
@@ -239,11 +184,13 @@ def get_conn(self) -> Any:
if in_cluster:
self.log.debug("loading kube_config from: in_cluster configuration")
+ self._is_in_cluster = True
config.load_incluster_config()
return client.ApiClient()
if kubeconfig_path is not None:
self.log.debug("loading kube_config from: %s", kubeconfig_path)
+ self._is_in_cluster = False
config.load_kube_config(
config_file=kubeconfig_path,
client_configuration=self.client_configuration,
@@ -256,6 +203,7 @@ def get_conn(self) -> Any:
self.log.debug("loading kube_config from: connection kube_config")
temp_config.write(kubeconfig.encode())
temp_config.flush()
+ self._is_in_cluster = False
config.load_kube_config(
config_file=temp_config.name,
client_configuration=self.client_configuration,
@@ -265,36 +213,47 @@ def get_conn(self) -> Any:
return self._get_default_client(cluster_context=cluster_context)
- def _get_default_client(self, *, cluster_context=None):
+ def _get_default_client(self, *, cluster_context: str | None = None) -> client.ApiClient:
# if we get here, then no configuration has been supplied
# we should try in_cluster since that's most likely
# but failing that just load assuming a kubeconfig file
# in the default location
try:
config.load_incluster_config(client_configuration=self.client_configuration)
+ self._is_in_cluster = True
except ConfigException:
self.log.debug("loading kube_config from: default file")
+ self._is_in_cluster = False
config.load_kube_config(
client_configuration=self.client_configuration,
context=cluster_context,
)
return client.ApiClient()
+ @property
+ def is_in_cluster(self) -> bool:
+ """Expose whether the hook is configured with ``load_incluster_config`` or not"""
+ if self._is_in_cluster is not None:
+ return self._is_in_cluster
+ self.api_client # so we can determine if we are in_cluster or not
+ if TYPE_CHECKING:
+ assert self._is_in_cluster is not None
+ return self._is_in_cluster
+
@cached_property
- def api_client(self) -> Any:
+ def api_client(self) -> client.ApiClient:
"""Cached Kubernetes API client"""
return self.get_conn()
@cached_property
- def core_v1_client(self):
+ def core_v1_client(self) -> client.CoreV1Api:
return client.CoreV1Api(api_client=self.api_client)
def create_custom_object(
- self, group: str, version: str, plural: str, body: Union[str, dict], namespace: Optional[str] = None
+ self, group: str, version: str, plural: str, body: str | dict, namespace: str | None = None
):
"""
Creates custom resource definition object in Kubernetes
-
:param group: api group
:param version: api version
:param plural: api plural
@@ -302,35 +261,40 @@ def create_custom_object(
:param namespace: kubernetes namespace
"""
api = client.CustomObjectsApi(self.api_client)
- if namespace is None:
- namespace = self.get_namespace()
+ namespace = namespace or self._get_namespace() or self.DEFAULT_NAMESPACE
+
if isinstance(body, str):
body_dict = _load_body_to_dict(body)
else:
body_dict = body
- try:
- api.delete_namespaced_custom_object(
- group=group,
- version=version,
- namespace=namespace,
- plural=plural,
- name=body_dict["metadata"]["name"],
- )
- self.log.warning("Deleted SparkApplication with the same name.")
- except client.rest.ApiException:
- self.log.info("SparkApp %s not found.", body_dict['metadata']['name'])
+
+ # Attribute "name" is not mandatory if "generateName" is used instead
+ if "name" in body_dict["metadata"]:
+ try:
+ api.delete_namespaced_custom_object(
+ group=group,
+ version=version,
+ namespace=namespace,
+ plural=plural,
+ name=body_dict["metadata"]["name"],
+ )
+
+ self.log.warning("Deleted SparkApplication with the same name")
+ except client.rest.ApiException:
+ self.log.info("SparkApplication %s not found", body_dict["metadata"]["name"])
try:
response = api.create_namespaced_custom_object(
group=group, version=version, namespace=namespace, plural=plural, body=body_dict
)
+
self.log.debug("Response: %s", response)
return response
except client.rest.ApiException as e:
raise AirflowException(f"Exception when calling -> create_custom_object: {e}\n")
def get_custom_object(
- self, group: str, version: str, plural: str, name: str, namespace: Optional[str] = None
+ self, group: str, version: str, plural: str, name: str, namespace: str | None = None
):
"""
Get custom resource definition object from Kubernetes
@@ -342,8 +306,7 @@ def get_custom_object(
:param namespace: kubernetes namespace
"""
api = client.CustomObjectsApi(self.api_client)
- if namespace is None:
- namespace = self.get_namespace()
+ namespace = namespace or self._get_namespace() or self.DEFAULT_NAMESPACE
try:
response = api.get_namespaced_custom_object(
group=group, version=version, namespace=namespace, plural=plural, name=name
@@ -352,21 +315,45 @@ def get_custom_object(
except client.rest.ApiException as e:
raise AirflowException(f"Exception when calling -> get_custom_object: {e}\n")
- def get_namespace(self) -> Optional[str]:
- """Returns the namespace that defined in the connection"""
+ def get_namespace(self) -> str | None:
+ """
+ Returns the namespace defined in the connection or 'default'.
+
+ TODO: in provider version 6.0, return None when namespace not defined in connection
+ """
+ namespace = self._get_namespace()
+ if self.conn_id and not namespace:
+ warnings.warn(
+ "Airflow connection defined but namespace is not set; returning 'default'. In "
+ "cncf.kubernetes provider version 6.0 we will return None when namespace is "
+ "not defined in the connection so that it's clear whether user intends 'default' or "
+ "whether namespace is unset (which is required in order to apply precedence logic in "
+ "KubernetesPodOperator).",
+ DeprecationWarning,
+ )
+ return "default"
+ return namespace
+
+ def _get_namespace(self) -> str | None:
+ """
+ Returns the namespace that defined in the connection
+
+ TODO: in provider version 6.0, get rid of this method and make it the behavior of get_namespace.
+ """
if self.conn_id:
- connection = self.get_connection(self.conn_id)
- extras = connection.extra_dejson
- namespace = extras.get("extra__kubernetes__namespace", "default")
- return namespace
+ return self._get_field("namespace")
return None
+ def get_xcom_sidecar_container_image(self):
+ """Returns the xcom sidecar image that defined in the connection"""
+ return self._get_field("xcom_sidecar_container_image")
+
def get_pod_log_stream(
self,
pod_name: str,
- container: Optional[str] = "",
- namespace: Optional[str] = None,
- ) -> Tuple[watch.Watch, Generator[str, None, None]]:
+ container: str | None = "",
+ namespace: str | None = None,
+ ) -> tuple[watch.Watch, Generator[str, None, None]]:
"""
Retrieves a log stream for a container in a kubernetes pod.
@@ -374,23 +361,22 @@ def get_pod_log_stream(
:param container: container name
:param namespace: kubernetes namespace
"""
- api = client.CoreV1Api(self.api_client)
watcher = watch.Watch()
return (
watcher,
watcher.stream(
- api.read_namespaced_pod_log,
+ self.core_v1_client.read_namespaced_pod_log,
name=pod_name,
container=container,
- namespace=namespace if namespace else self.get_namespace(),
+ namespace=namespace or self._get_namespace() or self.DEFAULT_NAMESPACE,
),
)
def get_pod_logs(
self,
pod_name: str,
- container: Optional[str] = "",
- namespace: Optional[str] = None,
+ container: str | None = "",
+ namespace: str | None = None,
):
"""
Retrieves a container's log from the specified pod.
@@ -399,16 +385,15 @@ def get_pod_logs(
:param container: container name
:param namespace: kubernetes namespace
"""
- api = client.CoreV1Api(self.api_client)
- return api.read_namespaced_pod_log(
+ return self.core_v1_client.read_namespaced_pod_log(
name=pod_name,
container=container,
_preload_content=False,
- namespace=namespace if namespace else self.get_namespace(),
+ namespace=namespace or self._get_namespace() or self.DEFAULT_NAMESPACE,
)
-def _get_bool(val) -> Optional[bool]:
+def _get_bool(val) -> bool | None:
"""
Converts val to bool if can be done with certainty.
If we cannot infer intention we return None.
@@ -416,8 +401,8 @@ def _get_bool(val) -> Optional[bool]:
if isinstance(val, bool):
return val
elif isinstance(val, str):
- if val.strip().lower() == 'true':
+ if val.strip().lower() == "true":
return True
- elif val.strip().lower() == 'false':
+ elif val.strip().lower() == "false":
return False
return None
diff --git a/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py b/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
index 69f120e823b3b..c1735158ec1ec 100644
--- a/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
+++ b/airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
@@ -15,17 +15,18 @@
# specific language governing permissions and limitations
# under the License.
"""Executes task in a Kubernetes POD"""
+from __future__ import annotations
+
import json
import logging
import re
-import sys
import warnings
from contextlib import AbstractContextManager
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence
+from typing import TYPE_CHECKING, Any, Sequence
from kubernetes.client import CoreV1Api, models as k8s
-from airflow.configuration import conf
+from airflow.compat.functools import cached_property
from airflow.exceptions import AirflowException
from airflow.kubernetes import pod_generator
from airflow.kubernetes.pod_generator import PodGenerator
@@ -38,7 +39,6 @@
convert_image_pull_secrets,
convert_pod_runtime_info_env,
convert_port,
- convert_resources,
convert_toleration,
convert_volume,
convert_volume_mount,
@@ -56,17 +56,37 @@
from airflow.utils.helpers import prune_dict, validate_key
from airflow.version import version as airflow_version
-if sys.version_info >= (3, 8):
- from functools import cached_property
-else:
- from cached_property import cached_property
-
if TYPE_CHECKING:
import jinja2
from airflow.utils.context import Context
+def _task_id_to_pod_name(val: str) -> str:
+ """
+ Given a task_id, convert it to a pod name.
+ Adds a 0 if start or end char is invalid.
+ Replaces any other invalid char with `-`.
+
+ :param val: non-empty string, presumed to be a task id
+ :return valid kubernetes object name.
+ """
+ if not val:
+ raise ValueError("_task_id_to_pod_name requires non-empty string.")
+ val = val.lower()
+ if not re.match(r"[a-z0-9]", val[0]):
+ val = f"0{val}"
+ if not re.match(r"[a-z0-9]", val[-1]):
+ val = f"{val}0"
+ val = re.sub(r"[^a-z0-9\-.]", "-", val)
+ if len(val) > 253:
+ raise ValueError(
+ f"Pod name {val} is longer than 253 characters. "
+ "See https://kubernetes.io/docs/concepts/overview/working-with-objects/names/."
+ )
+ return val
+
+
class PodReattachFailure(AirflowException):
"""When we expect to be able to find a pod but cannot."""
@@ -102,6 +122,7 @@ class KubernetesPodOperator(BaseOperator):
:param volume_mounts: volumeMounts for the launched pod.
:param volumes: volumes for the launched pod. Includes ConfigMaps and PersistentVolumes.
:param env_vars: Environment variables initialized in the container. (templated)
+ :param env_from: (Optional) List of sources to populate environment variables in the container.
:param secrets: Kubernetes secrets to inject in the container.
They can be exposed as environment vars or files in a volume.
:param in_cluster: run kubernetes client with in_cluster configuration.
@@ -116,7 +137,7 @@ class KubernetesPodOperator(BaseOperator):
:param annotations: non-identifying metadata you can attach to the Pod.
Can be a large range of data, and can include characters
that are not permitted by labels.
- :param resources: resources for the launched pod.
+ :param container_resources: resources for the launched pod. (templated)
:param affinity: affinity scheduling rules for the launched pod.
:param config_file: The path to the Kubernetes config file. (templated)
If not specified, default value is ``~/.kube/config``
@@ -131,6 +152,7 @@ class KubernetesPodOperator(BaseOperator):
:param hostnetwork: If True enable host networking on the pod.
:param tolerations: A list of kubernetes tolerations.
:param security_context: security options the pod should run with (PodSecurityContext).
+ :param container_security_context: security options the container should run with.
:param dnspolicy: dnspolicy for the pod.
:param schedulername: Specify a schedulername for the pod
:param full_pod_spec: The complete podSpec
@@ -141,76 +163,95 @@ class KubernetesPodOperator(BaseOperator):
XCom when the container completes.
:param pod_template_file: path to pod template file (templated)
:param priority_class_name: priority class name for the launched Pod
+ :param pod_runtime_info_envs: (Optional) A list of environment variables,
+ to be set in the container.
:param termination_grace_period: Termination grace period if task killed in UI,
defaults to kubernetes default
- :param: kubernetes_conn_id: To retrieve credentials for your k8s cluster from an Airflow connection
+ :param configmaps: (Optional) A list of names of config maps from which it collects ConfigMaps
+ to populate the environment variables with. The contents of the target
+ ConfigMap's Data field will represent the key-value pairs as environment variables.
+ Extends env_from.
"""
- BASE_CONTAINER_NAME = 'base'
- POD_CHECKED_KEY = 'already_checked'
+ BASE_CONTAINER_NAME = "base"
+ POD_CHECKED_KEY = "already_checked"
template_fields: Sequence[str] = (
- 'image',
- 'cmds',
- 'arguments',
- 'env_vars',
- 'labels',
- 'config_file',
- 'pod_template_file',
- 'namespace',
+ "image",
+ "cmds",
+ "arguments",
+ "env_vars",
+ "labels",
+ "config_file",
+ "pod_template_file",
+ "namespace",
+ "container_resources",
)
+ template_fields_renderers = {"env_vars": "py"}
def __init__(
self,
*,
- kubernetes_conn_id: Optional[str] = None, # 'kubernetes_default',
- namespace: Optional[str] = None,
- image: Optional[str] = None,
- name: Optional[str] = None,
- random_name_suffix: Optional[bool] = True,
- cmds: Optional[List[str]] = None,
- arguments: Optional[List[str]] = None,
- ports: Optional[List[k8s.V1ContainerPort]] = None,
- volume_mounts: Optional[List[k8s.V1VolumeMount]] = None,
- volumes: Optional[List[k8s.V1Volume]] = None,
- env_vars: Optional[List[k8s.V1EnvVar]] = None,
- env_from: Optional[List[k8s.V1EnvFromSource]] = None,
- secrets: Optional[List[Secret]] = None,
- in_cluster: Optional[bool] = None,
- cluster_context: Optional[str] = None,
- labels: Optional[Dict] = None,
+ kubernetes_conn_id: str | None = None, # 'kubernetes_default',
+ namespace: str | None = None,
+ image: str | None = None,
+ name: str | None = None,
+ random_name_suffix: bool | None = True,
+ cmds: list[str] | None = None,
+ arguments: list[str] | None = None,
+ ports: list[k8s.V1ContainerPort] | None = None,
+ volume_mounts: list[k8s.V1VolumeMount] | None = None,
+ volumes: list[k8s.V1Volume] | None = None,
+ env_vars: list[k8s.V1EnvVar] | None = None,
+ env_from: list[k8s.V1EnvFromSource] | None = None,
+ secrets: list[Secret] | None = None,
+ in_cluster: bool | None = None,
+ cluster_context: str | None = None,
+ labels: dict | None = None,
reattach_on_restart: bool = True,
startup_timeout_seconds: int = 120,
get_logs: bool = True,
- image_pull_policy: Optional[str] = None,
- annotations: Optional[Dict] = None,
- resources: Optional[k8s.V1ResourceRequirements] = None,
- affinity: Optional[k8s.V1Affinity] = None,
- config_file: Optional[str] = None,
- node_selectors: Optional[dict] = None,
- node_selector: Optional[dict] = None,
- image_pull_secrets: Optional[List[k8s.V1LocalObjectReference]] = None,
- service_account_name: Optional[str] = None,
+ image_pull_policy: str | None = None,
+ annotations: dict | None = None,
+ container_resources: k8s.V1ResourceRequirements | None = None,
+ affinity: k8s.V1Affinity | None = None,
+ config_file: str | None = None,
+ node_selector: dict | None = None,
+ image_pull_secrets: list[k8s.V1LocalObjectReference] | None = None,
+ service_account_name: str | None = None,
is_delete_operator_pod: bool = True,
hostnetwork: bool = False,
- tolerations: Optional[List[k8s.V1Toleration]] = None,
- security_context: Optional[Dict] = None,
- dnspolicy: Optional[str] = None,
- schedulername: Optional[str] = None,
- full_pod_spec: Optional[k8s.V1Pod] = None,
- init_containers: Optional[List[k8s.V1Container]] = None,
+ tolerations: list[k8s.V1Toleration] | None = None,
+ security_context: dict | None = None,
+ container_security_context: dict | None = None,
+ dnspolicy: str | None = None,
+ schedulername: str | None = None,
+ full_pod_spec: k8s.V1Pod | None = None,
+ init_containers: list[k8s.V1Container] | None = None,
log_events_on_failure: bool = False,
do_xcom_push: bool = False,
- pod_template_file: Optional[str] = None,
- priority_class_name: Optional[str] = None,
- pod_runtime_info_envs: Optional[List[k8s.V1EnvVar]] = None,
- termination_grace_period: Optional[int] = None,
- configmaps: Optional[List[str]] = None,
+ pod_template_file: str | None = None,
+ priority_class_name: str | None = None,
+ pod_runtime_info_envs: list[k8s.V1EnvVar] | None = None,
+ termination_grace_period: int | None = None,
+ configmaps: list[str] | None = None,
**kwargs,
) -> None:
- if kwargs.get('xcom_push') is not None:
- raise AirflowException("'xcom_push' was deprecated, use 'do_xcom_push' instead")
- super().__init__(resources=None, **kwargs)
+ # TODO: remove in provider 6.0.0 release. This is a mitigate step to advise users to switch to the
+ # container_resources parameter.
+ if isinstance(kwargs.get("resources"), k8s.V1ResourceRequirements):
+ raise AirflowException(
+ "Specifying resources for the launched pod with 'resources' is deprecated. "
+ "Use 'container_resources' instead."
+ )
+ # TODO: remove in provider 6.0.0 release. This is a mitigate step to advise users to switch to the
+ # node_selector parameter.
+ if "node_selectors" in kwargs:
+ raise ValueError(
+ "Param `node_selectors` supplied. This param is no longer supported. "
+ "Use `node_selector` instead."
+ )
+ super().__init__(**kwargs)
self.kubernetes_conn_id = kubernetes_conn_id
self.do_xcom_push = do_xcom_push
self.image = image
@@ -234,19 +275,10 @@ def __init__(
self.reattach_on_restart = reattach_on_restart
self.get_logs = get_logs
self.image_pull_policy = image_pull_policy
- if node_selectors:
- # Node selectors is incorrect based on k8s API
- warnings.warn(
- "node_selectors is deprecated. Please use node_selector instead.", DeprecationWarning
- )
- self.node_selector = node_selectors
- elif node_selector:
- self.node_selector = node_selector
- else:
- self.node_selector = {}
+ self.node_selector = node_selector or {}
self.annotations = annotations or {}
self.affinity = convert_affinity(affinity) if affinity else {}
- self.k8s_resources = convert_resources(resources) if resources else {}
+ self.container_resources = container_resources
self.config_file = config_file
self.image_pull_secrets = convert_image_pull_secrets(image_pull_secrets) if image_pull_secrets else []
self.service_account_name = service_account_name
@@ -256,6 +288,7 @@ def __init__(
[convert_toleration(toleration) for toleration in tolerations] if tolerations else []
)
self.security_context = security_context or {}
+ self.container_security_context = container_security_context
self.dnspolicy = dnspolicy
self.schedulername = schedulername
self.full_pod_spec = full_pod_spec
@@ -266,25 +299,37 @@ def __init__(
self.name = self._set_name(name)
self.random_name_suffix = random_name_suffix
self.termination_grace_period = termination_grace_period
- self.pod_request_obj: Optional[k8s.V1Pod] = None
- self.pod: Optional[k8s.V1Pod] = None
+ self.pod_request_obj: k8s.V1Pod | None = None
+ self.pod: k8s.V1Pod | None = None
+
+ @cached_property
+ def _incluster_namespace(self):
+ from pathlib import Path
+
+ path = Path("/var/run/secrets/kubernetes.io/serviceaccount/namespace")
+ return path.exists() and path.read_text() or None
def _render_nested_template_fields(
self,
content: Any,
- context: 'Context',
- jinja_env: "jinja2.Environment",
+ context: Context,
+ jinja_env: jinja2.Environment,
seen_oids: set,
) -> None:
if id(content) not in seen_oids and isinstance(content, k8s.V1EnvVar):
seen_oids.add(id(content))
- self._do_render_template_fields(content, ('value', 'name'), context, jinja_env, seen_oids)
+ self._do_render_template_fields(content, ("value", "name"), context, jinja_env, seen_oids)
+ return
+
+ if id(content) not in seen_oids and isinstance(content, k8s.V1ResourceRequirements):
+ seen_oids.add(id(content))
+ self._do_render_template_fields(content, ("limits", "requests"), context, jinja_env, seen_oids)
return
super()._render_nested_template_fields(content, context, jinja_env, seen_oids)
@staticmethod
- def _get_ti_pod_labels(context: Optional[dict] = None, include_try_number: bool = True) -> dict:
+ def _get_ti_pod_labels(context: Context | None = None, include_try_number: bool = True) -> dict[str, str]:
"""
Generate labels for the pod to track the pod in case of Operator crash
@@ -294,26 +339,25 @@ def _get_ti_pod_labels(context: Optional[dict] = None, include_try_number: bool
if not context:
return {}
- ti = context['ti']
- run_id = context['run_id']
+ ti = context["ti"]
+ run_id = context["run_id"]
labels = {
- 'dag_id': ti.dag_id,
- 'task_id': ti.task_id,
- 'run_id': run_id,
- 'kubernetes_pod_operator': 'True',
+ "dag_id": ti.dag_id,
+ "task_id": ti.task_id,
+ "run_id": run_id,
+ "kubernetes_pod_operator": "True",
}
- # If running on Airflow 2.3+:
- map_index = getattr(ti, 'map_index', -1)
+ map_index = ti.map_index
if map_index >= 0:
- labels['map_index'] = map_index
+ labels["map_index"] = map_index
if include_try_number:
labels.update(try_number=ti.try_number)
# In the case of sub dags this is just useful
- if context['dag'].is_subdag:
- labels['parent_dag_id'] = context['dag'].parent_dag.dag_id
+ if context["dag"].parent_dag:
+ labels["parent_dag_id"] = context["dag"].parent_dag.dag_id
# Ensure that label is valid for Kube,
# and if not truncate/remove invalid chars and replace with short hash.
for label_id, label in labels.items():
@@ -326,21 +370,24 @@ def pod_manager(self) -> PodManager:
return PodManager(kube_client=self.client)
def get_hook(self):
+ warnings.warn("get_hook is deprecated. Please use hook instead.", DeprecationWarning, stacklevel=2)
+ return self.hook
+
+ @cached_property
+ def hook(self) -> KubernetesHook:
hook = KubernetesHook(
conn_id=self.kubernetes_conn_id,
in_cluster=self.in_cluster,
config_file=self.config_file,
cluster_context=self.cluster_context,
)
- self._patch_deprecated_k8s_settings(hook)
return hook
@cached_property
def client(self) -> CoreV1Api:
- hook = self.get_hook()
- return hook.core_v1_client
+ return self.hook.core_v1_client
- def find_pod(self, namespace, context, *, exclude_checked=True) -> Optional[k8s.V1Pod]:
+ def find_pod(self, namespace: str, context: Context, *, exclude_checked: bool = True) -> k8s.V1Pod | None:
"""Returns an already-running pod for this task instance if one exists."""
label_selector = self._build_find_pod_label_selector(context, exclude_checked=exclude_checked)
pod_list = self.client.list_namespaced_pod(
@@ -351,15 +398,15 @@ def find_pod(self, namespace, context, *, exclude_checked=True) -> Optional[k8s.
pod = None
num_pods = len(pod_list)
if num_pods > 1:
- raise AirflowException(f'More than one pod running with labels {label_selector}')
+ raise AirflowException(f"More than one pod running with labels {label_selector}")
elif num_pods == 1:
pod = pod_list[0]
self.log.info("Found matching pod %s with labels %s", pod.metadata.name, pod.metadata.labels)
- self.log.info("`try_number` of task_instance: %s", context['ti'].try_number)
- self.log.info("`try_number` of pod: %s", pod.metadata.labels['try_number'])
+ self.log.info("`try_number` of task_instance: %s", context["ti"].try_number)
+ self.log.info("`try_number` of pod: %s", pod.metadata.labels["try_number"])
return pod
- def get_or_create_pod(self, pod_request_obj: k8s.V1Pod, context):
+ def get_or_create_pod(self, pod_request_obj: k8s.V1Pod, context: Context) -> k8s.V1Pod:
if self.reattach_on_restart:
pod = self.find_pod(self.namespace or pod_request_obj.metadata.namespace, context=context)
if pod:
@@ -368,7 +415,7 @@ def get_or_create_pod(self, pod_request_obj: k8s.V1Pod, context):
self.pod_manager.create_pod(pod=pod_request_obj)
return pod_request_obj
- def await_pod_start(self, pod):
+ def await_pod_start(self, pod: k8s.V1Pod):
try:
self.pod_manager.await_pod_start(pod=pod, startup_timeout=self.startup_timeout_seconds)
except PodLaunchFailedException:
@@ -377,13 +424,17 @@ def await_pod_start(self, pod):
self.log.error("Pod Event: %s - %s", event.reason, event.message)
raise
- def extract_xcom(self, pod):
+ def extract_xcom(self, pod: k8s.V1Pod):
"""Retrieves xcom value and kills xcom sidecar container"""
result = self.pod_manager.extract_xcom(pod)
- self.log.info("xcom result: \n%s", result)
- return json.loads(result)
+ if isinstance(result, str) and result.rstrip() == "__airflow_xcom_result_empty__":
+ self.log.info("Result file is empty.")
+ return None
+ else:
+ self.log.info("xcom result: \n%s", result)
+ return json.loads(result)
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
remote_pod = None
try:
self.pod_request_obj = self.build_pod_request_obj(context)
@@ -391,6 +442,8 @@ def execute(self, context: 'Context'):
pod_request_obj=self.pod_request_obj,
context=context,
)
+ # get remote pod for use in cleanup methods
+ remote_pod = self.find_pod(self.pod.metadata.namespace, context=context)
self.await_pod_start(pod=self.pod)
if self.get_logs:
@@ -405,6 +458,7 @@ def execute(self, context: 'Context'):
)
if self.do_xcom_push:
+ self.pod_manager.await_xcom_sidecar_container_start(pod=self.pod)
result = self.extract_xcom(pod=self.pod)
remote_pod = self.pod_manager.await_pod_completion(self.pod)
finally:
@@ -412,14 +466,14 @@ def execute(self, context: 'Context'):
pod=self.pod or self.pod_request_obj,
remote_pod=remote_pod,
)
- ti = context['ti']
- ti.xcom_push(key='pod_name', value=self.pod.metadata.name)
- ti.xcom_push(key='pod_namespace', value=self.pod.metadata.namespace)
+ ti = context["ti"]
+ ti.xcom_push(key="pod_name", value=self.pod.metadata.name)
+ ti.xcom_push(key="pod_namespace", value=self.pod.metadata.namespace)
if self.do_xcom_push:
return result
def cleanup(self, pod: k8s.V1Pod, remote_pod: k8s.V1Pod):
- pod_phase = remote_pod.status.phase if hasattr(remote_pod, 'status') else None
+ pod_phase = remote_pod.status.phase if hasattr(remote_pod, "status") else None
if not self.is_delete_operator_pod:
with _suppress(Exception):
self.patch_already_checked(remote_pod)
@@ -429,40 +483,39 @@ def cleanup(self, pod: k8s.V1Pod, remote_pod: k8s.V1Pod):
for event in self.pod_manager.read_pod_events(pod).items:
self.log.error("Pod Event: %s - %s", event.reason, event.message)
with _suppress(Exception):
- self.process_pod_deletion(pod)
+ self.process_pod_deletion(remote_pod)
error_message = get_container_termination_message(remote_pod, self.BASE_CONTAINER_NAME)
error_message = "\n" + error_message if error_message else ""
raise AirflowException(
- f'Pod {pod and pod.metadata.name} returned a failure:{error_message}\n{remote_pod}'
+ f"Pod {pod and pod.metadata.name} returned a failure:{error_message}\n{remote_pod}"
)
else:
with _suppress(Exception):
- self.process_pod_deletion(pod)
+ self.process_pod_deletion(remote_pod)
- def process_pod_deletion(self, pod):
- if self.is_delete_operator_pod:
- self.log.info("Deleting pod: %s", pod.metadata.name)
- self.pod_manager.delete_pod(pod)
- else:
- self.log.info("skipping deleting pod: %s", pod.metadata.name)
+ def process_pod_deletion(self, pod: k8s.V1Pod):
+ if pod is not None:
+ if self.is_delete_operator_pod:
+ self.log.info("Deleting pod: %s", pod.metadata.name)
+ self.pod_manager.delete_pod(pod)
+ else:
+ self.log.info("skipping deleting pod: %s", pod.metadata.name)
- def _build_find_pod_label_selector(self, context: Optional[dict] = None, *, exclude_checked=True) -> str:
+ def _build_find_pod_label_selector(self, context: Context | None = None, *, exclude_checked=True) -> str:
labels = self._get_ti_pod_labels(context, include_try_number=False)
- label_strings = [f'{label_id}={label}' for label_id, label in sorted(labels.items())]
- labels_value = ','.join(label_strings)
+ label_strings = [f"{label_id}={label}" for label_id, label in sorted(labels.items())]
+ labels_value = ",".join(label_strings)
if exclude_checked:
- labels_value += f',{self.POD_CHECKED_KEY}!=True'
- labels_value += ',!airflow-worker'
+ labels_value += f",{self.POD_CHECKED_KEY}!=True"
+ labels_value += ",!airflow-worker"
return labels_value
- def _set_name(self, name):
- if name is None:
- if self.pod_template_file or self.full_pod_spec:
- return None
- raise AirflowException("`name` is required unless `pod_template_file` or `full_pod_spec` is set")
-
- validate_key(name, max_length=220)
- return re.sub(r'[^a-z0-9-]+', '-', name.lower())
+ @staticmethod
+ def _set_name(name: str | None) -> str | None:
+ if name is not None:
+ validate_key(name, max_length=220)
+ return re.sub(r"[^a-z0-9-]+", "-", name.lower())
+ return None
def patch_already_checked(self, pod: k8s.V1Pod):
"""Add an "already checked" annotation to ensure we don't reattach on retries"""
@@ -481,7 +534,7 @@ def on_kill(self) -> None:
kwargs.update(grace_period_seconds=self.termination_grace_period)
self.client.delete_namespaced_pod(**kwargs)
- def build_pod_request_obj(self, context=None):
+ def build_pod_request_obj(self, context: Context | None = None) -> k8s.V1Pod:
"""
Returns V1Pod object based on pod template file, full pod spec, and other operator parameters.
@@ -497,7 +550,7 @@ def build_pod_request_obj(self, context=None):
elif self.full_pod_spec:
pod_template = self.full_pod_spec
else:
- pod_template = k8s.V1Pod(metadata=k8s.V1ObjectMeta(name="name"))
+ pod_template = k8s.V1Pod(metadata=k8s.V1ObjectMeta())
pod = k8s.V1Pod(
api_version="v1",
@@ -520,11 +573,12 @@ def build_pod_request_obj(self, context=None):
command=self.cmds,
ports=self.ports,
image_pull_policy=self.image_pull_policy,
- resources=self.k8s_resources,
+ resources=self.container_resources,
volume_mounts=self.volume_mounts,
args=self.arguments,
env=self.env_vars,
env_from=self.env_from,
+ security_context=self.container_security_context,
)
],
image_pull_secrets=self.image_pull_secrets,
@@ -533,7 +587,7 @@ def build_pod_request_obj(self, context=None):
security_context=self.security_context,
dns_policy=self.dnspolicy,
scheduler_name=self.schedulername,
- restart_policy='Never',
+ restart_policy="Never",
priority_class_name=self.priority_class_name,
volumes=self.volumes,
),
@@ -541,18 +595,30 @@ def build_pod_request_obj(self, context=None):
pod = PodGenerator.reconcile_pods(pod_template, pod)
+ if not pod.metadata.name:
+ pod.metadata.name = _task_id_to_pod_name(self.task_id)
+
if self.random_name_suffix:
pod.metadata.name = PodGenerator.make_unique_pod_id(pod.metadata.name)
+ if not pod.metadata.namespace:
+ # todo: replace with call to `hook.get_namespace` in 6.0, when it doesn't default to `default`.
+ # if namespace not actually defined in hook, we want to check k8s if in cluster
+ hook_namespace = self.hook._get_namespace()
+ pod_namespace = self.namespace or hook_namespace or self._incluster_namespace or "default"
+ pod.metadata.namespace = pod_namespace
+
for secret in self.secrets:
self.log.debug("Adding secret to task %s", self.task_id)
pod = secret.attach_to_pod(pod)
if self.do_xcom_push:
self.log.debug("Adding xcom sidecar to task %s", self.task_id)
- pod = xcom_sidecar.add_xcom_sidecar(pod)
+ pod = xcom_sidecar.add_xcom_sidecar(
+ pod, sidecar_container_image=self.hook.get_xcom_sidecar_container_image()
+ )
labels = self._get_ti_pod_labels(context)
- self.log.info("Creating pod %s with labels: %s", pod.metadata.name, labels)
+ self.log.info("Building pod %s with labels: %s", pod.metadata.name, labels)
# Merge Pod Identifying labels with labels passed to operator
pod.metadata.labels.update(labels)
@@ -560,7 +626,8 @@ def build_pod_request_obj(self, context=None):
# And a label to identify that pod is launched by KubernetesPodOperator
pod.metadata.labels.update(
{
- 'airflow_version': airflow_version.replace('+', '-'),
+ "airflow_version": airflow_version.replace("+", "-"),
+ "airflow_kpo_in_cluster": str(self.hook.is_in_cluster),
}
)
pod_mutation_hook(pod)
@@ -573,40 +640,7 @@ def dry_run(self) -> None:
one in a dry_run) and excludes all empty elements.
"""
pod = self.build_pod_request_obj()
- print(yaml.dump(prune_dict(pod.to_dict(), mode='strict')))
-
- def _patch_deprecated_k8s_settings(self, hook: KubernetesHook):
- """
- Here we read config from core Airflow config [kubernetes] section.
- In a future release we will stop looking at this section and require users
- to use Airflow connections to configure KPO.
-
- When we find values there that we need to apply on the hook, we patch special
- hook attributes here.
- """
-
- # default for enable_tcp_keepalive is True; patch if False
- if conf.getboolean('kubernetes', 'enable_tcp_keepalive') is False:
- hook._deprecated_core_disable_tcp_keepalive = True
-
- # default verify_ssl is True; patch if False.
- if conf.getboolean('kubernetes', 'verify_ssl') is False:
- hook._deprecated_core_disable_verify_ssl = True
-
- # default for in_cluster is True; patch if False and no KPO param.
- conf_in_cluster = conf.getboolean('kubernetes', 'in_cluster')
- if self.in_cluster is None and conf_in_cluster is False:
- hook._deprecated_core_in_cluster = conf_in_cluster
-
- # there's no default for cluster context; if we get something (and no KPO param) patch it.
- conf_cluster_context = conf.get('kubernetes', 'cluster_context', fallback=None)
- if not self.cluster_context and conf_cluster_context:
- hook._deprecated_core_cluster_context = conf_cluster_context
-
- # there's no default for config_file; if we get something (and no KPO param) patch it.
- conf_config_file = conf.get('kubernetes', 'config_file', fallback=None)
- if not self.config_file and conf_config_file:
- hook._deprecated_core_config_file = conf_config_file
+ print(yaml.dump(prune_dict(pod.to_dict(), mode="strict")))
class _suppress(AbstractContextManager):
@@ -630,5 +664,5 @@ def __exit__(self, exctype, excinst, exctb):
if caught_error:
self.exception = excinst
logger = logging.getLogger(__name__)
- logger.error(str(excinst), exc_info=True)
+ logger.exception(excinst)
return caught_error
diff --git a/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py b/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py
index fbf0aebd4948c..ff3828ffb0d4a 100644
--- a/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py
+++ b/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py
@@ -15,7 +15,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import TYPE_CHECKING, Optional, Sequence
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Sequence
from airflow.models import BaseOperator
from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook
@@ -41,18 +43,18 @@ class SparkKubernetesOperator(BaseOperator):
:param api_version: kubernetes api version of sparkApplication
"""
- template_fields: Sequence[str] = ('application_file', 'namespace')
- template_ext: Sequence[str] = ('.yaml', '.yml', '.json')
- ui_color = '#f4a460'
+ template_fields: Sequence[str] = ("application_file", "namespace")
+ template_ext: Sequence[str] = (".yaml", ".yml", ".json")
+ ui_color = "#f4a460"
def __init__(
self,
*,
application_file: str,
- namespace: Optional[str] = None,
- kubernetes_conn_id: str = 'kubernetes_default',
- api_group: str = 'sparkoperator.k8s.io',
- api_version: str = 'v1beta2',
+ namespace: str | None = None,
+ kubernetes_conn_id: str = "kubernetes_default",
+ api_group: str = "sparkoperator.k8s.io",
+ api_version: str = "v1beta2",
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -63,7 +65,7 @@ def __init__(
self.api_version = api_version
self.plural = "sparkapplications"
- def execute(self, context: 'Context'):
+ def execute(self, context: Context):
hook = KubernetesHook(conn_id=self.kubernetes_conn_id)
self.log.info("Creating sparkApplication")
response = hook.create_custom_object(
diff --git a/airflow/providers/cncf/kubernetes/provider.yaml b/airflow/providers/cncf/kubernetes/provider.yaml
index ce2ce89eddeec..3de2fb8717c80 100644
--- a/airflow/providers/cncf/kubernetes/provider.yaml
+++ b/airflow/providers/cncf/kubernetes/provider.yaml
@@ -22,6 +22,11 @@ description: |
`Kubernetes `__
versions:
+ - 5.0.0
+ - 4.4.0
+ - 4.3.0
+ - 4.2.0
+ - 4.1.0
- 4.0.2
- 4.0.1
- 4.0.0
@@ -43,8 +48,18 @@ versions:
- 1.0.1
- 1.0.0
-additional-dependencies:
+dependencies:
- apache-airflow>=2.3.0
+ - cryptography>=2.0.0
+ # The Kubernetes API is known to introduce problems when upgraded to a MAJOR version. Airflow Core
+ # Uses Kubernetes for Kubernetes executor, and we also know that Kubernetes Python client follows SemVer
+ # (https://github.com/kubernetes-client/python#compatibility). This is a crucial component of Airflow
+ # So we should limit it to the next MAJOR version and only deliberately bump the version when we
+ # tested it, and we know it can be bumped. Bumping this version should also be connected with
+ # limiting minimum airflow version supported in cncf.kubernetes provider, due to the
+ # potential breaking changes in Airflow Core as well (kubernetes is added as extra, so Airflow
+ # core is not hard-limited via install-requirements, only by extra).
+ - kubernetes>=21.7.0,<24
integrations:
- integration-name: Kubernetes
@@ -74,9 +89,11 @@ hooks:
python-modules:
- airflow.providers.cncf.kubernetes.hooks.kubernetes
-hook-class-names: # deprecated - to be removed after providers add dependency on Airflow 2.2.0+
- - airflow.providers.cncf.kubernetes.hooks.kubernetes.KubernetesHook
connection-types:
- hook-class-name: airflow.providers.cncf.kubernetes.hooks.kubernetes.KubernetesHook
connection-type: kubernetes
+
+task-decorators:
+ - class-name: airflow.providers.cncf.kubernetes.decorators.kubernetes.kubernetes_task
+ name: kubernetes
diff --git a/airflow/providers/cncf/kubernetes/python_kubernetes_script.jinja2 b/airflow/providers/cncf/kubernetes/python_kubernetes_script.jinja2
new file mode 100644
index 0000000000000..c961f10de4e5c
--- /dev/null
+++ b/airflow/providers/cncf/kubernetes/python_kubernetes_script.jinja2
@@ -0,0 +1,44 @@
+{#
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+-#}
+
+import {{ pickling_library }}
+import sys
+
+{# Check whether Airflow is available in the environment.
+ # If it is, we'll want to ensure that we integrate any macros that are being provided
+ # by plugins prior to unpickling the task context. #}
+if sys.version_info >= (3,6):
+ try:
+ from airflow.plugins_manager import integrate_macros_plugins
+ integrate_macros_plugins()
+ except ImportError:
+ {# Airflow is not available in this environment, therefore we won't
+ # be able to integrate any plugin macros. #}
+ pass
+
+{% if op_args or op_kwargs %}
+with open(sys.argv[1], "rb") as file:
+ arg_dict = {{ pickling_library }}.load(file)
+{% else %}
+arg_dict = {"args": [], "kwargs": {}}
+{% endif %}
+
+# Script
+{{ python_callable_source }}
+res = {{ python_callable }}(*arg_dict["args"], **arg_dict["kwargs"])
diff --git a/airflow/providers/cncf/kubernetes/python_kubernetes_script.py b/airflow/providers/cncf/kubernetes/python_kubernetes_script.py
new file mode 100644
index 0000000000000..785daf6e56fb2
--- /dev/null
+++ b/airflow/providers/cncf/kubernetes/python_kubernetes_script.py
@@ -0,0 +1,80 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities for using the kubernetes decorator"""
+from __future__ import annotations
+
+import os
+from collections import deque
+
+import jinja2
+
+
+def _balance_parens(after_decorator):
+ num_paren = 1
+ after_decorator = deque(after_decorator)
+ after_decorator.popleft()
+ while num_paren:
+ current = after_decorator.popleft()
+ if current == "(":
+ num_paren = num_paren + 1
+ elif current == ")":
+ num_paren = num_paren - 1
+ return "".join(after_decorator)
+
+
+def remove_task_decorator(python_source: str, task_decorator_name: str) -> str:
+ """
+ Removed @kubernetes_task
+
+ :param python_source:
+ """
+ if task_decorator_name not in python_source:
+ return python_source
+ split = python_source.split(task_decorator_name)
+ before_decorator, after_decorator = split[0], split[1]
+ if after_decorator[0] == "(":
+ after_decorator = _balance_parens(after_decorator)
+ if after_decorator[0] == "\n":
+ after_decorator = after_decorator[1:]
+ return before_decorator + after_decorator
+
+
+def write_python_script(
+ jinja_context: dict,
+ filename: str,
+ render_template_as_native_obj: bool = False,
+):
+ """
+ Renders the python script to a file to execute in the virtual environment.
+
+ :param jinja_context: The jinja context variables to unpack and replace with its placeholders in the
+ template file.
+ :param filename: The name of the file to dump the rendered script to.
+ :param render_template_as_native_obj: If ``True``, rendered Jinja template would be converted
+ to a native Python object
+ """
+ template_loader = jinja2.FileSystemLoader(searchpath=os.path.dirname(__file__))
+ template_env: jinja2.Environment
+ if render_template_as_native_obj:
+ template_env = jinja2.nativetypes.NativeEnvironment(
+ loader=template_loader, undefined=jinja2.StrictUndefined
+ )
+ else:
+ template_env = jinja2.Environment(loader=template_loader, undefined=jinja2.StrictUndefined)
+ template = template_env.get_template("python_kubernetes_script.jinja2")
+ template.stream(**jinja_context).dump(filename)
diff --git a/airflow/providers/cncf/kubernetes/sensors/spark_kubernetes.py b/airflow/providers/cncf/kubernetes/sensors/spark_kubernetes.py
index 15ac40bcdb90a..6c09202ab8782 100644
--- a/airflow/providers/cncf/kubernetes/sensors/spark_kubernetes.py
+++ b/airflow/providers/cncf/kubernetes/sensors/spark_kubernetes.py
@@ -15,7 +15,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import TYPE_CHECKING, Optional, Sequence
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Sequence
from kubernetes import client
@@ -37,6 +39,7 @@ class SparkKubernetesSensor(BaseSensorOperator):
:param application_name: spark Application resource name
:param namespace: the kubernetes namespace where the sparkApplication reside in
+ :param container_name: the kubernetes container name where the sparkApplication reside in
:param kubernetes_conn_id: The :ref:`kubernetes connection`
to Kubernetes cluster.
:param attach_log: determines whether logs for driver pod should be appended to the sensor log
@@ -53,16 +56,18 @@ def __init__(
*,
application_name: str,
attach_log: bool = False,
- namespace: Optional[str] = None,
+ namespace: str | None = None,
+ container_name: str = "spark-kubernetes-driver",
kubernetes_conn_id: str = "kubernetes_default",
- api_group: str = 'sparkoperator.k8s.io',
- api_version: str = 'v1beta2',
+ api_group: str = "sparkoperator.k8s.io",
+ api_version: str = "v1beta2",
**kwargs,
) -> None:
super().__init__(**kwargs)
self.application_name = application_name
self.attach_log = attach_log
self.namespace = namespace
+ self.container_name = container_name
self.kubernetes_conn_id = kubernetes_conn_id
self.hook = KubernetesHook(conn_id=self.kubernetes_conn_id)
self.api_group = api_group
@@ -82,7 +87,9 @@ def _log_driver(self, application_state: str, response: dict) -> None:
log_method = self.log.error if application_state in self.FAILURE_STATES else self.log.info
try:
log = ""
- for line in self.hook.get_pod_logs(driver_pod_name, namespace=namespace):
+ for line in self.hook.get_pod_logs(
+ driver_pod_name, namespace=namespace, container=self.container_name
+ ):
log += line.decode()
log_method(log)
except client.rest.ApiException as e:
@@ -94,7 +101,7 @@ def _log_driver(self, application_state: str, response: dict) -> None:
e,
)
- def poke(self, context: 'Context') -> bool:
+ def poke(self, context: Context) -> bool:
self.log.info("Poking: %s", self.application_name)
response = self.hook.get_custom_object(
group=self.api_group,
diff --git a/airflow/providers/cncf/kubernetes/utils/__init__.py b/airflow/providers/cncf/kubernetes/utils/__init__.py
index 9cebc80e1d022..84e243c6dbc76 100644
--- a/airflow/providers/cncf/kubernetes/utils/__init__.py
+++ b/airflow/providers/cncf/kubernetes/utils/__init__.py
@@ -14,4 +14,4 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-__all__ = ['xcom_sidecar', 'pod_manager']
+__all__ = ["xcom_sidecar", "pod_manager"]
diff --git a/airflow/providers/cncf/kubernetes/utils/pod_manager.py b/airflow/providers/cncf/kubernetes/utils/pod_manager.py
index 27c9439dbde1d..ffd3e7ec4bd88 100644
--- a/airflow/providers/cncf/kubernetes/utils/pod_manager.py
+++ b/airflow/providers/cncf/kubernetes/utils/pod_manager.py
@@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.
"""Launches PODs"""
+from __future__ import annotations
+
import json
import math
import time
@@ -22,7 +24,7 @@
from contextlib import closing
from dataclasses import dataclass
from datetime import datetime
-from typing import TYPE_CHECKING, Iterable, Optional, Tuple, cast
+from typing import TYPE_CHECKING, Iterable, cast
import pendulum
import tenacity
@@ -60,10 +62,10 @@ class PodPhase:
See https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle/#pod-phase.
"""
- PENDING = 'Pending'
- RUNNING = 'Running'
- FAILED = 'Failed'
- SUCCEEDED = 'Succeeded'
+ PENDING = "Pending"
+ RUNNING = "Running"
+ FAILED = "Failed"
+ SUCCEEDED = "Succeeded"
terminal_states = {FAILED, SUCCEEDED}
@@ -76,7 +78,7 @@ def container_is_running(pod: V1Pod, container_name: str) -> bool:
container_statuses = pod.status.container_statuses if pod and pod.status else None
if not container_statuses:
return False
- container_status = next(iter([x for x in container_statuses if x.name == container_name]), None)
+ container_status = next((x for x in container_statuses if x.name == container_name), None)
if not container_status:
return False
return container_status.state.running is not None
@@ -85,9 +87,9 @@ def container_is_running(pod: V1Pod, container_name: str) -> bool:
def get_container_termination_message(pod: V1Pod, container_name: str):
try:
container_statuses = pod.status.container_statuses
- container_status = next(iter([x for x in container_statuses if x.name == container_name]), None)
+ container_status = next((x for x in container_statuses if x.name == container_name), None)
return container_status.state.terminated.message if container_status else None
- except AttributeError:
+ except (AttributeError, TypeError):
return None
@@ -96,7 +98,7 @@ class PodLoggingStatus:
"""Used for returning the status of the pod and last log time when exiting from `fetch_container_logs`"""
running: bool
- last_log_time: Optional[DateTime]
+ last_log_time: DateTime | None
class PodManager(LoggingMixin):
@@ -109,7 +111,7 @@ def __init__(
self,
kube_client: client.CoreV1Api = None,
in_cluster: bool = True,
- cluster_context: Optional[str] = None,
+ cluster_context: str | None = None,
):
"""
Creates the launcher.
@@ -119,7 +121,16 @@ def __init__(
:param cluster_context: context of the cluster
"""
super().__init__()
- self._client = kube_client or get_kube_client(in_cluster=in_cluster, cluster_context=cluster_context)
+ if kube_client:
+ self._client = kube_client
+ else:
+ self._client = get_kube_client(in_cluster=in_cluster, cluster_context=cluster_context)
+ warnings.warn(
+ "`kube_client` not supplied to PodManager. "
+ "This will be a required argument in a future release. "
+ "Please use KubernetesHook to create the client before calling.",
+ DeprecationWarning,
+ )
self._watch = watch.Watch()
def run_pod_async(self, pod: V1Pod, **kwargs) -> V1Pod:
@@ -127,15 +138,15 @@ def run_pod_async(self, pod: V1Pod, **kwargs) -> V1Pod:
sanitized_pod = self._client.api_client.sanitize_for_serialization(pod)
json_pod = json.dumps(sanitized_pod, indent=2)
- self.log.debug('Pod Creation Request: \n%s', json_pod)
+ self.log.debug("Pod Creation Request: \n%s", json_pod)
try:
resp = self._client.create_namespaced_pod(
body=sanitized_pod, namespace=pod.metadata.namespace, **kwargs
)
- self.log.debug('Pod Creation Response: %s', resp)
+ self.log.debug("Pod Creation Response: %s", resp)
except Exception as e:
self.log.exception(
- 'Exception when attempting to create Namespaced Pod: %s', str(json_pod).replace("\n", " ")
+ "Exception when attempting to create Namespaced Pod: %s", str(json_pod).replace("\n", " ")
)
raise e
return resp
@@ -194,14 +205,14 @@ def follow_container_logs(self, pod: V1Pod, container_name: str) -> PodLoggingSt
return self.fetch_container_logs(pod=pod, container_name=container_name, follow=True)
def fetch_container_logs(
- self, pod: V1Pod, container_name: str, *, follow=False, since_time: Optional[DateTime] = None
+ self, pod: V1Pod, container_name: str, *, follow=False, since_time: DateTime | None = None
) -> PodLoggingStatus:
"""
Follows the logs of container and streams to airflow logging.
Returns when container exits.
"""
- def consume_logs(*, since_time: Optional[DateTime] = None, follow: bool = True) -> Optional[DateTime]:
+ def consume_logs(*, since_time: DateTime | None = None, follow: bool = True) -> DateTime | None:
"""
Tries to follow container logs until container completes.
For a long-running container, sometimes the log read may be interrupted
@@ -221,7 +232,7 @@ def consume_logs(*, since_time: Optional[DateTime] = None, follow: bool = True)
follow=follow,
)
for raw_line in logs:
- line = raw_line.decode('utf-8', errors="backslashreplace")
+ line = raw_line.decode("utf-8", errors="backslashreplace")
timestamp, message = self.parse_log_line(line)
self.log.info(message)
except BaseHTTPError as e:
@@ -249,7 +260,7 @@ def consume_logs(*, since_time: Optional[DateTime] = None, follow: bool = True)
return PodLoggingStatus(running=True, last_log_time=last_log_time)
else:
self.log.warning(
- 'Pod %s log read interrupted but container %s still running',
+ "Pod %s log read interrupted but container %s still running",
pod.metadata.name,
container_name,
)
@@ -264,25 +275,24 @@ def await_pod_completion(self, pod: V1Pod) -> V1Pod:
Monitors a pod and returns the final state
:param pod: pod spec that will be monitored
- :return: Tuple[State, Optional[str]]
+ :return: tuple[State, str | None]
"""
while True:
remote_pod = self.read_pod(pod)
if remote_pod.status.phase in PodPhase.terminal_states:
break
- self.log.info('Pod %s has phase %s', pod.metadata.name, remote_pod.status.phase)
+ self.log.info("Pod %s has phase %s", pod.metadata.name, remote_pod.status.phase)
time.sleep(2)
return remote_pod
- def parse_log_line(self, line: str) -> Tuple[Optional[DateTime], str]:
+ def parse_log_line(self, line: str) -> tuple[DateTime | None, str]:
"""
Parse K8s log line and returns the final state
:param line: k8s log line
:return: timestamp and log message
- :rtype: Tuple[str, str]
"""
- split_at = line.find(' ')
+ split_at = line.find(" ")
if split_at == -1:
self.log.error(
"Error parsing timestamp (no timestamp in message %r). "
@@ -309,18 +319,18 @@ def read_pod_logs(
self,
pod: V1Pod,
container_name: str,
- tail_lines: Optional[int] = None,
+ tail_lines: int | None = None,
timestamps: bool = False,
- since_seconds: Optional[int] = None,
+ since_seconds: int | None = None,
follow=True,
) -> Iterable[bytes]:
"""Reads log from the POD"""
additional_kwargs = {}
if since_seconds:
- additional_kwargs['since_seconds'] = since_seconds
+ additional_kwargs["since_seconds"] = since_seconds
if tail_lines:
- additional_kwargs['tail_lines'] = tail_lines
+ additional_kwargs["tail_lines"] = tail_lines
try:
return self._client.read_namespaced_pod_log(
@@ -333,18 +343,18 @@ def read_pod_logs(
**additional_kwargs,
)
except BaseHTTPError:
- self.log.exception('There was an error reading the kubernetes API.')
+ self.log.exception("There was an error reading the kubernetes API.")
raise
@tenacity.retry(stop=tenacity.stop_after_attempt(3), wait=tenacity.wait_exponential(), reraise=True)
- def read_pod_events(self, pod: V1Pod) -> "CoreV1EventList":
+ def read_pod_events(self, pod: V1Pod) -> CoreV1EventList:
"""Reads events from the POD"""
try:
return self._client.list_namespaced_event(
namespace=pod.metadata.namespace, field_selector=f"involvedObject.name={pod.metadata.name}"
)
except BaseHTTPError as e:
- raise AirflowException(f'There was an error reading the kubernetes API: {e}')
+ raise AirflowException(f"There was an error reading the kubernetes API: {e}")
@tenacity.retry(stop=tenacity.stop_after_attempt(3), wait=tenacity.wait_exponential(), reraise=True)
def read_pod(self, pod: V1Pod) -> V1Pod:
@@ -352,7 +362,19 @@ def read_pod(self, pod: V1Pod) -> V1Pod:
try:
return self._client.read_namespaced_pod(pod.metadata.name, pod.metadata.namespace)
except BaseHTTPError as e:
- raise AirflowException(f'There was an error reading the kubernetes API: {e}')
+ raise AirflowException(f"There was an error reading the kubernetes API: {e}")
+
+ def await_xcom_sidecar_container_start(self, pod: V1Pod) -> None:
+ self.log.info("Checking if xcom sidecar container is started.")
+ warned = False
+ while True:
+ if self.container_is_running(pod, PodDefaults.SIDECAR_CONTAINER_NAME):
+ self.log.info("The xcom sidecar container is started.")
+ break
+ if not warned:
+ self.log.warning("The xcom sidecar container is not yet started.")
+ warned = True
+ time.sleep(1)
def extract_xcom(self, pod: V1Pod) -> str:
"""Retrieves XCom value and kills xcom sidecar container"""
@@ -362,7 +384,7 @@ def extract_xcom(self, pod: V1Pod) -> str:
pod.metadata.name,
pod.metadata.namespace,
container=PodDefaults.SIDECAR_CONTAINER_NAME,
- command=['/bin/sh'],
+ command=["/bin/sh"],
stdin=True,
stdout=True,
stderr=True,
@@ -370,17 +392,20 @@ def extract_xcom(self, pod: V1Pod) -> str:
_preload_content=False,
)
) as resp:
- result = self._exec_pod_command(resp, f'cat {PodDefaults.XCOM_MOUNT_PATH}/return.json')
- self._exec_pod_command(resp, 'kill -s SIGINT 1')
+ result = self._exec_pod_command(
+ resp,
+ f"if [ -s {PodDefaults.XCOM_MOUNT_PATH}/return.json ]; then cat {PodDefaults.XCOM_MOUNT_PATH}/return.json; else echo __airflow_xcom_result_empty__; fi", # noqa
+ )
+ self._exec_pod_command(resp, "kill -s SIGINT 1")
if result is None:
- raise AirflowException(f'Failed to extract xcom from pod: {pod.metadata.name}')
+ raise AirflowException(f"Failed to extract xcom from pod: {pod.metadata.name}")
return result
- def _exec_pod_command(self, resp, command: str) -> Optional[str]:
+ def _exec_pod_command(self, resp, command: str) -> str | None:
res = None
if resp.is_open():
- self.log.info('Running command... %s\n', command)
- resp.write_stdin(command + '\n')
+ self.log.info("Running command... %s\n", command)
+ resp.write_stdin(command + "\n")
while resp.is_open():
resp.update(timeout=1)
while resp.peek_stdout():
diff --git a/airflow/providers/cncf/kubernetes/utils/xcom_sidecar.py b/airflow/providers/cncf/kubernetes/utils/xcom_sidecar.py
index a8c0ea4c1936f..81b3047993691 100644
--- a/airflow/providers/cncf/kubernetes/utils/xcom_sidecar.py
+++ b/airflow/providers/cncf/kubernetes/utils/xcom_sidecar.py
@@ -19,6 +19,8 @@
by attaching a sidecar container that blocks the pod from completing until
Airflow has pulled result data into the worker for xcom serialization.
"""
+from __future__ import annotations
+
import copy
from kubernetes.client import models as k8s
@@ -27,15 +29,15 @@
class PodDefaults:
"""Static defaults for Pods"""
- XCOM_MOUNT_PATH = '/airflow/xcom'
- SIDECAR_CONTAINER_NAME = 'airflow-xcom-sidecar'
+ XCOM_MOUNT_PATH = "/airflow/xcom"
+ SIDECAR_CONTAINER_NAME = "airflow-xcom-sidecar"
XCOM_CMD = 'trap "exit 0" INT; while true; do sleep 1; done;'
- VOLUME_MOUNT = k8s.V1VolumeMount(name='xcom', mount_path=XCOM_MOUNT_PATH)
- VOLUME = k8s.V1Volume(name='xcom', empty_dir=k8s.V1EmptyDirVolumeSource())
+ VOLUME_MOUNT = k8s.V1VolumeMount(name="xcom", mount_path=XCOM_MOUNT_PATH)
+ VOLUME = k8s.V1Volume(name="xcom", empty_dir=k8s.V1EmptyDirVolumeSource())
SIDECAR_CONTAINER = k8s.V1Container(
name=SIDECAR_CONTAINER_NAME,
- command=['sh', '-c', XCOM_CMD],
- image='alpine',
+ command=["sh", "-c", XCOM_CMD],
+ image="alpine",
volume_mounts=[VOLUME_MOUNT],
resources=k8s.V1ResourceRequirements(
requests={
@@ -45,13 +47,15 @@ class PodDefaults:
)
-def add_xcom_sidecar(pod: k8s.V1Pod) -> k8s.V1Pod:
+def add_xcom_sidecar(pod: k8s.V1Pod, *, sidecar_container_image=None) -> k8s.V1Pod:
"""Adds sidecar"""
pod_cp = copy.deepcopy(pod)
pod_cp.spec.volumes = pod.spec.volumes or []
pod_cp.spec.volumes.insert(0, PodDefaults.VOLUME)
pod_cp.spec.containers[0].volume_mounts = pod_cp.spec.containers[0].volume_mounts or []
pod_cp.spec.containers[0].volume_mounts.insert(0, PodDefaults.VOLUME_MOUNT)
- pod_cp.spec.containers.append(PodDefaults.SIDECAR_CONTAINER)
+ sidecar = copy.deepcopy(PodDefaults.SIDECAR_CONTAINER)
+ sidecar.image = sidecar_container_image or PodDefaults.SIDECAR_CONTAINER.image
+ pod_cp.spec.containers.append(sidecar)
return pod_cp
diff --git a/airflow/providers/github/example_dags/__init__.py b/airflow/providers/common/__init__.py
similarity index 100%
rename from airflow/providers/github/example_dags/__init__.py
rename to airflow/providers/common/__init__.py
diff --git a/airflow/providers/common/sql/.latest-doc-only-change.txt b/airflow/providers/common/sql/.latest-doc-only-change.txt
new file mode 100644
index 0000000000000..ff7136e07d744
--- /dev/null
+++ b/airflow/providers/common/sql/.latest-doc-only-change.txt
@@ -0,0 +1 @@
+06acf40a4337759797f666d5bb27a5a393b74fed
diff --git a/airflow/providers/common/sql/CHANGELOG.rst b/airflow/providers/common/sql/CHANGELOG.rst
new file mode 100644
index 0000000000000..16c1cd098338e
--- /dev/null
+++ b/airflow/providers/common/sql/CHANGELOG.rst
@@ -0,0 +1,127 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+.. NOTE TO CONTRIBUTORS:
+ Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes
+ and you want to add an explanation to the users on how they are supposed to deal with them.
+ The changelog is updated and maintained semi-automatically by release manager.
+
+
+Changelog
+---------
+
+1.3.1
+.....
+
+This release fixes a few errors that were introduced in common.sql operator while refactoring common parts:
+
+* ``_process_output`` method in ``SQLExecuteQueryOperator`` has now consistent semantics and typing, it
+ can also modify the returned (and stored in XCom) values in the operators that derive from the
+ ``SQLExecuteQueryOperator``).
+* last description of the cursor whether to return scalar values are now stored in DBApiHook
+
+Lack of consistency in the operator caused ``1.3.0`` to be yanked - the ``1.3.0`` should not be used - if
+you have ``1.3.0`` installed, upgrade to ``1.3.1``.
+
+Bug Fixes
+~~~~~~~~~
+
+* ``Restore removed (but used) methods in common.sql (#27843)``
+* ``Fix errors in Databricks SQL operator introduced when refactoring (#27854)``
+
+
+1.3.0
+.....
+
+This release of provider is only available for Airflow 2.3+ as explained in the
+`Apache Airflow providers support policy `_.
+
+Misc
+~~~~
+
+* ``Move min airflow version to 2.3.0 for all providers (#27196)``
+
+Features
+~~~~~~~~
+
+* ``Add SQLExecuteQueryOperator (#25717)``
+* ``Use DbApiHook.run for DbApiHook.get_records and DbApiHook.get_first (#26944)``
+* ``DbApiHook consistent insert_rows logging (#26758)``
+
+Bug Fixes
+~~~~~~~~~
+
+* ``Common sql bugfixes and improvements (#26761)``
+* ``Use unused SQLCheckOperator.parameters in SQLCheckOperator.execute. (#27599)``
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Update old style typing (#26872)``
+ * ``Enable string normalization in python formatting - providers (#27205)``
+ * ``Update docs for September Provider's release (#26731)``
+ * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)``
+
+1.2.0
+.....
+
+Features
+~~~~~~~~
+
+* ``Make placeholder style configurable (#25939)``
+* ``Better error message for pre-common-sql providers (#26051)``
+
+Bug Fixes
+~~~~~~~~~
+
+* ``Fix (and test) SQLTableCheckOperator on postgresql (#25821)``
+* ``Don't use Pandas for SQLTableCheckOperator (#25822)``
+* ``Discard semicolon stripping in SQL hook (#25855)``
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+
+
+1.1.0
+.....
+
+Features
+~~~~~~~~
+
+* ``Improve taskflow type hints with ParamSpec (#25173)``
+* ``Move all "old" SQL operators to common.sql providers (#25350)``
+* ``Deprecate hql parameters and synchronize DBApiHook method APIs (#25299)``
+* ``Unify DbApiHook.run() method with the methods which override it (#23971)``
+* ``Common SQLCheckOperators Various Functionality Update (#25164)``
+
+Bug Fixes
+~~~~~~~~~
+
+* ``Allow Legacy SqlSensor to use the common.sql providers (#25293)``
+* ``Fix fetch_all_handler & db-api tests for it (#25430)``
+* ``Align Common SQL provider logo location (#25538)``
+* ``Fix SQL split string to include ';-less' statements (#25713)``
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Fix CHANGELOG for common.sql provider and add amazon commit (#25636)``
+
+1.0.0
+.....
+
+Initial version of the provider.
+Adds ``SQLColumnCheckOperator`` and ``SQLTableCheckOperator``.
+Moves ``DBApiHook``, ``SQLSensor`` and ``ConnectorProtocol`` to the provider.
diff --git a/airflow/providers/google/ads/example_dags/__init__.py b/airflow/providers/common/sql/__init__.py
similarity index 100%
rename from airflow/providers/google/ads/example_dags/__init__.py
rename to airflow/providers/common/sql/__init__.py
diff --git a/airflow/providers/google/firebase/example_dags/__init__.py b/airflow/providers/common/sql/hooks/__init__.py
similarity index 100%
rename from airflow/providers/google/firebase/example_dags/__init__.py
rename to airflow/providers/common/sql/hooks/__init__.py
diff --git a/airflow/providers/common/sql/hooks/sql.py b/airflow/providers/common/sql/hooks/sql.py
new file mode 100644
index 0000000000000..df808430fd9aa
--- /dev/null
+++ b/airflow/providers/common/sql/hooks/sql.py
@@ -0,0 +1,433 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from contextlib import closing
+from datetime import datetime
+from typing import Any, Callable, Iterable, Mapping, Sequence, cast
+
+import sqlparse
+from packaging.version import Version
+from sqlalchemy import create_engine
+from typing_extensions import Protocol
+
+from airflow import AirflowException
+from airflow.hooks.base import BaseHook
+from airflow.version import version
+
+
+def fetch_all_handler(cursor) -> list[tuple] | None:
+ """Handler for DbApiHook.run() to return results"""
+ if cursor.description is not None:
+ return cursor.fetchall()
+ else:
+ return None
+
+
+def fetch_one_handler(cursor) -> list[tuple] | None:
+ """Handler for DbApiHook.run() to return results"""
+ if cursor.description is not None:
+ return cursor.fetchone()
+ else:
+ return None
+
+
+class ConnectorProtocol(Protocol):
+ """A protocol where you can connect to a database."""
+
+ def connect(self, host: str, port: int, username: str, schema: str) -> Any:
+ """
+ Connect to a database.
+
+ :param host: The database host to connect to.
+ :param port: The database port to connect to.
+ :param username: The database username used for the authentication.
+ :param schema: The database schema to connect to.
+ :return: the authorized connection object.
+ """
+
+
+# In case we are running it on Airflow 2.4+, we should use BaseHook, but on Airflow 2.3 and below
+# We want the DbApiHook to derive from the original DbApiHook from airflow, because otherwise
+# SqlSensor and BaseSqlOperator from "airflow.operators" and "airflow.sensors" will refuse to
+# accept the new Hooks as not derived from the original DbApiHook
+if Version(version) < Version("2.4"):
+ try:
+ from airflow.hooks.dbapi import DbApiHook as BaseForDbApiHook
+ except ImportError:
+ # just in case we have a problem with circular import
+ BaseForDbApiHook: type[BaseHook] = BaseHook # type: ignore[no-redef]
+else:
+ BaseForDbApiHook: type[BaseHook] = BaseHook # type: ignore[no-redef]
+
+
+class DbApiHook(BaseForDbApiHook):
+ """
+ Abstract base class for sql hooks.
+
+ :param schema: Optional DB schema that overrides the schema specified in the connection. Make sure that
+ if you change the schema parameter value in the constructor of the derived Hook, such change
+ should be done before calling the ``DBApiHook.__init__()``.
+ :param log_sql: Whether to log SQL query when it's executed. Defaults to *True*.
+ """
+
+ # Override to provide the connection name.
+ conn_name_attr: str
+ # Override to have a default connection id for a particular dbHook
+ default_conn_name = "default_conn_id"
+ # Override if this db supports autocommit.
+ supports_autocommit = False
+ # Override with the object that exposes the connect method
+ connector: ConnectorProtocol | None = None
+ # Override with db-specific query to check connection
+ _test_connection_sql = "select 1"
+ # Override with the db-specific value used for placeholders
+ placeholder: str = "%s"
+
+ def __init__(self, *args, schema: str | None = None, log_sql: bool = True, **kwargs):
+ super().__init__()
+ if not self.conn_name_attr:
+ raise AirflowException("conn_name_attr is not defined")
+ elif len(args) == 1:
+ setattr(self, self.conn_name_attr, args[0])
+ elif self.conn_name_attr not in kwargs:
+ setattr(self, self.conn_name_attr, self.default_conn_name)
+ else:
+ setattr(self, self.conn_name_attr, kwargs[self.conn_name_attr])
+ # We should not make schema available in deriving hooks for backwards compatibility
+ # If a hook deriving from DBApiHook has a need to access schema, then it should retrieve it
+ # from kwargs and store it on its own. We do not run "pop" here as we want to give the
+ # Hook deriving from the DBApiHook to still have access to the field in its constructor
+ self.__schema = schema
+ self.log_sql = log_sql
+ self.scalar_return_last = False
+ self.last_description: Sequence[Sequence] | None = None
+
+ def get_conn(self):
+ """Returns a connection object"""
+ db = self.get_connection(getattr(self, cast(str, self.conn_name_attr)))
+ return self.connector.connect(host=db.host, port=db.port, username=db.login, schema=db.schema)
+
+ def get_uri(self) -> str:
+ """
+ Extract the URI from the connection.
+
+ :return: the extracted uri.
+ """
+ conn = self.get_connection(getattr(self, self.conn_name_attr))
+ conn.schema = self.__schema or conn.schema
+ return conn.get_uri()
+
+ def get_sqlalchemy_engine(self, engine_kwargs=None):
+ """
+ Get an sqlalchemy_engine object.
+
+ :param engine_kwargs: Kwargs used in :func:`~sqlalchemy.create_engine`.
+ :return: the created engine.
+ """
+ if engine_kwargs is None:
+ engine_kwargs = {}
+ return create_engine(self.get_uri(), **engine_kwargs)
+
+ def get_pandas_df(self, sql, parameters=None, **kwargs):
+ """
+ Executes the sql and returns a pandas dataframe
+
+ :param sql: the sql statement to be executed (str) or a list of
+ sql statements to execute
+ :param parameters: The parameters to render the SQL query with.
+ :param kwargs: (optional) passed into pandas.io.sql.read_sql method
+ """
+ try:
+ from pandas.io import sql as psql
+ except ImportError:
+ raise Exception(
+ "pandas library not installed, run: pip install "
+ "'apache-airflow-providers-common-sql[pandas]'."
+ )
+
+ with closing(self.get_conn()) as conn:
+ return psql.read_sql(sql, con=conn, params=parameters, **kwargs)
+
+ def get_pandas_df_by_chunks(self, sql, parameters=None, *, chunksize, **kwargs):
+ """
+ Executes the sql and returns a generator
+
+ :param sql: the sql statement to be executed (str) or a list of
+ sql statements to execute
+ :param parameters: The parameters to render the SQL query with
+ :param chunksize: number of rows to include in each chunk
+ :param kwargs: (optional) passed into pandas.io.sql.read_sql method
+ """
+ try:
+ from pandas.io import sql as psql
+ except ImportError:
+ raise Exception(
+ "pandas library not installed, run: pip install "
+ "'apache-airflow-providers-common-sql[pandas]'."
+ )
+
+ with closing(self.get_conn()) as conn:
+ yield from psql.read_sql(sql, con=conn, params=parameters, chunksize=chunksize, **kwargs)
+
+ def get_records(
+ self,
+ sql: str | list[str],
+ parameters: Iterable | Mapping | None = None,
+ ) -> Any:
+ """
+ Executes the sql and returns a set of records.
+
+ :param sql: the sql statement to be executed (str) or a list of sql statements to execute
+ :param parameters: The parameters to render the SQL query with.
+ """
+ return self.run(sql=sql, parameters=parameters, handler=fetch_all_handler)
+
+ def get_first(self, sql: str | list[str], parameters: Iterable | Mapping | None = None) -> Any:
+ """
+ Executes the sql and returns the first resulting row.
+
+ :param sql: the sql statement to be executed (str) or a list of sql statements to execute
+ :param parameters: The parameters to render the SQL query with.
+ """
+ return self.run(sql=sql, parameters=parameters, handler=fetch_one_handler)
+
+ @staticmethod
+ def strip_sql_string(sql: str) -> str:
+ return sql.strip().rstrip(";")
+
+ @staticmethod
+ def split_sql_string(sql: str) -> list[str]:
+ """
+ Splits string into multiple SQL expressions
+
+ :param sql: SQL string potentially consisting of multiple expressions
+ :return: list of individual expressions
+ """
+ splits = sqlparse.split(sqlparse.format(sql, strip_comments=True))
+ statements: list[str] = list(filter(None, splits))
+ return statements
+
+ def run(
+ self,
+ sql: str | Iterable[str],
+ autocommit: bool = False,
+ parameters: Iterable | Mapping | None = None,
+ handler: Callable | None = None,
+ split_statements: bool = False,
+ return_last: bool = True,
+ ) -> Any | list[Any] | None:
+ """
+ Runs a command or a list of commands. Pass a list of sql
+ statements to the sql parameter to get them to execute
+ sequentially
+
+ :param sql: the sql statement to be executed (str) or a list of
+ sql statements to execute
+ :param autocommit: What to set the connection's autocommit setting to
+ before executing the query.
+ :param parameters: The parameters to render the SQL query with.
+ :param handler: The result handler which is called with the result of each statement.
+ :param split_statements: Whether to split a single SQL string into statements and run separately
+ :param return_last: Whether to return result for only last statement or for all after split
+ :return: return only result of the ALL SQL expressions if handler was provided.
+ """
+ self.scalar_return_last = isinstance(sql, str) and return_last
+ if isinstance(sql, str):
+ if split_statements:
+ sql = self.split_sql_string(sql)
+ else:
+ sql = [sql]
+
+ if sql:
+ self.log.debug("Executing following statements against DB: %s", list(sql))
+ else:
+ raise ValueError("List of SQL statements is empty")
+
+ with closing(self.get_conn()) as conn:
+ if self.supports_autocommit:
+ self.set_autocommit(conn, autocommit)
+
+ with closing(conn.cursor()) as cur:
+ results = []
+ for sql_statement in sql:
+ self._run_command(cur, sql_statement, parameters)
+
+ if handler is not None:
+ result = handler(cur)
+ results.append(result)
+ self.last_description = cur.description
+
+ # If autocommit was set to False or db does not support autocommit, we do a manual commit.
+ if not self.get_autocommit(conn):
+ conn.commit()
+
+ if handler is None:
+ return None
+ elif self.scalar_return_last:
+ return results[-1]
+ else:
+ return results
+
+ def _run_command(self, cur, sql_statement, parameters):
+ """Runs a statement using an already open cursor."""
+ if self.log_sql:
+ self.log.info("Running statement: %s, parameters: %s", sql_statement, parameters)
+
+ if parameters:
+ cur.execute(sql_statement, parameters)
+ else:
+ cur.execute(sql_statement)
+
+ # According to PEP 249, this is -1 when query result is not applicable.
+ if cur.rowcount >= 0:
+ self.log.info("Rows affected: %s", cur.rowcount)
+
+ def set_autocommit(self, conn, autocommit):
+ """Sets the autocommit flag on the connection"""
+ if not self.supports_autocommit and autocommit:
+ self.log.warning(
+ "%s connection doesn't support autocommit but autocommit activated.",
+ getattr(self, self.conn_name_attr),
+ )
+ conn.autocommit = autocommit
+
+ def get_autocommit(self, conn) -> bool:
+ """
+ Get autocommit setting for the provided connection.
+ Return True if conn.autocommit is set to True.
+ Return False if conn.autocommit is not set or set to False or conn
+ does not support autocommit.
+
+ :param conn: Connection to get autocommit setting from.
+ :return: connection autocommit setting.
+ """
+ return getattr(conn, "autocommit", False) and self.supports_autocommit
+
+ def get_cursor(self):
+ """Returns a cursor"""
+ return self.get_conn().cursor()
+
+ @classmethod
+ def _generate_insert_sql(cls, table, values, target_fields, replace, **kwargs) -> str:
+ """
+ Helper class method that generates the INSERT SQL statement.
+ The REPLACE variant is specific to MySQL syntax.
+
+ :param table: Name of the target table
+ :param values: The row to insert into the table
+ :param target_fields: The names of the columns to fill in the table
+ :param replace: Whether to replace instead of insert
+ :return: The generated INSERT or REPLACE SQL statement
+ """
+ placeholders = [
+ cls.placeholder,
+ ] * len(values)
+
+ if target_fields:
+ target_fields = ", ".join(target_fields)
+ target_fields = f"({target_fields})"
+ else:
+ target_fields = ""
+
+ if not replace:
+ sql = "INSERT INTO "
+ else:
+ sql = "REPLACE INTO "
+ sql += f"{table} {target_fields} VALUES ({','.join(placeholders)})"
+ return sql
+
+ def insert_rows(self, table, rows, target_fields=None, commit_every=1000, replace=False, **kwargs):
+ """
+ A generic way to insert a set of tuples into a table,
+ a new transaction is created every commit_every rows
+
+ :param table: Name of the target table
+ :param rows: The rows to insert into the table
+ :param target_fields: The names of the columns to fill in the table
+ :param commit_every: The maximum number of rows to insert in one
+ transaction. Set to 0 to insert all rows in one transaction.
+ :param replace: Whether to replace instead of insert
+ """
+ i = 0
+ with closing(self.get_conn()) as conn:
+ if self.supports_autocommit:
+ self.set_autocommit(conn, False)
+
+ conn.commit()
+
+ with closing(conn.cursor()) as cur:
+ for i, row in enumerate(rows, 1):
+ lst = []
+ for cell in row:
+ lst.append(self._serialize_cell(cell, conn))
+ values = tuple(lst)
+ sql = self._generate_insert_sql(table, values, target_fields, replace, **kwargs)
+ self.log.debug("Generated sql: %s", sql)
+ cur.execute(sql, values)
+ if commit_every and i % commit_every == 0:
+ conn.commit()
+ self.log.info("Loaded %s rows into %s so far", i, table)
+
+ conn.commit()
+ self.log.info("Done loading. Loaded a total of %s rows into %s", i, table)
+
+ @staticmethod
+ def _serialize_cell(cell, conn=None) -> str | None:
+ """
+ Returns the SQL literal of the cell as a string.
+
+ :param cell: The cell to insert into the table
+ :param conn: The database connection
+ :return: The serialized cell
+ """
+ if cell is None:
+ return None
+ if isinstance(cell, datetime):
+ return cell.isoformat()
+ return str(cell)
+
+ def bulk_dump(self, table, tmp_file):
+ """
+ Dumps a database table into a tab-delimited file
+
+ :param table: The name of the source table
+ :param tmp_file: The path of the target file
+ """
+ raise NotImplementedError()
+
+ def bulk_load(self, table, tmp_file):
+ """
+ Loads a tab-delimited file into a database table
+
+ :param table: The name of the target table
+ :param tmp_file: The path of the file to load into the table
+ """
+ raise NotImplementedError()
+
+ def test_connection(self):
+ """Tests the connection using db-specific query"""
+ status, message = False, ""
+ try:
+ if self.get_first(self._test_connection_sql):
+ status = True
+ message = "Connection successfully tested"
+ except Exception as e:
+ status = False
+ message = str(e)
+
+ return status, message
diff --git a/airflow/providers/influxdb/example_dags/__init__.py b/airflow/providers/common/sql/operators/__init__.py
similarity index 100%
rename from airflow/providers/influxdb/example_dags/__init__.py
rename to airflow/providers/common/sql/operators/__init__.py
diff --git a/airflow/providers/common/sql/operators/sql.py b/airflow/providers/common/sql/operators/sql.py
new file mode 100644
index 0000000000000..314af43003488
--- /dev/null
+++ b/airflow/providers/common/sql/operators/sql.py
@@ -0,0 +1,1104 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import ast
+import re
+from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, NoReturn, Sequence, SupportsAbs, overload
+
+from airflow.compat.functools import cached_property
+from airflow.exceptions import AirflowException, AirflowFailException
+from airflow.hooks.base import BaseHook
+from airflow.models import BaseOperator, SkipMixin
+from airflow.providers.common.sql.hooks.sql import DbApiHook, fetch_all_handler
+from airflow.typing_compat import Literal
+
+if TYPE_CHECKING:
+ from airflow.utils.context import Context
+
+
+def _convert_to_float_if_possible(s: str) -> float | str:
+ try:
+ return float(s)
+ except (ValueError, TypeError):
+ return s
+
+
+def _parse_boolean(val: str) -> str | bool:
+ """Try to parse a string into boolean.
+
+ Raises ValueError if the input is not a valid true- or false-like string value.
+ """
+ val = val.lower()
+ if val in ("y", "yes", "t", "true", "on", "1"):
+ return True
+ if val in ("n", "no", "f", "false", "off", "0"):
+ return False
+ raise ValueError(f"{val!r} is not a boolean-like string value")
+
+
+def _get_failed_checks(checks, col=None):
+ """
+ IMPORTANT!!! Keep it for compatibility with released 8.4.0 version of google provider.
+
+ Unfortunately the provider used _get_failed_checks and parse_boolean as imports and we should
+ keep those methods to avoid 8.4.0 version from failing.
+ """
+ if col:
+ return [
+ f"Column: {col}\nCheck: {check},\nCheck Values: {check_values}\n"
+ for check, check_values in checks.items()
+ if not check_values["success"]
+ ]
+ return [
+ f"\tCheck: {check},\n\tCheck Values: {check_values}\n"
+ for check, check_values in checks.items()
+ if not check_values["success"]
+ ]
+
+
+parse_boolean = _parse_boolean
+"""
+IMPORTANT!!! Keep it for compatibility with released 8.4.0 version of google provider.
+
+Unfortunately the provider used _get_failed_checks and parse_boolean as imports and we should
+keep those methods to avoid 8.4.0 version from failing.
+"""
+
+
+_PROVIDERS_MATCHER = re.compile(r"airflow\.providers\.(.*)\.hooks.*")
+
+_MIN_SUPPORTED_PROVIDERS_VERSION = {
+ "amazon": "4.1.0",
+ "apache.drill": "2.1.0",
+ "apache.druid": "3.1.0",
+ "apache.hive": "3.1.0",
+ "apache.pinot": "3.1.0",
+ "databricks": "3.1.0",
+ "elasticsearch": "4.1.0",
+ "exasol": "3.1.0",
+ "google": "8.2.0",
+ "jdbc": "3.1.0",
+ "mssql": "3.1.0",
+ "mysql": "3.1.0",
+ "odbc": "3.1.0",
+ "oracle": "3.1.0",
+ "postgres": "5.1.0",
+ "presto": "3.1.0",
+ "qubole": "3.1.0",
+ "slack": "5.1.0",
+ "snowflake": "3.1.0",
+ "sqlite": "3.1.0",
+ "trino": "3.1.0",
+ "vertica": "3.1.0",
+}
+
+
+class BaseSQLOperator(BaseOperator):
+ """
+ This is a base class for generic SQL Operator to get a DB Hook
+
+ The provided method is .get_db_hook(). The default behavior will try to
+ retrieve the DB hook based on connection type.
+ You can customize the behavior by overriding the .get_db_hook() method.
+
+ :param conn_id: reference to a specific database
+ """
+
+ def __init__(
+ self,
+ *,
+ conn_id: str | None = None,
+ database: str | None = None,
+ hook_params: dict | None = None,
+ retry_on_failure: bool = True,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.conn_id = conn_id
+ self.database = database
+ self.hook_params = {} if hook_params is None else hook_params
+ self.retry_on_failure = retry_on_failure
+
+ @cached_property
+ def _hook(self):
+ """Get DB Hook based on connection type"""
+ self.log.debug("Get connection for %s", self.conn_id)
+ conn = BaseHook.get_connection(self.conn_id)
+ hook = conn.get_hook(hook_params=self.hook_params)
+ if not isinstance(hook, DbApiHook):
+ from airflow.hooks.dbapi_hook import DbApiHook as _DbApiHook
+
+ if isinstance(hook, _DbApiHook):
+ # This case might happen if user installed common.sql provider but did not upgrade the
+ # Other provider's versions to a version that supports common.sql provider
+ class_module = hook.__class__.__module__
+ match = _PROVIDERS_MATCHER.match(class_module)
+ if match:
+ provider = match.group(1)
+ min_version = _MIN_SUPPORTED_PROVIDERS_VERSION.get(provider)
+ if min_version:
+ raise AirflowException(
+ f"You are trying to use common-sql with {hook.__class__.__name__},"
+ f" but the Hook class comes from provider {provider} that does not support it."
+ f" Please upgrade provider {provider} to at least {min_version}."
+ )
+ raise AirflowException(
+ f"You are trying to use `common-sql` with {hook.__class__.__name__},"
+ " but its provider does not support it. Please upgrade the provider to a version that"
+ " supports `common-sql`. The hook class should be a subclass of"
+ " `airflow.providers.common.sql.hooks.sql.DbApiHook`."
+ f" Got {hook.__class__.__name__} Hook with class hierarchy: {hook.__class__.mro()}"
+ )
+
+ if self.database:
+ hook.schema = self.database
+
+ return hook
+
+ def get_db_hook(self) -> DbApiHook:
+ """
+ Get the database hook for the connection.
+
+ :return: the database hook object.
+ """
+ return self._hook
+
+ def _raise_exception(self, exception_string: str) -> NoReturn:
+ if self.retry_on_failure:
+ raise AirflowException(exception_string)
+ raise AirflowFailException(exception_string)
+
+
+class SQLExecuteQueryOperator(BaseSQLOperator):
+ """
+ Executes SQL code in a specific database
+ :param sql: the SQL code or string pointing to a template file to be executed (templated).
+ File must have a '.sql' extensions.
+ :param autocommit: (optional) if True, each command is automatically committed (default: False).
+ :param parameters: (optional) the parameters to render the SQL query with.
+ :param handler: (optional) the function that will be applied to the cursor (default: fetch_all_handler).
+ :param split_statements: (optional) if split single SQL string into statements (default: False).
+ :param return_last: (optional) if return the result of only last statement (default: True).
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:SQLExecuteQueryOperator`
+ """
+
+ template_fields: Sequence[str] = ("sql", "parameters")
+ template_ext: Sequence[str] = (".sql", ".json")
+ template_fields_renderers = {"sql": "sql", "parameters": "json"}
+ ui_color = "#cdaaed"
+
+ def __init__(
+ self,
+ *,
+ sql: str | list[str],
+ autocommit: bool = False,
+ parameters: Mapping | Iterable | None = None,
+ handler: Callable[[Any], Any] = fetch_all_handler,
+ split_statements: bool = False,
+ return_last: bool = True,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.sql = sql
+ self.autocommit = autocommit
+ self.parameters = parameters
+ self.handler = handler
+ self.split_statements = split_statements
+ self.return_last = return_last
+
+ @overload
+ def _process_output(
+ self, results: Any, description: Sequence[Sequence] | None, scalar_results: Literal[True]
+ ) -> Any:
+ pass
+
+ @overload
+ def _process_output(
+ self, results: list[Any], description: Sequence[Sequence] | None, scalar_results: Literal[False]
+ ) -> Any:
+ pass
+
+ def _process_output(
+ self, results: Any | list[Any], description: Sequence[Sequence] | None, scalar_results: bool
+ ) -> Any:
+ """
+ Can be overridden by the subclass in case some extra processing is needed.
+ The "process_output" method can override the returned output - augmenting or processing the
+ output as needed - the output returned will be returned as execute return value and if
+ do_xcom_push is set to True, it will be set as XCom returned
+
+ :param results: results in the form of list of rows.
+ :param description: as returned by ``cur.description`` in the Python DBAPI
+ :param scalar_results: True if result is single scalar value rather than list of rows
+ """
+ return results
+
+ def execute(self, context):
+ self.log.info("Executing: %s", self.sql)
+ hook = self.get_db_hook()
+ if self.do_xcom_push:
+ output = hook.run(
+ sql=self.sql,
+ autocommit=self.autocommit,
+ parameters=self.parameters,
+ handler=self.handler,
+ split_statements=self.split_statements,
+ return_last=self.return_last,
+ )
+ else:
+ output = hook.run(
+ sql=self.sql,
+ autocommit=self.autocommit,
+ parameters=self.parameters,
+ split_statements=self.split_statements,
+ )
+
+ return self._process_output(output, hook.last_description, hook.scalar_return_last)
+
+ def prepare_template(self) -> None:
+ """Parse template file for attribute parameters."""
+ if isinstance(self.parameters, str):
+ self.parameters = ast.literal_eval(self.parameters)
+
+
+class SQLColumnCheckOperator(BaseSQLOperator):
+ """
+ Performs one or more of the templated checks in the column_checks dictionary.
+ Checks are performed on a per-column basis specified by the column_mapping.
+ Each check can take one or more of the following options:
+ - equal_to: an exact value to equal, cannot be used with other comparison options
+ - greater_than: value that result should be strictly greater than
+ - less_than: value that results should be strictly less than
+ - geq_to: value that results should be greater than or equal to
+ - leq_to: value that results should be less than or equal to
+ - tolerance: the percentage that the result may be off from the expected value
+ - partition_clause: an extra clause passed into a WHERE statement to partition data
+
+ :param table: the table to run checks on
+ :param column_mapping: the dictionary of columns and their associated checks, e.g.
+
+ .. code-block:: python
+
+ {
+ "col_name": {
+ "null_check": {
+ "equal_to": 0,
+ "partition_clause": "foreign_key IS NOT NULL",
+ },
+ "min": {
+ "greater_than": 5,
+ "leq_to": 10,
+ "tolerance": 0.2,
+ },
+ "max": {"less_than": 1000, "geq_to": 10, "tolerance": 0.01},
+ }
+ }
+
+ :param partition_clause: a partial SQL statement that is added to a WHERE clause in the query built by
+ the operator that creates partition_clauses for the checks to run on, e.g.
+
+ .. code-block:: python
+
+ "date = '1970-01-01'"
+
+ :param conn_id: the connection ID used to connect to the database
+ :param database: name of database which overwrite the defined one in connection
+ :param accept_none: whether or not to accept None values returned by the query. If true, converts None
+ to 0.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:SQLColumnCheckOperator`
+ """
+
+ template_fields = ("partition_clause",)
+
+ sql_check_template = """
+ SELECT '{column}' AS col_name, '{check}' AS check_type, {column}_{check} AS check_result
+ FROM (SELECT {check_statement} AS {column}_{check} FROM {table} {partition_clause}) AS sq
+ """
+
+ column_checks = {
+ "null_check": "SUM(CASE WHEN {column} IS NULL THEN 1 ELSE 0 END)",
+ "distinct_check": "COUNT(DISTINCT({column}))",
+ "unique_check": "COUNT({column}) - COUNT(DISTINCT({column}))",
+ "min": "MIN({column})",
+ "max": "MAX({column})",
+ }
+
+ def __init__(
+ self,
+ *,
+ table: str,
+ column_mapping: dict[str, dict[str, Any]],
+ partition_clause: str | None = None,
+ conn_id: str | None = None,
+ database: str | None = None,
+ accept_none: bool = True,
+ **kwargs,
+ ):
+ super().__init__(conn_id=conn_id, database=database, **kwargs)
+
+ self.table = table
+ self.column_mapping = column_mapping
+ self.partition_clause = partition_clause
+ self.accept_none = accept_none
+
+ def _build_checks_sql():
+ for column, checks in self.column_mapping.items():
+ for check, check_values in checks.items():
+ self._column_mapping_validation(check, check_values)
+ yield self._generate_sql_query(column, checks)
+
+ checks_sql = "UNION ALL".join(_build_checks_sql())
+
+ self.sql = f"SELECT col_name, check_type, check_result FROM ({checks_sql}) AS check_columns"
+
+ def execute(self, context: Context):
+ hook = self.get_db_hook()
+ records = hook.get_records(self.sql)
+
+ if not records:
+ self._raise_exception(f"The following query returned zero rows: {self.sql}")
+
+ self.log.info("Record: %s", records)
+
+ for column, check, result in records:
+ tolerance = self.column_mapping[column][check].get("tolerance")
+
+ self.column_mapping[column][check]["result"] = result
+ self.column_mapping[column][check]["success"] = self._get_match(
+ self.column_mapping[column][check], result, tolerance
+ )
+
+ failed_tests = [
+ f"Column: {col}\n\tCheck: {check},\n\tCheck Values: {check_values}\n"
+ for col, checks in self.column_mapping.items()
+ for check, check_values in checks.items()
+ if not check_values["success"]
+ ]
+ if failed_tests:
+ exception_string = (
+ f"Test failed.\nResults:\n{records!s}\n"
+ f"The following tests have failed:\n{''.join(failed_tests)}"
+ )
+ self._raise_exception(exception_string)
+
+ self.log.info("All tests have passed")
+
+ def _generate_sql_query(self, column, checks):
+ def _generate_partition_clause(check):
+ if self.partition_clause and "partition_clause" not in checks[check]:
+ return f"WHERE {self.partition_clause}"
+ elif not self.partition_clause and "partition_clause" in checks[check]:
+ return f"WHERE {checks[check]['partition_clause']}"
+ elif self.partition_clause and "partition_clause" in checks[check]:
+ return f"WHERE {self.partition_clause} AND {checks[check]['partition_clause']}"
+ else:
+ return ""
+
+ checks_sql = "UNION ALL".join(
+ self.sql_check_template.format(
+ check_statement=self.column_checks[check].format(column=column),
+ check=check,
+ table=self.table,
+ column=column,
+ partition_clause=_generate_partition_clause(check),
+ )
+ for check in checks
+ )
+ return checks_sql
+
+ def _get_match(self, check_values, record, tolerance=None) -> bool:
+ if record is None and self.accept_none:
+ record = 0
+ match_boolean = True
+ if "geq_to" in check_values:
+ if tolerance is not None:
+ match_boolean = record >= check_values["geq_to"] * (1 - tolerance)
+ else:
+ match_boolean = record >= check_values["geq_to"]
+ elif "greater_than" in check_values:
+ if tolerance is not None:
+ match_boolean = record > check_values["greater_than"] * (1 - tolerance)
+ else:
+ match_boolean = record > check_values["greater_than"]
+ if "leq_to" in check_values:
+ if tolerance is not None:
+ match_boolean = record <= check_values["leq_to"] * (1 + tolerance) and match_boolean
+ else:
+ match_boolean = record <= check_values["leq_to"] and match_boolean
+ elif "less_than" in check_values:
+ if tolerance is not None:
+ match_boolean = record < check_values["less_than"] * (1 + tolerance) and match_boolean
+ else:
+ match_boolean = record < check_values["less_than"] and match_boolean
+ if "equal_to" in check_values:
+ if tolerance is not None:
+ match_boolean = (
+ check_values["equal_to"] * (1 - tolerance)
+ <= record
+ <= check_values["equal_to"] * (1 + tolerance)
+ ) and match_boolean
+ else:
+ match_boolean = record == check_values["equal_to"] and match_boolean
+ return match_boolean
+
+ def _column_mapping_validation(self, check, check_values):
+ if check not in self.column_checks:
+ raise AirflowException(f"Invalid column check: {check}.")
+ if (
+ "greater_than" not in check_values
+ and "geq_to" not in check_values
+ and "less_than" not in check_values
+ and "leq_to" not in check_values
+ and "equal_to" not in check_values
+ ):
+ raise ValueError(
+ "Please provide one or more of: less_than, leq_to, "
+ "greater_than, geq_to, or equal_to in the check's dict."
+ )
+
+ if "greater_than" in check_values and "less_than" in check_values:
+ if check_values["greater_than"] >= check_values["less_than"]:
+ raise ValueError(
+ "greater_than should be strictly less than "
+ "less_than. Use geq_to or leq_to for "
+ "overlapping equality."
+ )
+
+ if "greater_than" in check_values and "leq_to" in check_values:
+ if check_values["greater_than"] >= check_values["leq_to"]:
+ raise ValueError(
+ "greater_than must be strictly less than leq_to. "
+ "Use geq_to with leq_to for overlapping equality."
+ )
+
+ if "geq_to" in check_values and "less_than" in check_values:
+ if check_values["geq_to"] >= check_values["less_than"]:
+ raise ValueError(
+ "geq_to should be strictly less than less_than. "
+ "Use leq_to with geq_to for overlapping equality."
+ )
+
+ if "geq_to" in check_values and "leq_to" in check_values:
+ if check_values["geq_to"] > check_values["leq_to"]:
+ raise ValueError("geq_to should be less than or equal to leq_to.")
+
+ if "greater_than" in check_values and "geq_to" in check_values:
+ raise ValueError("Only supply one of greater_than or geq_to.")
+
+ if "less_than" in check_values and "leq_to" in check_values:
+ raise ValueError("Only supply one of less_than or leq_to.")
+
+ if (
+ "greater_than" in check_values
+ or "geq_to" in check_values
+ or "less_than" in check_values
+ or "leq_to" in check_values
+ ) and "equal_to" in check_values:
+ raise ValueError(
+ "equal_to cannot be passed with a greater or less than "
+ "function. To specify 'greater than or equal to' or "
+ "'less than or equal to', use geq_to or leq_to."
+ )
+
+
+class SQLTableCheckOperator(BaseSQLOperator):
+ """
+ Performs one or more of the checks provided in the checks dictionary.
+ Checks should be written to return a boolean result.
+
+ :param table: the table to run checks on
+ :param checks: the dictionary of checks, where check names are followed by a dictionary containing at
+ least a check statement, and optionally a partition clause, e.g.:
+
+ .. code-block:: python
+
+ {
+ "row_count_check": {"check_statement": "COUNT(*) = 1000"},
+ "column_sum_check": {"check_statement": "col_a + col_b < col_c"},
+ "third_check": {"check_statement": "MIN(col) = 1", "partition_clause": "col IS NOT NULL"},
+ }
+
+
+ :param partition_clause: a partial SQL statement that is added to a WHERE clause in the query built by
+ the operator that creates partition_clauses for the checks to run on, e.g.
+
+ .. code-block:: python
+
+ "date = '1970-01-01'"
+
+ :param conn_id: the connection ID used to connect to the database
+ :param database: name of database which overwrite the defined one in connection
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the guide:
+ :ref:`howto/operator:SQLTableCheckOperator`
+ """
+
+ template_fields = ("partition_clause",)
+
+ sql_check_template = """
+ SELECT '{check_name}' AS check_name, MIN({check_name}) AS check_result
+ FROM (SELECT CASE WHEN {check_statement} THEN 1 ELSE 0 END AS {check_name}
+ FROM {table} {partition_clause}) AS sq
+ """
+
+ def __init__(
+ self,
+ *,
+ table: str,
+ checks: dict[str, dict[str, Any]],
+ partition_clause: str | None = None,
+ conn_id: str | None = None,
+ database: str | None = None,
+ **kwargs,
+ ):
+ super().__init__(conn_id=conn_id, database=database, **kwargs)
+
+ self.table = table
+ self.checks = checks
+ self.partition_clause = partition_clause
+ self.sql = f"SELECT check_name, check_result FROM ({self._generate_sql_query()}) AS check_table"
+
+ def execute(self, context: Context):
+ hook = self.get_db_hook()
+ records = hook.get_records(self.sql)
+
+ if not records:
+ self._raise_exception(f"The following query returned zero rows: {self.sql}")
+
+ self.log.info("Record:\n%s", records)
+
+ for row in records:
+ check, result = row
+ self.checks[check]["success"] = _parse_boolean(str(result))
+
+ failed_tests = [
+ f"\tCheck: {check},\n\tCheck Values: {check_values}\n"
+ for check, check_values in self.checks.items()
+ if not check_values["success"]
+ ]
+ if failed_tests:
+ exception_string = (
+ f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}\n"
+ f"The following tests have failed:\n{', '.join(failed_tests)}"
+ )
+ self._raise_exception(exception_string)
+
+ self.log.info("All tests have passed")
+
+ def _generate_sql_query(self):
+ def _generate_partition_clause(check_name):
+ if self.partition_clause and "partition_clause" not in self.checks[check_name]:
+ return f"WHERE {self.partition_clause}"
+ elif not self.partition_clause and "partition_clause" in self.checks[check_name]:
+ return f"WHERE {self.checks[check_name]['partition_clause']}"
+ elif self.partition_clause and "partition_clause" in self.checks[check_name]:
+ return f"WHERE {self.partition_clause} AND {self.checks[check_name]['partition_clause']}"
+ else:
+ return ""
+
+ return "UNION ALL".join(
+ self.sql_check_template.format(
+ check_statement=value["check_statement"],
+ check_name=check_name,
+ table=self.table,
+ partition_clause=_generate_partition_clause(check_name),
+ )
+ for check_name, value in self.checks.items()
+ )
+
+
+class SQLCheckOperator(BaseSQLOperator):
+ """
+ Performs checks against a db. The ``SQLCheckOperator`` expects
+ a sql query that will return a single row. Each value on that
+ first row is evaluated using python ``bool`` casting. If any of the
+ values return ``False`` the check is failed and errors out.
+
+ Note that Python bool casting evals the following as ``False``:
+
+ * ``False``
+ * ``0``
+ * Empty string (``""``)
+ * Empty list (``[]``)
+ * Empty dictionary or set (``{}``)
+
+ Given a query like ``SELECT COUNT(*) FROM foo``, it will fail only if
+ the count ``== 0``. You can craft much more complex query that could,
+ for instance, check that the table has the same number of rows as
+ the source table upstream, or that the count of today's partition is
+ greater than yesterday's partition, or that a set of metrics are less
+ than 3 standard deviation for the 7 day average.
+
+ This operator can be used as a data quality check in your pipeline, and
+ depending on where you put it in your DAG, you have the choice to
+ stop the critical path, preventing from
+ publishing dubious data, or on the side and receive email alerts
+ without stopping the progress of the DAG.
+
+ :param sql: the sql to be executed. (templated)
+ :param conn_id: the connection ID used to connect to the database.
+ :param database: name of database which overwrite the defined one in connection
+ :param parameters: (optional) the parameters to render the SQL query with.
+ """
+
+ template_fields: Sequence[str] = ("sql",)
+ template_ext: Sequence[str] = (
+ ".hql",
+ ".sql",
+ )
+ template_fields_renderers = {"sql": "sql"}
+ ui_color = "#fff7e6"
+
+ def __init__(
+ self,
+ *,
+ sql: str,
+ conn_id: str | None = None,
+ database: str | None = None,
+ parameters: Iterable | Mapping | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(conn_id=conn_id, database=database, **kwargs)
+ self.sql = sql
+ self.parameters = parameters
+
+ def execute(self, context: Context):
+ self.log.info("Executing SQL check: %s", self.sql)
+ records = self.get_db_hook().get_first(self.sql, self.parameters)
+
+ self.log.info("Record: %s", records)
+ if not records:
+ self._raise_exception(f"The following query returned zero rows: {self.sql}")
+ elif not all(bool(r) for r in records):
+ self._raise_exception(f"Test failed.\nQuery:\n{self.sql}\nResults:\n{records!s}")
+
+ self.log.info("Success.")
+
+
+class SQLValueCheckOperator(BaseSQLOperator):
+ """
+ Performs a simple value check using sql code.
+
+ :param sql: the sql to be executed. (templated)
+ :param conn_id: the connection ID used to connect to the database.
+ :param database: name of database which overwrite the defined one in connection
+ """
+
+ __mapper_args__ = {"polymorphic_identity": "SQLValueCheckOperator"}
+ template_fields: Sequence[str] = (
+ "sql",
+ "pass_value",
+ )
+ template_ext: Sequence[str] = (
+ ".hql",
+ ".sql",
+ )
+ template_fields_renderers = {"sql": "sql"}
+ ui_color = "#fff7e6"
+
+ def __init__(
+ self,
+ *,
+ sql: str,
+ pass_value: Any,
+ tolerance: Any = None,
+ conn_id: str | None = None,
+ database: str | None = None,
+ **kwargs,
+ ):
+ super().__init__(conn_id=conn_id, database=database, **kwargs)
+ self.sql = sql
+ self.pass_value = str(pass_value)
+ tol = _convert_to_float_if_possible(tolerance)
+ self.tol = tol if isinstance(tol, float) else None
+ self.has_tolerance = self.tol is not None
+
+ def execute(self, context: Context):
+ self.log.info("Executing SQL check: %s", self.sql)
+ records = self.get_db_hook().get_first(self.sql)
+
+ if not records:
+ self._raise_exception(f"The following query returned zero rows: {self.sql}")
+
+ pass_value_conv = _convert_to_float_if_possible(self.pass_value)
+ is_numeric_value_check = isinstance(pass_value_conv, float)
+
+ tolerance_pct_str = str(self.tol * 100) + "%" if self.tol is not None else None
+ error_msg = (
+ "Test failed.\nPass value:{pass_value_conv}\n"
+ "Tolerance:{tolerance_pct_str}\n"
+ "Query:\n{sql}\nResults:\n{records!s}"
+ ).format(
+ pass_value_conv=pass_value_conv,
+ tolerance_pct_str=tolerance_pct_str,
+ sql=self.sql,
+ records=records,
+ )
+
+ if not is_numeric_value_check:
+ tests = self._get_string_matches(records, pass_value_conv)
+ elif is_numeric_value_check:
+ try:
+ numeric_records = self._to_float(records)
+ except (ValueError, TypeError):
+ raise AirflowException(f"Converting a result to float failed.\n{error_msg}")
+ tests = self._get_numeric_matches(numeric_records, pass_value_conv)
+ else:
+ tests = []
+
+ if not all(tests):
+ self._raise_exception(error_msg)
+
+ def _to_float(self, records):
+ return [float(record) for record in records]
+
+ def _get_string_matches(self, records, pass_value_conv):
+ return [str(record) == pass_value_conv for record in records]
+
+ def _get_numeric_matches(self, numeric_records, numeric_pass_value_conv):
+ if self.has_tolerance:
+ return [
+ numeric_pass_value_conv * (1 - self.tol) <= record <= numeric_pass_value_conv * (1 + self.tol)
+ for record in numeric_records
+ ]
+
+ return [record == numeric_pass_value_conv for record in numeric_records]
+
+
+class SQLIntervalCheckOperator(BaseSQLOperator):
+ """
+ Checks that the values of metrics given as SQL expressions are within
+ a certain tolerance of the ones from days_back before.
+
+ :param table: the table name
+ :param conn_id: the connection ID used to connect to the database.
+ :param database: name of database which will overwrite the defined one in connection
+ :param days_back: number of days between ds and the ds we want to check
+ against. Defaults to 7 days
+ :param date_filter_column: The column name for the dates to filter on. Defaults to 'ds'
+ :param ratio_formula: which formula to use to compute the ratio between
+ the two metrics. Assuming cur is the metric of today and ref is
+ the metric to today - days_back.
+
+ max_over_min: computes max(cur, ref) / min(cur, ref)
+ relative_diff: computes abs(cur-ref) / ref
+
+ Default: 'max_over_min'
+ :param ignore_zero: whether we should ignore zero metrics
+ :param metrics_thresholds: a dictionary of ratios indexed by metrics
+ """
+
+ __mapper_args__ = {"polymorphic_identity": "SQLIntervalCheckOperator"}
+ template_fields: Sequence[str] = ("sql1", "sql2")
+ template_ext: Sequence[str] = (
+ ".hql",
+ ".sql",
+ )
+ template_fields_renderers = {"sql1": "sql", "sql2": "sql"}
+ ui_color = "#fff7e6"
+
+ ratio_formulas = {
+ "max_over_min": lambda cur, ref: float(max(cur, ref)) / min(cur, ref),
+ "relative_diff": lambda cur, ref: float(abs(cur - ref)) / ref,
+ }
+
+ def __init__(
+ self,
+ *,
+ table: str,
+ metrics_thresholds: dict[str, int],
+ date_filter_column: str | None = "ds",
+ days_back: SupportsAbs[int] = -7,
+ ratio_formula: str | None = "max_over_min",
+ ignore_zero: bool = True,
+ conn_id: str | None = None,
+ database: str | None = None,
+ **kwargs,
+ ):
+ super().__init__(conn_id=conn_id, database=database, **kwargs)
+ if ratio_formula not in self.ratio_formulas:
+ msg_template = "Invalid diff_method: {diff_method}. Supported diff methods are: {diff_methods}"
+
+ raise AirflowFailException(
+ msg_template.format(diff_method=ratio_formula, diff_methods=self.ratio_formulas)
+ )
+ self.ratio_formula = ratio_formula
+ self.ignore_zero = ignore_zero
+ self.table = table
+ self.metrics_thresholds = metrics_thresholds
+ self.metrics_sorted = sorted(metrics_thresholds.keys())
+ self.date_filter_column = date_filter_column
+ self.days_back = -abs(days_back)
+ sqlexp = ", ".join(self.metrics_sorted)
+ sqlt = f"SELECT {sqlexp} FROM {table} WHERE {date_filter_column}="
+
+ self.sql1 = sqlt + "'{{ ds }}'"
+ self.sql2 = sqlt + "'{{ macros.ds_add(ds, " + str(self.days_back) + ") }}'"
+
+ def execute(self, context: Context):
+ hook = self.get_db_hook()
+ self.log.info("Using ratio formula: %s", self.ratio_formula)
+ self.log.info("Executing SQL check: %s", self.sql2)
+ row2 = hook.get_first(self.sql2)
+ self.log.info("Executing SQL check: %s", self.sql1)
+ row1 = hook.get_first(self.sql1)
+
+ if not row2:
+ self._raise_exception(f"The following query returned zero rows: {self.sql2}")
+ if not row1:
+ self._raise_exception(f"The following query returned zero rows: {self.sql1}")
+
+ current = dict(zip(self.metrics_sorted, row1))
+ reference = dict(zip(self.metrics_sorted, row2))
+
+ ratios: dict[str, int | None] = {}
+ test_results = {}
+
+ for metric in self.metrics_sorted:
+ cur = current[metric]
+ ref = reference[metric]
+ threshold = self.metrics_thresholds[metric]
+ if cur == 0 or ref == 0:
+ ratios[metric] = None
+ test_results[metric] = self.ignore_zero
+ else:
+ ratio_metric = self.ratio_formulas[self.ratio_formula](current[metric], reference[metric])
+ ratios[metric] = ratio_metric
+ if ratio_metric is not None:
+ test_results[metric] = ratio_metric < threshold
+ else:
+ test_results[metric] = self.ignore_zero
+
+ self.log.info(
+ (
+ "Current metric for %s: %s\n"
+ "Past metric for %s: %s\n"
+ "Ratio for %s: %s\n"
+ "Threshold: %s\n"
+ ),
+ metric,
+ cur,
+ metric,
+ ref,
+ metric,
+ ratios[metric],
+ threshold,
+ )
+
+ if not all(test_results.values()):
+ failed_tests = [it[0] for it in test_results.items() if not it[1]]
+ self.log.warning(
+ "The following %s tests out of %s failed:",
+ len(failed_tests),
+ len(self.metrics_sorted),
+ )
+ for k in failed_tests:
+ self.log.warning(
+ "'%s' check failed. %s is above %s",
+ k,
+ ratios[k],
+ self.metrics_thresholds[k],
+ )
+ self._raise_exception(f"The following tests have failed:\n {', '.join(sorted(failed_tests))}")
+
+ self.log.info("All tests have passed")
+
+
+class SQLThresholdCheckOperator(BaseSQLOperator):
+ """
+ Performs a value check using sql code against a minimum threshold
+ and a maximum threshold. Thresholds can be in the form of a numeric
+ value OR a sql statement that results a numeric.
+
+ :param sql: the sql to be executed. (templated)
+ :param conn_id: the connection ID used to connect to the database.
+ :param database: name of database which overwrite the defined one in connection
+ :param min_threshold: numerical value or min threshold sql to be executed (templated)
+ :param max_threshold: numerical value or max threshold sql to be executed (templated)
+ """
+
+ template_fields: Sequence[str] = ("sql", "min_threshold", "max_threshold")
+ template_ext: Sequence[str] = (
+ ".hql",
+ ".sql",
+ )
+ template_fields_renderers = {"sql": "sql"}
+
+ def __init__(
+ self,
+ *,
+ sql: str,
+ min_threshold: Any,
+ max_threshold: Any,
+ conn_id: str | None = None,
+ database: str | None = None,
+ **kwargs,
+ ):
+ super().__init__(conn_id=conn_id, database=database, **kwargs)
+ self.sql = sql
+ self.min_threshold = _convert_to_float_if_possible(min_threshold)
+ self.max_threshold = _convert_to_float_if_possible(max_threshold)
+
+ def execute(self, context: Context):
+ hook = self.get_db_hook()
+ result = hook.get_first(self.sql)[0]
+ if not result:
+ self._raise_exception(f"The following query returned zero rows: {self.sql}")
+
+ if isinstance(self.min_threshold, float):
+ lower_bound = self.min_threshold
+ else:
+ lower_bound = hook.get_first(self.min_threshold)[0]
+
+ if isinstance(self.max_threshold, float):
+ upper_bound = self.max_threshold
+ else:
+ upper_bound = hook.get_first(self.max_threshold)[0]
+
+ meta_data = {
+ "result": result,
+ "task_id": self.task_id,
+ "min_threshold": lower_bound,
+ "max_threshold": upper_bound,
+ "within_threshold": lower_bound <= result <= upper_bound,
+ }
+
+ self.push(meta_data)
+ if not meta_data["within_threshold"]:
+ result = (
+ round(meta_data.get("result"), 2) # type: ignore[arg-type]
+ if meta_data.get("result") is not None
+ else ""
+ )
+ error_msg = (
+ f'Threshold Check: "{meta_data.get("task_id")}" failed.\n'
+ f'DAG: {self.dag_id}\nTask_id: {meta_data.get("task_id")}\n'
+ f'Check description: {meta_data.get("description")}\n'
+ f"SQL: {self.sql}\n"
+ f"Result: {result} is not within thresholds "
+ f'{meta_data.get("min_threshold")} and {meta_data.get("max_threshold")}'
+ )
+ self._raise_exception(error_msg)
+
+ self.log.info("Test %s Successful.", self.task_id)
+
+ def push(self, meta_data):
+ """
+ Optional: Send data check info and metadata to an external database.
+ Default functionality will log metadata.
+ """
+ info = "\n".join(f"""{key}: {item}""" for key, item in meta_data.items())
+ self.log.info("Log from %s:\n%s", self.dag_id, info)
+
+
+class BranchSQLOperator(BaseSQLOperator, SkipMixin):
+ """
+ Allows a DAG to "branch" or follow a specified path based on the results of a SQL query.
+
+ :param sql: The SQL code to be executed, should return true or false (templated)
+ Template reference are recognized by str ending in '.sql'.
+ Expected SQL query to return Boolean (True/False), integer (0 = False, Otherwise = 1)
+ or string (true/y/yes/1/on/false/n/no/0/off).
+ :param follow_task_ids_if_true: task id or task ids to follow if query returns true
+ :param follow_task_ids_if_false: task id or task ids to follow if query returns false
+ :param conn_id: the connection ID used to connect to the database.
+ :param database: name of database which overwrite the defined one in connection
+ :param parameters: (optional) the parameters to render the SQL query with.
+ """
+
+ template_fields: Sequence[str] = ("sql",)
+ template_ext: Sequence[str] = (".sql",)
+ template_fields_renderers = {"sql": "sql"}
+ ui_color = "#a22034"
+ ui_fgcolor = "#F7F7F7"
+
+ def __init__(
+ self,
+ *,
+ sql: str,
+ follow_task_ids_if_true: list[str],
+ follow_task_ids_if_false: list[str],
+ conn_id: str = "default_conn_id",
+ database: str | None = None,
+ parameters: Iterable | Mapping | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(conn_id=conn_id, database=database, **kwargs)
+ self.sql = sql
+ self.parameters = parameters
+ self.follow_task_ids_if_true = follow_task_ids_if_true
+ self.follow_task_ids_if_false = follow_task_ids_if_false
+
+ def execute(self, context: Context):
+ self.log.info(
+ "Executing: %s (with parameters %s) with connection: %s",
+ self.sql,
+ self.parameters,
+ self.conn_id,
+ )
+ record = self.get_db_hook().get_first(self.sql, self.parameters)
+ if not record:
+ raise AirflowException(
+ "No rows returned from sql query. Operator expected True or False return value."
+ )
+
+ if isinstance(record, list):
+ if isinstance(record[0], list):
+ query_result = record[0][0]
+ else:
+ query_result = record[0]
+ elif isinstance(record, tuple):
+ query_result = record[0]
+ else:
+ query_result = record
+
+ self.log.info("Query returns %s, type '%s'", query_result, type(query_result))
+
+ follow_branch = None
+ try:
+ if isinstance(query_result, bool):
+ if query_result:
+ follow_branch = self.follow_task_ids_if_true
+ elif isinstance(query_result, str):
+ # return result is not Boolean, try to convert from String to Boolean
+ if _parse_boolean(query_result):
+ follow_branch = self.follow_task_ids_if_true
+ elif isinstance(query_result, int):
+ if bool(query_result):
+ follow_branch = self.follow_task_ids_if_true
+ else:
+ raise AirflowException(
+ f"Unexpected query return result '{query_result}' type '{type(query_result)}'"
+ )
+
+ if follow_branch is None:
+ follow_branch = self.follow_task_ids_if_false
+ except ValueError:
+ raise AirflowException(
+ f"Unexpected query return result '{query_result}' type '{type(query_result)}'"
+ )
+
+ self.skip_all_except(context["ti"], follow_branch)
diff --git a/airflow/providers/common/sql/provider.yaml b/airflow/providers/common/sql/provider.yaml
new file mode 100644
index 0000000000000..4527dfff586a6
--- /dev/null
+++ b/airflow/providers/common/sql/provider.yaml
@@ -0,0 +1,60 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+---
+package-name: apache-airflow-providers-common-sql
+name: Common SQL
+description: |
+ `Common SQL Provider `__
+
+versions:
+ - 1.3.1
+ - 1.3.0
+ - 1.2.0
+ - 1.1.0
+ - 1.0.0
+
+dependencies:
+ - sqlparse>=0.4.2
+
+additional-extras:
+ - name: pandas
+ dependencies:
+ - pandas>=0.17.1
+
+integrations:
+ - integration-name: Common SQL
+ external-doc-url: https://en.wikipedia.org/wiki/SQL
+ how-to-guide:
+ - /docs/apache-airflow-providers-common-sql/operators.rst
+ logo: /integration-logos/common/sql/sql.png
+ tags: [software]
+
+operators:
+ - integration-name: Common SQL
+ python-modules:
+ - airflow.providers.common.sql.operators.sql
+
+hooks:
+ - integration-name: Common SQL
+ python-modules:
+ - airflow.providers.common.sql.hooks.sql
+
+sensors:
+ - integration-name: Common SQL
+ python-modules:
+ - airflow.providers.common.sql.sensors.sql
diff --git a/airflow/providers/jdbc/example_dags/__init__.py b/airflow/providers/common/sql/sensors/__init__.py
similarity index 100%
rename from airflow/providers/jdbc/example_dags/__init__.py
rename to airflow/providers/common/sql/sensors/__init__.py
diff --git a/airflow/providers/common/sql/sensors/sql.py b/airflow/providers/common/sql/sensors/sql.py
new file mode 100644
index 0000000000000..d58802dc98b03
--- /dev/null
+++ b/airflow/providers/common/sql/sensors/sql.py
@@ -0,0 +1,110 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import Any, Sequence
+
+from airflow import AirflowException
+from airflow.hooks.base import BaseHook
+from airflow.providers.common.sql.hooks.sql import DbApiHook
+from airflow.sensors.base import BaseSensorOperator
+
+
+class SqlSensor(BaseSensorOperator):
+ """
+ Runs a sql statement repeatedly until a criteria is met. It will keep trying until
+ success or failure criteria are met, or if the first cell is not in (0, '0', '', None).
+ Optional success and failure callables are called with the first cell returned as the argument.
+ If success callable is defined the sensor will keep retrying until the criteria is met.
+ If failure callable is defined and the criteria is met the sensor will raise AirflowException.
+ Failure criteria is evaluated before success criteria. A fail_on_empty boolean can also
+ be passed to the sensor in which case it will fail if no rows have been returned
+
+ :param conn_id: The connection to run the sensor against
+ :param sql: The sql to run. To pass, it needs to return at least one cell
+ that contains a non-zero / empty string value.
+ :param parameters: The parameters to render the SQL query with (optional).
+ :param success: Success criteria for the sensor is a Callable that takes first_cell
+ as the only argument, and returns a boolean (optional).
+ :param failure: Failure criteria for the sensor is a Callable that takes first_cell
+ as the only argument and return a boolean (optional).
+ :param fail_on_empty: Explicitly fail on no rows returned.
+ :param hook_params: Extra config params to be passed to the underlying hook.
+ Should match the desired hook constructor params.
+ """
+
+ template_fields: Sequence[str] = ("sql",)
+ template_ext: Sequence[str] = (
+ ".hql",
+ ".sql",
+ )
+ ui_color = "#7c7287"
+
+ def __init__(
+ self,
+ *,
+ conn_id,
+ sql,
+ parameters=None,
+ success=None,
+ failure=None,
+ fail_on_empty=False,
+ hook_params=None,
+ **kwargs,
+ ):
+ self.conn_id = conn_id
+ self.sql = sql
+ self.parameters = parameters
+ self.success = success
+ self.failure = failure
+ self.fail_on_empty = fail_on_empty
+ self.hook_params = hook_params
+ super().__init__(**kwargs)
+
+ def _get_hook(self):
+ conn = BaseHook.get_connection(self.conn_id)
+ hook = conn.get_hook(hook_params=self.hook_params)
+ if not isinstance(hook, DbApiHook):
+ raise AirflowException(
+ f"The connection type is not supported by {self.__class__.__name__}. "
+ f"The associated hook should be a subclass of `DbApiHook`. Got {hook.__class__.__name__}"
+ )
+ return hook
+
+ def poke(self, context: Any):
+ hook = self._get_hook()
+
+ self.log.info("Poking: %s (with parameters %s)", self.sql, self.parameters)
+ records = hook.get_records(self.sql, self.parameters)
+ if not records:
+ if self.fail_on_empty:
+ raise AirflowException("No rows returned, raising as per fail_on_empty flag")
+ else:
+ return False
+ first_cell = records[0][0]
+ if self.failure is not None:
+ if callable(self.failure):
+ if self.failure(first_cell):
+ raise AirflowException(f"Failure criteria met. self.failure({first_cell}) returned True")
+ else:
+ raise AirflowException(f"self.failure is present, but not callable -> {self.failure}")
+ if self.success is not None:
+ if callable(self.success):
+ return self.success(first_cell)
+ else:
+ raise AirflowException(f"self.success is present, but not callable -> {self.success}")
+ return bool(first_cell)
diff --git a/airflow/providers/databricks/CHANGELOG.rst b/airflow/providers/databricks/CHANGELOG.rst
index beb539e8a7b69..0d446920542b1 100644
--- a/airflow/providers/databricks/CHANGELOG.rst
+++ b/airflow/providers/databricks/CHANGELOG.rst
@@ -16,9 +16,137 @@
under the License.
+.. NOTE TO CONTRIBUTORS:
+ Please, only add notes to the Changelog just below the "Changelog" header when there are some breaking changes
+ and you want to add an explanation to the users on how they are supposed to deal with them.
+ The changelog is updated and maintained semi-automatically by release manager.
+
Changelog
---------
+3.4.0
+.....
+
+This release of provider is only available for Airflow 2.3+ as explained in the
+`Apache Airflow providers support policy `_.
+
+Misc
+~~~~
+
+* ``Move min airflow version to 2.3.0 for all providers (#27196)``
+* ``Replace urlparse with urlsplit (#27389)``
+
+Features
+~~~~~~~~
+
+* ``Add SQLExecuteQueryOperator (#25717)``
+* ``Use new job search API for triggering Databricks job by name (#27446)``
+
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Update old style typing (#26872)``
+ * ``Enable string normalization in python formatting - providers (#27205)``
+
+3.3.0
+.....
+
+Features
+~~~~~~~~
+
+* ``DatabricksSubmitRunOperator dbt task support (#25623)``
+
+Misc
+~~~~
+
+* ``Add common-sql lower bound for common-sql (#25789)``
+* ``Remove duplicated connection-type within the provider (#26628)``
+
+Bug Fixes
+~~~~~~~~~
+
+* ``Databricks: fix provider name in the User-Agent string (#25873)``
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Apply PEP-563 (Postponed Evaluation of Annotations) to non-core airflow (#26289)``
+ * ``D400 first line should end with period batch02 (#25268)``
+
+3.2.0
+.....
+
+Features
+~~~~~~~~
+
+* ``Databricks: update user-agent string (#25578)``
+* ``More improvements in the Databricks operators (#25260)``
+* ``Improved telemetry for Databricks provider (#25115)``
+* ``Unify DbApiHook.run() method with the methods which override it (#23971)``
+
+Bug Fixes
+~~~~~~~~~
+
+* ``Databricks: fix test_connection implementation (#25114)``
+* ``Do not convert boolean values to string in deep_string_coerce function (#25394)``
+* ``Correctly handle output of the failed tasks (#25427)``
+* ``Databricks: Fix provider for Airflow 2.2.x (#25674)``
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``updated documentation for databricks operator (#24599)``
+ * ``Prepare docs for new providers release (August 2022) (#25618)``
+
+3.1.0
+.....
+
+Features
+~~~~~~~~
+
+* ``Added databricks_conn_id as templated field (#24945)``
+* ``Add 'test_connection' method to Databricks hook (#24617)``
+* ``Move all SQL classes to common-sql provider (#24836)``
+
+Bug Fixes
+~~~~~~~~~
+
+* ``Update providers to use functools compat for ''cached_property'' (#24582)``
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``Automatically detect if non-lazy logging interpolation is used (#24910)``
+ * ``Remove "bad characters" from our codebase (#24841)``
+ * ``Move provider dependencies to inside provider folders (#24672)``
+ * ``Remove 'hook-class-names' from provider.yaml (#24702)``
+
+3.0.0
+.....
+
+Breaking changes
+~~~~~~~~~~~~~~~~
+
+* This release of provider is only available for Airflow 2.2+ as explained in the Apache Airflow
+ providers support policy https://github.com/apache/airflow/blob/main/README.md#support-for-providers
+
+Features
+~~~~~~~~
+
+* ``Add Deferrable Databricks operators (#19736)``
+* ``Add git_source to DatabricksSubmitRunOperator (#23620)``
+
+Bug Fixes
+~~~~~~~~~
+
+* ``fix: DatabricksSubmitRunOperator and DatabricksRunNowOperator cannot define .json as template_ext (#23622) (#23641)``
+* ``Fix UnboundLocalError when sql is empty list in DatabricksSqlHook (#23815)``
+
+.. Below changes are excluded from the changelog. Move them to
+ appropriate section above if needed. Do not delete the lines(!):
+ * ``AIP-47 - Migrate databricks DAGs to new design #22442 (#24203)``
+ * ``Introduce 'flake8-implicit-str-concat' plugin to static checks (#23873)``
+ * ``Add explanatory note for contributors about updating Changelog (#24229)``
+ * ``Prepare docs for May 2022 provider's release (#24231)``
+ * ``Update package description to remove double min-airflow specification (#24292)``
+
2.7.0
.....
@@ -62,7 +190,6 @@ Misc
* ``Fix new MyPy errors in main (#22884)``
* ``Prepare mid-April provider documentation. (#22819)``
-.. Review and move the new changes to one of the sections above:
* ``Prepare for RC2 release of March Databricks provider (#22979)``
2.5.0
diff --git a/airflow/providers/databricks/example_dags/example_databricks.py b/airflow/providers/databricks/example_dags/example_databricks.py
deleted file mode 100644
index bea9038afeb0a..0000000000000
--- a/airflow/providers/databricks/example_dags/example_databricks.py
+++ /dev/null
@@ -1,75 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-This is an example DAG which uses the DatabricksSubmitRunOperator.
-In this example, we create two tasks which execute sequentially.
-The first task is to run a notebook at the workspace path "/test"
-and the second task is to run a JAR uploaded to DBFS. Both,
-tasks use new clusters.
-
-Because we have set a downstream dependency on the notebook task,
-the spark jar task will NOT run until the notebook task completes
-successfully.
-
-The definition of a successful run is if the run has a result_state of "SUCCESS".
-For more information about the state of a run refer to
-https://docs.databricks.com/api/latest/jobs.html#runstate
-"""
-
-from datetime import datetime
-
-from airflow import DAG
-from airflow.providers.databricks.operators.databricks import DatabricksSubmitRunOperator
-
-with DAG(
- dag_id='example_databricks_operator',
- schedule_interval='@daily',
- start_date=datetime(2021, 1, 1),
- tags=['example'],
- catchup=False,
-) as dag:
- # [START howto_operator_databricks_json]
- # Example of using the JSON parameter to initialize the operator.
- new_cluster = {
- 'spark_version': '9.1.x-scala2.12',
- 'node_type_id': 'r3.xlarge',
- 'aws_attributes': {'availability': 'ON_DEMAND'},
- 'num_workers': 8,
- }
-
- notebook_task_params = {
- 'new_cluster': new_cluster,
- 'notebook_task': {
- 'notebook_path': '/Users/airflow@example.com/PrepareData',
- },
- }
-
- notebook_task = DatabricksSubmitRunOperator(task_id='notebook_task', json=notebook_task_params)
- # [END howto_operator_databricks_json]
-
- # [START howto_operator_databricks_named]
- # Example of using the named parameters of DatabricksSubmitRunOperator
- # to initialize the operator.
- spark_jar_task = DatabricksSubmitRunOperator(
- task_id='spark_jar_task',
- new_cluster=new_cluster,
- spark_jar_task={'main_class_name': 'com.example.ProcessData'},
- libraries=[{'jar': 'dbfs:/lib/etl-0.1.jar'}],
- )
- # [END howto_operator_databricks_named]
- notebook_task >> spark_jar_task
diff --git a/airflow/providers/databricks/example_dags/example_databricks_repos.py b/airflow/providers/databricks/example_dags/example_databricks_repos.py
deleted file mode 100644
index e33d32044f5df..0000000000000
--- a/airflow/providers/databricks/example_dags/example_databricks_repos.py
+++ /dev/null
@@ -1,74 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-from datetime import datetime
-
-from airflow import DAG
-from airflow.providers.databricks.operators.databricks import DatabricksSubmitRunOperator
-from airflow.providers.databricks.operators.databricks_repos import (
- DatabricksReposCreateOperator,
- DatabricksReposDeleteOperator,
- DatabricksReposUpdateOperator,
-)
-
-default_args = {
- 'owner': 'airflow',
- 'databricks_conn_id': 'databricks',
-}
-
-with DAG(
- dag_id='example_databricks_repos_operator',
- schedule_interval='@daily',
- start_date=datetime(2021, 1, 1),
- default_args=default_args,
- tags=['example'],
- catchup=False,
-) as dag:
- # [START howto_operator_databricks_repo_create]
- # Example of creating a Databricks Repo
- repo_path = "/Repos/user@domain.com/demo-repo"
- git_url = "https://github.com/test/test"
- create_repo = DatabricksReposCreateOperator(task_id='create_repo', repo_path=repo_path, git_url=git_url)
- # [END howto_operator_databricks_repo_create]
-
- # [START howto_operator_databricks_repo_update]
- # Example of updating a Databricks Repo to the latest code
- repo_path = "/Repos/user@domain.com/demo-repo"
- update_repo = DatabricksReposUpdateOperator(task_id='update_repo', repo_path=repo_path, branch="releases")
- # [END howto_operator_databricks_repo_update]
-
- notebook_task_params = {
- 'new_cluster': {
- 'spark_version': '9.1.x-scala2.12',
- 'node_type_id': 'r3.xlarge',
- 'aws_attributes': {'availability': 'ON_DEMAND'},
- 'num_workers': 8,
- },
- 'notebook_task': {
- 'notebook_path': f'{repo_path}/PrepareData',
- },
- }
-
- notebook_task = DatabricksSubmitRunOperator(task_id='notebook_task', json=notebook_task_params)
-
- # [START howto_operator_databricks_repo_delete]
- # Example of deleting a Databricks Repo
- repo_path = "/Repos/user@domain.com/demo-repo"
- delete_repo = DatabricksReposDeleteOperator(task_id='delete_repo', repo_path=repo_path)
- # [END howto_operator_databricks_repo_delete]
-
- (create_repo >> update_repo >> notebook_task >> delete_repo)
diff --git a/airflow/providers/databricks/example_dags/example_databricks_sql.py b/airflow/providers/databricks/example_dags/example_databricks_sql.py
deleted file mode 100644
index 6032c0fb03032..0000000000000
--- a/airflow/providers/databricks/example_dags/example_databricks_sql.py
+++ /dev/null
@@ -1,113 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""
-This is an example DAG which uses the DatabricksSubmitRunOperator.
-In this example, we create two tasks which execute sequentially.
-The first task is to run a notebook at the workspace path "/test"
-and the second task is to run a JAR uploaded to DBFS. Both,
-tasks use new clusters.
-
-Because we have set a downstream dependency on the notebook task,
-the spark jar task will NOT run until the notebook task completes
-successfully.
-
-The definition of a successful run is if the run has a result_state of "SUCCESS".
-For more information about the state of a run refer to
-https://docs.databricks.com/api/latest/jobs.html#runstate
-"""
-
-from datetime import datetime
-
-from airflow import DAG
-from airflow.providers.databricks.operators.databricks_sql import (
- DatabricksCopyIntoOperator,
- DatabricksSqlOperator,
-)
-
-with DAG(
- dag_id='example_databricks_sql_operator',
- schedule_interval='@daily',
- start_date=datetime(2021, 1, 1),
- tags=['example'],
- catchup=False,
-) as dag:
- connection_id = 'my_connection'
- sql_endpoint_name = "My Endpoint"
-
- # [START howto_operator_databricks_sql_multiple]
- # Example of using the Databricks SQL Operator to perform multiple operations.
- create = DatabricksSqlOperator(
- databricks_conn_id=connection_id,
- sql_endpoint_name=sql_endpoint_name,
- task_id='create_and_populate_table',
- sql=[
- "drop table if exists default.my_airflow_table",
- "create table default.my_airflow_table(id int, v string)",
- "insert into default.my_airflow_table values (1, 'test 1'), (2, 'test 2')",
- ],
- )
- # [END howto_operator_databricks_sql_multiple]
-
- # [START howto_operator_databricks_sql_select]
- # Example of using the Databricks SQL Operator to select data.
- select = DatabricksSqlOperator(
- databricks_conn_id=connection_id,
- sql_endpoint_name=sql_endpoint_name,
- task_id='select_data',
- sql="select * from default.my_airflow_table",
- )
- # [END howto_operator_databricks_sql_select]
-
- # [START howto_operator_databricks_sql_select_file]
- # Example of using the Databricks SQL Operator to select data into a file with JSONL format.
- select_into_file = DatabricksSqlOperator(
- databricks_conn_id=connection_id,
- sql_endpoint_name=sql_endpoint_name,
- task_id='select_data_into_file',
- sql="select * from default.my_airflow_table",
- output_path="/tmp/1.jsonl",
- output_format="jsonl",
- )
- # [END howto_operator_databricks_sql_select_file]
-
- # [START howto_operator_databricks_sql_multiple_file]
- # Example of using the Databricks SQL Operator to select data.
- # SQL statements should be in the file with name test.sql
- create_file = DatabricksSqlOperator(
- databricks_conn_id=connection_id,
- sql_endpoint_name=sql_endpoint_name,
- task_id='create_and_populate_from_file',
- sql="test.sql",
- )
- # [END howto_operator_databricks_sql_multiple_file]
-
- # [START howto_operator_databricks_copy_into]
- # Example of importing data using COPY_INTO SQL command
- import_csv = DatabricksCopyIntoOperator(
- task_id='import_csv',
- databricks_conn_id=connection_id,
- sql_endpoint_name=sql_endpoint_name,
- table_name="my_table",
- file_format="CSV",
- file_location="abfss://container@account.dfs.core.windows.net/my-data/csv",
- format_options={'header': 'true'},
- force_copy=True,
- )
- # [END howto_operator_databricks_copy_into]
-
- (create >> create_file >> import_csv >> select >> select_into_file)
diff --git a/airflow/providers/databricks/hooks/databricks.py b/airflow/providers/databricks/hooks/databricks.py
index 400bbe895588a..bb8a1dc88080e 100644
--- a/airflow/providers/databricks/hooks/databricks.py
+++ b/airflow/providers/databricks/hooks/databricks.py
@@ -25,8 +25,10 @@
or the ``api/2.1/jobs/runs/submit``
`endpoint `_.
"""
+from __future__ import annotations
+
import json
-from typing import Any, Dict, List, Optional
+from typing import Any
from requests import exceptions as requests_exceptions
@@ -37,27 +39,29 @@
START_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/start")
TERMINATE_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/delete")
-RUN_NOW_ENDPOINT = ('POST', 'api/2.1/jobs/run-now')
-SUBMIT_RUN_ENDPOINT = ('POST', 'api/2.1/jobs/runs/submit')
-GET_RUN_ENDPOINT = ('GET', 'api/2.1/jobs/runs/get')
-CANCEL_RUN_ENDPOINT = ('POST', 'api/2.1/jobs/runs/cancel')
-OUTPUT_RUNS_JOB_ENDPOINT = ('GET', 'api/2.1/jobs/runs/get-output')
+RUN_NOW_ENDPOINT = ("POST", "api/2.1/jobs/run-now")
+SUBMIT_RUN_ENDPOINT = ("POST", "api/2.1/jobs/runs/submit")
+GET_RUN_ENDPOINT = ("GET", "api/2.1/jobs/runs/get")
+CANCEL_RUN_ENDPOINT = ("POST", "api/2.1/jobs/runs/cancel")
+OUTPUT_RUNS_JOB_ENDPOINT = ("GET", "api/2.1/jobs/runs/get-output")
+
+INSTALL_LIBS_ENDPOINT = ("POST", "api/2.0/libraries/install")
+UNINSTALL_LIBS_ENDPOINT = ("POST", "api/2.0/libraries/uninstall")
-INSTALL_LIBS_ENDPOINT = ('POST', 'api/2.0/libraries/install')
-UNINSTALL_LIBS_ENDPOINT = ('POST', 'api/2.0/libraries/uninstall')
+LIST_JOBS_ENDPOINT = ("GET", "api/2.1/jobs/list")
-LIST_JOBS_ENDPOINT = ('GET', 'api/2.1/jobs/list')
+WORKSPACE_GET_STATUS_ENDPOINT = ("GET", "api/2.0/workspace/get-status")
-WORKSPACE_GET_STATUS_ENDPOINT = ('GET', 'api/2.0/workspace/get-status')
+RUN_LIFE_CYCLE_STATES = ["PENDING", "RUNNING", "TERMINATING", "TERMINATED", "SKIPPED", "INTERNAL_ERROR"]
-RUN_LIFE_CYCLE_STATES = ['PENDING', 'RUNNING', 'TERMINATING', 'TERMINATED', 'SKIPPED', 'INTERNAL_ERROR']
+SPARK_VERSIONS_ENDPOINT = ("GET", "api/2.0/clusters/spark-versions")
class RunState:
"""Utility class for the run state concept of Databricks runs."""
def __init__(
- self, life_cycle_state: str, result_state: str = '', state_message: str = '', *args, **kwargs
+ self, life_cycle_state: str, result_state: str = "", state_message: str = "", *args, **kwargs
) -> None:
self.life_cycle_state = life_cycle_state
self.result_state = result_state
@@ -69,17 +73,17 @@ def is_terminal(self) -> bool:
if self.life_cycle_state not in RUN_LIFE_CYCLE_STATES:
raise AirflowException(
(
- 'Unexpected life cycle state: {}: If the state has '
- 'been introduced recently, please check the Databricks user '
- 'guide for troubleshooting information'
+ "Unexpected life cycle state: {}: If the state has "
+ "been introduced recently, please check the Databricks user "
+ "guide for troubleshooting information"
).format(self.life_cycle_state)
)
- return self.life_cycle_state in ('TERMINATED', 'SKIPPED', 'INTERNAL_ERROR')
+ return self.life_cycle_state in ("TERMINATED", "SKIPPED", "INTERNAL_ERROR")
@property
def is_successful(self) -> bool:
"""True if the result state is SUCCESS"""
- return self.result_state == 'SUCCESS'
+ return self.result_state == "SUCCESS"
def __eq__(self, other: object) -> bool:
if not isinstance(other, RunState):
@@ -97,7 +101,7 @@ def to_json(self) -> str:
return json.dumps(self.__dict__)
@classmethod
- def from_json(cls, data: str) -> 'RunState':
+ def from_json(cls, data: str) -> RunState:
return RunState(**json.loads(data))
@@ -115,7 +119,7 @@ class DatabricksHook(BaseDatabricksHook):
:param retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class.
"""
- hook_name = 'Databricks'
+ hook_name = "Databricks"
def __init__(
self,
@@ -123,9 +127,10 @@ def __init__(
timeout_seconds: int = 180,
retry_limit: int = 3,
retry_delay: float = 1.0,
- retry_args: Optional[Dict[Any, Any]] = None,
+ retry_args: dict[Any, Any] | None = None,
+ caller: str = "DatabricksHook",
) -> None:
- super().__init__(databricks_conn_id, timeout_seconds, retry_limit, retry_delay, retry_args)
+ super().__init__(databricks_conn_id, timeout_seconds, retry_limit, retry_delay, retry_args, caller)
def run_now(self, json: dict) -> int:
"""
@@ -133,10 +138,9 @@ def run_now(self, json: dict) -> int:
:param json: The data used in the body of the request to the ``run-now`` endpoint.
:return: the run_id as an int
- :rtype: str
"""
response = self._do_api_call(RUN_NOW_ENDPOINT, json)
- return response['run_id']
+ return response["run_id"]
def submit_run(self, json: dict) -> int:
"""
@@ -144,46 +148,53 @@ def submit_run(self, json: dict) -> int:
:param json: The data used in the body of the request to the ``submit`` endpoint.
:return: the run_id as an int
- :rtype: str
"""
response = self._do_api_call(SUBMIT_RUN_ENDPOINT, json)
- return response['run_id']
+ return response["run_id"]
- def list_jobs(self, limit: int = 25, offset: int = 0, expand_tasks: bool = False) -> List[Dict[str, Any]]:
+ def list_jobs(
+ self, limit: int = 25, offset: int = 0, expand_tasks: bool = False, job_name: str | None = None
+ ) -> list[dict[str, Any]]:
"""
Lists the jobs in the Databricks Job Service.
:param limit: The limit/batch size used to retrieve jobs.
:param offset: The offset of the first job to return, relative to the most recently created job.
:param expand_tasks: Whether to include task and cluster details in the response.
+ :param job_name: Optional name of a job to search.
:return: A list of jobs.
"""
has_more = True
- jobs = []
+ all_jobs = []
while has_more:
- json = {
- 'limit': limit,
- 'offset': offset,
- 'expand_tasks': expand_tasks,
+ payload: dict[str, Any] = {
+ "limit": limit,
+ "expand_tasks": expand_tasks,
+ "offset": offset,
}
- response = self._do_api_call(LIST_JOBS_ENDPOINT, json)
- jobs += response['jobs'] if 'jobs' in response else []
- has_more = response.get('has_more', False)
+ if job_name:
+ payload["name"] = job_name
+ response = self._do_api_call(LIST_JOBS_ENDPOINT, payload)
+ jobs = response.get("jobs", [])
+ if job_name:
+ all_jobs += [j for j in jobs if j["settings"]["name"] == job_name]
+ else:
+ all_jobs += jobs
+ has_more = response.get("has_more", False)
if has_more:
- offset += len(response['jobs'])
+ offset += len(jobs)
- return jobs
+ return all_jobs
- def find_job_id_by_name(self, job_name: str) -> Optional[int]:
+ def find_job_id_by_name(self, job_name: str) -> int | None:
"""
Finds job id by its name. If there are multiple jobs with the same name, raises AirflowException.
:param job_name: The name of the job to look up.
:return: The job_id as an int or None if no job was found.
"""
- all_jobs = self.list_jobs()
- matching_jobs = [j for j in all_jobs if j['settings']['name'] == job_name]
+ matching_jobs = self.list_jobs(job_name=job_name)
if len(matching_jobs) > 1:
raise AirflowException(
@@ -193,7 +204,7 @@ def find_job_id_by_name(self, job_name: str) -> Optional[int]:
if not matching_jobs:
return None
else:
- return matching_jobs[0]['job_id']
+ return matching_jobs[0]["job_id"]
def get_run_page_url(self, run_id: int) -> str:
"""
@@ -202,9 +213,9 @@ def get_run_page_url(self, run_id: int) -> str:
:param run_id: id of the run
:return: URL of the run page
"""
- json = {'run_id': run_id}
+ json = {"run_id": run_id}
response = self._do_api_call(GET_RUN_ENDPOINT, json)
- return response['run_page_url']
+ return response["run_page_url"]
async def a_get_run_page_url(self, run_id: int) -> str:
"""
@@ -212,9 +223,9 @@ async def a_get_run_page_url(self, run_id: int) -> str:
:param run_id: id of the run
:return: URL of the run page
"""
- json = {'run_id': run_id}
+ json = {"run_id": run_id}
response = await self._a_do_api_call(GET_RUN_ENDPOINT, json)
- return response['run_page_url']
+ return response["run_page_url"]
def get_job_id(self, run_id: int) -> int:
"""
@@ -223,9 +234,9 @@ def get_job_id(self, run_id: int) -> int:
:param run_id: id of the run
:return: Job id for given Databricks run
"""
- json = {'run_id': run_id}
+ json = {"run_id": run_id}
response = self._do_api_call(GET_RUN_ENDPOINT, json)
- return response['job_id']
+ return response["job_id"]
def get_run_state(self, run_id: int) -> RunState:
"""
@@ -242,9 +253,9 @@ def get_run_state(self, run_id: int) -> RunState:
:param run_id: id of the run
:return: state of the run
"""
- json = {'run_id': run_id}
+ json = {"run_id": run_id}
response = self._do_api_call(GET_RUN_ENDPOINT, json)
- state = response['state']
+ state = response["state"]
return RunState(**state)
async def a_get_run_state(self, run_id: int) -> RunState:
@@ -253,11 +264,33 @@ async def a_get_run_state(self, run_id: int) -> RunState:
:param run_id: id of the run
:return: state of the run
"""
- json = {'run_id': run_id}
+ json = {"run_id": run_id}
response = await self._a_do_api_call(GET_RUN_ENDPOINT, json)
- state = response['state']
+ state = response["state"]
return RunState(**state)
+ def get_run(self, run_id: int) -> dict[str, Any]:
+ """
+ Retrieve run information.
+
+ :param run_id: id of the run
+ :return: state of the run
+ """
+ json = {"run_id": run_id}
+ response = self._do_api_call(GET_RUN_ENDPOINT, json)
+ return response
+
+ async def a_get_run(self, run_id: int) -> dict[str, Any]:
+ """
+ Async version of `get_run`.
+
+ :param run_id: id of the run
+ :return: state of the run
+ """
+ json = {"run_id": run_id}
+ response = await self._a_do_api_call(GET_RUN_ENDPOINT, json)
+ return response
+
def get_run_state_str(self, run_id: int) -> str:
"""
Return the string representation of RunState.
@@ -305,7 +338,7 @@ def get_run_output(self, run_id: int) -> dict:
:param run_id: id of the run
:return: output of the run
"""
- json = {'run_id': run_id}
+ json = {"run_id": run_id}
run_output = self._do_api_call(OUTPUT_RUNS_JOB_ENDPOINT, json)
return run_output
@@ -315,7 +348,7 @@ def cancel_run(self, run_id: int) -> None:
:param run_id: id of the run
"""
- json = {'run_id': run_id}
+ json = {"run_id": run_id}
self._do_api_call(CANCEL_RUN_ENDPOINT, json)
def restart_cluster(self, json: dict) -> None:
@@ -362,7 +395,7 @@ def uninstall(self, json: dict) -> None:
"""
self._do_api_call(UNINSTALL_LIBS_ENDPOINT, json)
- def update_repo(self, repo_id: str, json: Dict[str, Any]) -> dict:
+ def update_repo(self, repo_id: str, json: dict[str, Any]) -> dict:
"""
Updates given Databricks Repos
@@ -370,7 +403,7 @@ def update_repo(self, repo_id: str, json: Dict[str, Any]) -> dict:
:param json: payload
:return: metadata from update
"""
- repos_endpoint = ('PATCH', f'api/2.0/repos/{repo_id}')
+ repos_endpoint = ("PATCH", f"api/2.0/repos/{repo_id}")
return self._do_api_call(repos_endpoint, json)
def delete_repo(self, repo_id: str):
@@ -380,31 +413,44 @@ def delete_repo(self, repo_id: str):
:param repo_id: ID of Databricks Repos
:return:
"""
- repos_endpoint = ('DELETE', f'api/2.0/repos/{repo_id}')
+ repos_endpoint = ("DELETE", f"api/2.0/repos/{repo_id}")
self._do_api_call(repos_endpoint)
- def create_repo(self, json: Dict[str, Any]) -> dict:
+ def create_repo(self, json: dict[str, Any]) -> dict:
"""
Creates a Databricks Repos
:param json: payload
:return:
"""
- repos_endpoint = ('POST', 'api/2.0/repos')
+ repos_endpoint = ("POST", "api/2.0/repos")
return self._do_api_call(repos_endpoint, json)
- def get_repo_by_path(self, path: str) -> Optional[str]:
+ def get_repo_by_path(self, path: str) -> str | None:
"""
Obtains Repos ID by path
:param path: path to a repository
:return: Repos ID if it exists, None if doesn't.
"""
try:
- result = self._do_api_call(WORKSPACE_GET_STATUS_ENDPOINT, {'path': path}, wrap_http_errors=False)
- if result.get('object_type', '') == 'REPO':
- return str(result['object_id'])
+ result = self._do_api_call(WORKSPACE_GET_STATUS_ENDPOINT, {"path": path}, wrap_http_errors=False)
+ if result.get("object_type", "") == "REPO":
+ return str(result["object_id"])
except requests_exceptions.HTTPError as e:
if e.response.status_code != 404:
raise e
return None
+
+ def test_connection(self) -> tuple[bool, str]:
+ """Test the Databricks connectivity from UI"""
+ hook = DatabricksHook(databricks_conn_id=self.databricks_conn_id)
+ try:
+ hook._do_api_call(endpoint_info=SPARK_VERSIONS_ENDPOINT).get("versions")
+ status = True
+ message = "Connection successfully tested"
+ except Exception as e:
+ status = False
+ message = str(e)
+
+ return status, message
diff --git a/airflow/providers/databricks/hooks/databricks_base.py b/airflow/providers/databricks/hooks/databricks_base.py
index 5b18dad9303ef..50ab2eff6a251 100644
--- a/airflow/providers/databricks/hooks/databricks_base.py
+++ b/airflow/providers/databricks/hooks/databricks_base.py
@@ -22,11 +22,13 @@
operators talk to the ``api/2.0/jobs/runs/submit``
`endpoint `_.
"""
+from __future__ import annotations
+
import copy
-import sys
+import platform
import time
-from typing import Any, Dict, Optional, Tuple
-from urllib.parse import urlparse
+from typing import Any
+from urllib.parse import urlsplit
import aiohttp
import requests
@@ -43,16 +45,11 @@
)
from airflow import __version__
+from airflow.compat.functools import cached_property
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
from airflow.models import Connection
-
-if sys.version_info >= (3, 8):
- from functools import cached_property
-else:
- from cached_property import cached_property
-
-USER_AGENT_HEADER = {'user-agent': f'airflow-{__version__}'}
+from airflow.providers_manager import ProvidersManager
# https://docs.microsoft.com/en-us/azure/databricks/dev-tools/api/latest/aad/service-prin-aad-token#--get-an-azure-active-directory-access-token
# https://docs.microsoft.com/en-us/graph/deployments#app-registration-and-token-service-root-endpoints
@@ -81,17 +78,17 @@ class BaseDatabricksHook(BaseHook):
:param retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class.
"""
- conn_name_attr = 'databricks_conn_id'
- default_conn_name = 'databricks_default'
- conn_type = 'databricks'
+ conn_name_attr: str = "databricks_conn_id"
+ default_conn_name = "databricks_default"
+ conn_type = "databricks"
extra_parameters = [
- 'token',
- 'host',
- 'use_azure_managed_identity',
- 'azure_ad_endpoint',
- 'azure_resource_id',
- 'azure_tenant_id',
+ "token",
+ "host",
+ "use_azure_managed_identity",
+ "azure_ad_endpoint",
+ "azure_resource_id",
+ "azure_tenant_id",
]
def __init__(
@@ -100,25 +97,27 @@ def __init__(
timeout_seconds: int = 180,
retry_limit: int = 3,
retry_delay: float = 1.0,
- retry_args: Optional[Dict[Any, Any]] = None,
+ retry_args: dict[Any, Any] | None = None,
+ caller: str = "Unknown",
) -> None:
super().__init__()
self.databricks_conn_id = databricks_conn_id
self.timeout_seconds = timeout_seconds
if retry_limit < 1:
- raise ValueError('Retry limit must be greater than or equal to 1')
+ raise ValueError("Retry limit must be greater than or equal to 1")
self.retry_limit = retry_limit
self.retry_delay = retry_delay
- self.aad_tokens: Dict[str, dict] = {}
+ self.aad_tokens: dict[str, dict] = {}
self.aad_timeout_seconds = 10
+ self.caller = caller
def my_after_func(retry_state):
self._log_request_error(retry_state.attempt_number, retry_state.outcome)
if retry_args:
self.retry_args = copy.copy(retry_args)
- self.retry_args['retry'] = retry_if_exception(self._retryable_error)
- self.retry_args['after'] = my_after_func
+ self.retry_args["retry"] = retry_if_exception(self._retryable_error)
+ self.retry_args["after"] = my_after_func
else:
self.retry_args = dict(
stop=stop_after_attempt(self.retry_limit),
@@ -134,10 +133,28 @@ def databricks_conn(self) -> Connection:
def get_conn(self) -> Connection:
return self.databricks_conn
+ @cached_property
+ def user_agent_header(self) -> dict[str, str]:
+ return {"user-agent": self.user_agent_value}
+
+ @cached_property
+ def user_agent_value(self) -> str:
+ manager = ProvidersManager()
+ package_name = manager.hooks[BaseDatabricksHook.conn_type].package_name # type: ignore[union-attr]
+ provider = manager.providers[package_name]
+ version = provider.version
+ python_version = platform.python_version()
+ system = platform.system().lower()
+ ua_string = (
+ f"databricks-airflow/{version} _/0.0.0 python/{python_version} os/{system} "
+ f"airflow/{__version__} operator/{self.caller}"
+ )
+ return ua_string
+
@cached_property
def host(self) -> str:
- if 'host' in self.databricks_conn.extra_dejson:
- host = self._parse_host(self.databricks_conn.extra_dejson['host'])
+ if "host" in self.databricks_conn.extra_dejson:
+ host = self._parse_host(self.databricks_conn.extra_dejson["host"])
else:
host = self._parse_host(self.databricks_conn.host)
@@ -154,8 +171,7 @@ async def __aexit__(self, *err):
@staticmethod
def _parse_host(host: str) -> str:
"""
- The purpose of this function is to be robust to improper connections
- settings provided by users, specifically in the host field.
+ This function is resistant to incorrect connection settings provided by users, in the host field.
For example -- when users supply ``https://xx.cloud.databricks.com`` as the
host, we must strip out the protocol to get the host.::
@@ -170,7 +186,7 @@ def _parse_host(host: str) -> str:
assert h._parse_host('xx.cloud.databricks.com') == 'xx.cloud.databricks.com'
"""
- urlparse_host = urlparse(host).hostname
+ urlparse_host = urlsplit(host).hostname
if urlparse_host:
# In this case, host = https://xx.cloud.databricks.com
return urlparse_host
@@ -180,33 +196,35 @@ def _parse_host(host: str) -> str:
def _get_retry_object(self) -> Retrying:
"""
- Instantiates a retry object
+ Instantiate a retry object.
:return: instance of Retrying class
"""
return Retrying(**self.retry_args)
def _a_get_retry_object(self) -> AsyncRetrying:
"""
- Instantiates an async retry object
+ Instantiate an async retry object.
:return: instance of AsyncRetrying class
"""
return AsyncRetrying(**self.retry_args)
def _get_aad_token(self, resource: str) -> str:
"""
- Function to get AAD token for given resource. Supports managed identity or service principal auth
+ Function to get AAD token for given resource.
+
+ Supports managed identity or service principal auth.
:param resource: resource to issue token to
:return: AAD token, or raise an exception
"""
aad_token = self.aad_tokens.get(resource)
if aad_token and self._is_aad_token_valid(aad_token):
- return aad_token['token']
+ return aad_token["token"]
- self.log.info('Existing AAD token is expired, or going to expire soon. Refreshing...')
+ self.log.info("Existing AAD token is expired, or going to expire soon. Refreshing...")
try:
for attempt in self._get_retry_object():
with attempt:
- if self.databricks_conn.extra_dejson.get('use_azure_managed_identity', False):
+ if self.databricks_conn.extra_dejson.get("use_azure_managed_identity", False):
params = {
"api-version": "2018-02-01",
"resource": resource,
@@ -214,11 +232,11 @@ def _get_aad_token(self, resource: str) -> str:
resp = requests.get(
AZURE_METADATA_SERVICE_TOKEN_URL,
params=params,
- headers={**USER_AGENT_HEADER, "Metadata": "true"},
+ headers={**self.user_agent_header, "Metadata": "true"},
timeout=self.aad_timeout_seconds,
)
else:
- tenant_id = self.databricks_conn.extra_dejson['azure_tenant_id']
+ tenant_id = self.databricks_conn.extra_dejson["azure_tenant_id"]
data = {
"grant_type": "client_credentials",
"client_id": self.databricks_conn.login,
@@ -232,8 +250,8 @@ def _get_aad_token(self, resource: str) -> str:
AZURE_TOKEN_SERVICE_URL.format(azure_ad_endpoint, tenant_id),
data=data,
headers={
- **USER_AGENT_HEADER,
- 'Content-Type': 'application/x-www-form-urlencoded',
+ **self.user_agent_header,
+ "Content-Type": "application/x-www-form-urlencoded",
},
timeout=self.aad_timeout_seconds,
)
@@ -241,19 +259,19 @@ def _get_aad_token(self, resource: str) -> str:
resp.raise_for_status()
jsn = resp.json()
if (
- 'access_token' not in jsn
- or jsn.get('token_type') != 'Bearer'
- or 'expires_on' not in jsn
+ "access_token" not in jsn
+ or jsn.get("token_type") != "Bearer"
+ or "expires_on" not in jsn
):
raise AirflowException(f"Can't get necessary data from AAD token: {jsn}")
- token = jsn['access_token']
- self.aad_tokens[resource] = {'token': token, 'expires_on': int(jsn["expires_on"])}
+ token = jsn["access_token"]
+ self.aad_tokens[resource] = {"token": token, "expires_on": int(jsn["expires_on"])}
break
except RetryError:
- raise AirflowException(f'API requests to Azure failed {self.retry_limit} times. Giving up.')
+ raise AirflowException(f"API requests to Azure failed {self.retry_limit} times. Giving up.")
except requests_exceptions.HTTPError as e:
- raise AirflowException(f'Response: {e.response.content}, Status Code: {e.response.status_code}')
+ raise AirflowException(f"Response: {e.response.content}, Status Code: {e.response.status_code}")
return token
@@ -265,13 +283,13 @@ async def _a_get_aad_token(self, resource: str) -> str:
"""
aad_token = self.aad_tokens.get(resource)
if aad_token and self._is_aad_token_valid(aad_token):
- return aad_token['token']
+ return aad_token["token"]
- self.log.info('Existing AAD token is expired, or going to expire soon. Refreshing...')
+ self.log.info("Existing AAD token is expired, or going to expire soon. Refreshing...")
try:
async for attempt in self._a_get_retry_object():
with attempt:
- if self.databricks_conn.extra_dejson.get('use_azure_managed_identity', False):
+ if self.databricks_conn.extra_dejson.get("use_azure_managed_identity", False):
params = {
"api-version": "2018-02-01",
"resource": resource,
@@ -279,13 +297,13 @@ async def _a_get_aad_token(self, resource: str) -> str:
async with self._session.get(
url=AZURE_METADATA_SERVICE_TOKEN_URL,
params=params,
- headers={**USER_AGENT_HEADER, "Metadata": "true"},
+ headers={**self.user_agent_header, "Metadata": "true"},
timeout=self.aad_timeout_seconds,
) as resp:
resp.raise_for_status()
jsn = await resp.json()
else:
- tenant_id = self.databricks_conn.extra_dejson['azure_tenant_id']
+ tenant_id = self.databricks_conn.extra_dejson["azure_tenant_id"]
data = {
"grant_type": "client_credentials",
"client_id": self.databricks_conn.login,
@@ -299,42 +317,42 @@ async def _a_get_aad_token(self, resource: str) -> str:
url=AZURE_TOKEN_SERVICE_URL.format(azure_ad_endpoint, tenant_id),
data=data,
headers={
- **USER_AGENT_HEADER,
- 'Content-Type': 'application/x-www-form-urlencoded',
+ **self.user_agent_header,
+ "Content-Type": "application/x-www-form-urlencoded",
},
timeout=self.aad_timeout_seconds,
) as resp:
resp.raise_for_status()
jsn = await resp.json()
if (
- 'access_token' not in jsn
- or jsn.get('token_type') != 'Bearer'
- or 'expires_on' not in jsn
+ "access_token" not in jsn
+ or jsn.get("token_type") != "Bearer"
+ or "expires_on" not in jsn
):
raise AirflowException(f"Can't get necessary data from AAD token: {jsn}")
- token = jsn['access_token']
- self.aad_tokens[resource] = {'token': token, 'expires_on': int(jsn["expires_on"])}
+ token = jsn["access_token"]
+ self.aad_tokens[resource] = {"token": token, "expires_on": int(jsn["expires_on"])}
break
except RetryError:
- raise AirflowException(f'API requests to Azure failed {self.retry_limit} times. Giving up.')
+ raise AirflowException(f"API requests to Azure failed {self.retry_limit} times. Giving up.")
except aiohttp.ClientResponseError as err:
- raise AirflowException(f'Response: {err.message}, Status Code: {err.status}')
+ raise AirflowException(f"Response: {err.message}, Status Code: {err.status}")
return token
def _get_aad_headers(self) -> dict:
"""
- Fills AAD headers if necessary (SPN is outside of the workspace)
+ Fill AAD headers if necessary (SPN is outside of the workspace).
:return: dictionary with filled AAD headers
"""
headers = {}
- if 'azure_resource_id' in self.databricks_conn.extra_dejson:
+ if "azure_resource_id" in self.databricks_conn.extra_dejson:
mgmt_token = self._get_aad_token(AZURE_MANAGEMENT_ENDPOINT)
- headers['X-Databricks-Azure-Workspace-Resource-Id'] = self.databricks_conn.extra_dejson[
- 'azure_resource_id'
+ headers["X-Databricks-Azure-Workspace-Resource-Id"] = self.databricks_conn.extra_dejson[
+ "azure_resource_id"
]
- headers['X-Databricks-Azure-SP-Management-Token'] = mgmt_token
+ headers["X-Databricks-Azure-SP-Management-Token"] = mgmt_token
return headers
async def _a_get_aad_headers(self) -> dict:
@@ -343,31 +361,31 @@ async def _a_get_aad_headers(self) -> dict:
:return: dictionary with filled AAD headers
"""
headers = {}
- if 'azure_resource_id' in self.databricks_conn.extra_dejson:
+ if "azure_resource_id" in self.databricks_conn.extra_dejson:
mgmt_token = await self._a_get_aad_token(AZURE_MANAGEMENT_ENDPOINT)
- headers['X-Databricks-Azure-Workspace-Resource-Id'] = self.databricks_conn.extra_dejson[
- 'azure_resource_id'
+ headers["X-Databricks-Azure-Workspace-Resource-Id"] = self.databricks_conn.extra_dejson[
+ "azure_resource_id"
]
- headers['X-Databricks-Azure-SP-Management-Token'] = mgmt_token
+ headers["X-Databricks-Azure-SP-Management-Token"] = mgmt_token
return headers
@staticmethod
def _is_aad_token_valid(aad_token: dict) -> bool:
"""
- Utility function to check AAD token hasn't expired yet
+ Utility function to check AAD token hasn't expired yet.
+
:param aad_token: dict with properties of AAD token
:return: true if token is valid, false otherwise
- :rtype: bool
"""
now = int(time.time())
- if aad_token['expires_on'] > (now + TOKEN_REFRESH_LEAD_TIME):
+ if aad_token["expires_on"] > (now + TOKEN_REFRESH_LEAD_TIME):
return True
return False
@staticmethod
def _check_azure_metadata_service() -> None:
"""
- Check for Azure Metadata Service
+ Check for Azure Metadata Service.
https://docs.microsoft.com/en-us/azure/virtual-machines/linux/instance-metadata-service
"""
try:
@@ -377,7 +395,7 @@ def _check_azure_metadata_service() -> None:
headers={"Metadata": "true"},
timeout=2,
).json()
- if 'compute' not in jsn or 'azEnvironment' not in jsn['compute']:
+ if "compute" not in jsn or "azEnvironment" not in jsn["compute"]:
raise AirflowException(
f"Was able to fetch some metadata, but it doesn't look like Azure Metadata: {jsn}"
)
@@ -394,113 +412,112 @@ async def _a_check_azure_metadata_service(self):
timeout=2,
) as resp:
jsn = await resp.json()
- if 'compute' not in jsn or 'azEnvironment' not in jsn['compute']:
+ if "compute" not in jsn or "azEnvironment" not in jsn["compute"]:
raise AirflowException(
f"Was able to fetch some metadata, but it doesn't look like Azure Metadata: {jsn}"
)
except (requests_exceptions.RequestException, ValueError) as e:
raise AirflowException(f"Can't reach Azure Metadata Service: {e}")
- def _get_token(self, raise_error: bool = False) -> Optional[str]:
- if 'token' in self.databricks_conn.extra_dejson:
+ def _get_token(self, raise_error: bool = False) -> str | None:
+ if "token" in self.databricks_conn.extra_dejson:
self.log.info(
- 'Using token auth. For security reasons, please set token in Password field instead of extra'
+ "Using token auth. For security reasons, please set token in Password field instead of extra"
)
- return self.databricks_conn.extra_dejson['token']
+ return self.databricks_conn.extra_dejson["token"]
elif not self.databricks_conn.login and self.databricks_conn.password:
- self.log.info('Using token auth.')
+ self.log.info("Using token auth.")
return self.databricks_conn.password
- elif 'azure_tenant_id' in self.databricks_conn.extra_dejson:
+ elif "azure_tenant_id" in self.databricks_conn.extra_dejson:
if self.databricks_conn.login == "" or self.databricks_conn.password == "":
raise AirflowException("Azure SPN credentials aren't provided")
- self.log.info('Using AAD Token for SPN.')
+ self.log.info("Using AAD Token for SPN.")
return self._get_aad_token(DEFAULT_DATABRICKS_SCOPE)
- elif self.databricks_conn.extra_dejson.get('use_azure_managed_identity', False):
- self.log.info('Using AAD Token for managed identity.')
+ elif self.databricks_conn.extra_dejson.get("use_azure_managed_identity", False):
+ self.log.info("Using AAD Token for managed identity.")
self._check_azure_metadata_service()
return self._get_aad_token(DEFAULT_DATABRICKS_SCOPE)
elif raise_error:
- raise AirflowException('Token authentication isn\'t configured')
+ raise AirflowException("Token authentication isn't configured")
return None
- async def _a_get_token(self, raise_error: bool = False) -> Optional[str]:
- if 'token' in self.databricks_conn.extra_dejson:
+ async def _a_get_token(self, raise_error: bool = False) -> str | None:
+ if "token" in self.databricks_conn.extra_dejson:
self.log.info(
- 'Using token auth. For security reasons, please set token in Password field instead of extra'
+ "Using token auth. For security reasons, please set token in Password field instead of extra"
)
return self.databricks_conn.extra_dejson["token"]
elif not self.databricks_conn.login and self.databricks_conn.password:
- self.log.info('Using token auth.')
+ self.log.info("Using token auth.")
return self.databricks_conn.password
- elif 'azure_tenant_id' in self.databricks_conn.extra_dejson:
+ elif "azure_tenant_id" in self.databricks_conn.extra_dejson:
if self.databricks_conn.login == "" or self.databricks_conn.password == "":
raise AirflowException("Azure SPN credentials aren't provided")
- self.log.info('Using AAD Token for SPN.')
+ self.log.info("Using AAD Token for SPN.")
return await self._a_get_aad_token(DEFAULT_DATABRICKS_SCOPE)
- elif self.databricks_conn.extra_dejson.get('use_azure_managed_identity', False):
- self.log.info('Using AAD Token for managed identity.')
+ elif self.databricks_conn.extra_dejson.get("use_azure_managed_identity", False):
+ self.log.info("Using AAD Token for managed identity.")
await self._a_check_azure_metadata_service()
return await self._a_get_aad_token(DEFAULT_DATABRICKS_SCOPE)
elif raise_error:
- raise AirflowException('Token authentication isn\'t configured')
+ raise AirflowException("Token authentication isn't configured")
return None
def _log_request_error(self, attempt_num: int, error: str) -> None:
- self.log.error('Attempt %s API Request to Databricks failed with reason: %s', attempt_num, error)
+ self.log.error("Attempt %s API Request to Databricks failed with reason: %s", attempt_num, error)
def _do_api_call(
self,
- endpoint_info: Tuple[str, str],
- json: Optional[Dict[str, Any]] = None,
+ endpoint_info: tuple[str, str],
+ json: dict[str, Any] | None = None,
wrap_http_errors: bool = True,
):
"""
- Utility function to perform an API call with retries
+ Utility function to perform an API call with retries.
:param endpoint_info: Tuple of method and endpoint
:param json: Parameters for this API call.
:return: If the api call returns a OK status code,
this function returns the response in JSON. Otherwise,
we throw an AirflowException.
- :rtype: dict
"""
method, endpoint = endpoint_info
# TODO: get rid of explicit 'api/' in the endpoint specification
- url = f'https://{self.host}/{endpoint}'
+ url = f"https://{self.host}/{endpoint}"
aad_headers = self._get_aad_headers()
- headers = {**USER_AGENT_HEADER.copy(), **aad_headers}
+ headers = {**self.user_agent_header, **aad_headers}
auth: AuthBase
token = self._get_token()
if token:
auth = _TokenAuth(token)
else:
- self.log.info('Using basic auth.')
+ self.log.info("Using basic auth.")
auth = HTTPBasicAuth(self.databricks_conn.login, self.databricks_conn.password)
request_func: Any
- if method == 'GET':
+ if method == "GET":
request_func = requests.get
- elif method == 'POST':
+ elif method == "POST":
request_func = requests.post
- elif method == 'PATCH':
+ elif method == "PATCH":
request_func = requests.patch
- elif method == 'DELETE':
+ elif method == "DELETE":
request_func = requests.delete
else:
- raise AirflowException('Unexpected HTTP Method: ' + method)
+ raise AirflowException("Unexpected HTTP Method: " + method)
try:
for attempt in self._get_retry_object():
with attempt:
response = request_func(
url,
- json=json if method in ('POST', 'PATCH') else None,
- params=json if method == 'GET' else None,
+ json=json if method in ("POST", "PATCH") else None,
+ params=json if method == "GET" else None,
auth=auth,
headers=headers,
timeout=self.timeout_seconds,
@@ -508,16 +525,16 @@ def _do_api_call(
response.raise_for_status()
return response.json()
except RetryError:
- raise AirflowException(f'API requests to Databricks failed {self.retry_limit} times. Giving up.')
+ raise AirflowException(f"API requests to Databricks failed {self.retry_limit} times. Giving up.")
except requests_exceptions.HTTPError as e:
if wrap_http_errors:
raise AirflowException(
- f'Response: {e.response.content}, Status Code: {e.response.status_code}'
+ f"Response: {e.response.content}, Status Code: {e.response.status_code}"
)
else:
raise e
- async def _a_do_api_call(self, endpoint_info: Tuple[str, str], json: Optional[Dict[str, Any]] = None):
+ async def _a_do_api_call(self, endpoint_info: tuple[str, str], json: dict[str, Any] | None = None):
"""
Async version of `_do_api_call()`.
:param endpoint_info: Tuple of method and endpoint
@@ -527,28 +544,28 @@ async def _a_do_api_call(self, endpoint_info: Tuple[str, str], json: Optional[Di
"""
method, endpoint = endpoint_info
- url = f'https://{self.host}/{endpoint}'
+ url = f"https://{self.host}/{endpoint}"
aad_headers = await self._a_get_aad_headers()
- headers = {**USER_AGENT_HEADER.copy(), **aad_headers}
+ headers = {**self.user_agent_header, **aad_headers}
auth: aiohttp.BasicAuth
token = await self._a_get_token()
if token:
auth = BearerAuth(token)
else:
- self.log.info('Using basic auth.')
+ self.log.info("Using basic auth.")
auth = aiohttp.BasicAuth(self.databricks_conn.login, self.databricks_conn.password)
request_func: Any
- if method == 'GET':
+ if method == "GET":
request_func = self._session.get
- elif method == 'POST':
+ elif method == "POST":
request_func = self._session.post
- elif method == 'PATCH':
+ elif method == "PATCH":
request_func = self._session.patch
else:
- raise AirflowException('Unexpected HTTP Method: ' + method)
+ raise AirflowException("Unexpected HTTP Method: " + method)
try:
async for attempt in self._a_get_retry_object():
with attempt:
@@ -556,22 +573,22 @@ async def _a_do_api_call(self, endpoint_info: Tuple[str, str], json: Optional[Di
url,
json=json,
auth=auth,
- headers={**headers, **USER_AGENT_HEADER},
+ headers={**headers, **self.user_agent_header},
timeout=self.timeout_seconds,
) as response:
response.raise_for_status()
return await response.json()
except RetryError:
- raise AirflowException(f'API requests to Databricks failed {self.retry_limit} times. Giving up.')
+ raise AirflowException(f"API requests to Databricks failed {self.retry_limit} times. Giving up.")
except aiohttp.ClientResponseError as err:
- raise AirflowException(f'Response: {err.message}, Status Code: {err.status}')
+ raise AirflowException(f"Response: {err.message}, Status Code: {err.status}")
@staticmethod
def _get_error_code(exception: BaseException) -> str:
if isinstance(exception, requests_exceptions.HTTPError):
try:
jsn = exception.response.json()
- return jsn.get('error_code', '')
+ return jsn.get("error_code", "")
except JSONDecodeError:
pass
@@ -587,7 +604,7 @@ def _retryable_error(exception: BaseException) -> bool:
or exception.response.status_code == 429
or (
exception.response.status_code == 400
- and BaseDatabricksHook._get_error_code(exception) == 'COULD_NOT_ACQUIRE_LOCK'
+ and BaseDatabricksHook._get_error_code(exception) == "COULD_NOT_ACQUIRE_LOCK"
)
)
):
@@ -602,7 +619,9 @@ def _retryable_error(exception: BaseException) -> bool:
class _TokenAuth(AuthBase):
"""
- Helper class for requests Auth field. AuthBase requires you to implement the __call__
+ Helper class for requests Auth field.
+
+ AuthBase requires you to implement the ``__call__``
magic function.
"""
@@ -610,18 +629,18 @@ def __init__(self, token: str) -> None:
self.token = token
def __call__(self, r: PreparedRequest) -> PreparedRequest:
- r.headers['Authorization'] = 'Bearer ' + self.token
+ r.headers["Authorization"] = "Bearer " + self.token
return r
class BearerAuth(aiohttp.BasicAuth):
"""aiohttp only ships BasicAuth, for Bearer auth we need a subclass of BasicAuth."""
- def __new__(cls, token: str) -> 'BearerAuth':
+ def __new__(cls, token: str) -> BearerAuth:
return super().__new__(cls, token) # type: ignore
def __init__(self, token: str) -> None:
self.token = token
def encode(self) -> str:
- return f'Bearer {self.token}'
+ return f"Bearer {self.token}"
diff --git a/airflow/providers/databricks/hooks/databricks_sql.py b/airflow/providers/databricks/hooks/databricks_sql.py
index 9d86b4dbe3ae6..f042435943889 100644
--- a/airflow/providers/databricks/hooks/databricks_sql.py
+++ b/airflow/providers/databricks/hooks/databricks_sql.py
@@ -14,22 +14,20 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from __future__ import annotations
-import re
from contextlib import closing
from copy import copy
-from typing import Any, Dict, List, Optional, Tuple, Union
+from typing import Any, Callable, Iterable, Mapping
from databricks import sql # type: ignore[attr-defined]
from databricks.sql.client import Connection # type: ignore[attr-defined]
-from airflow import __version__
from airflow.exceptions import AirflowException
-from airflow.hooks.dbapi import DbApiHook
+from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.providers.databricks.hooks.databricks_base import BaseDatabricksHook
-LIST_SQL_ENDPOINTS_ENDPOINT = ('GET', 'api/2.0/sql/endpoints')
-USER_AGENT_STRING = f'airflow-{__version__}'
+LIST_SQL_ENDPOINTS_ENDPOINT = ("GET", "api/2.0/sql/endpoints")
class DatabricksSqlHook(BaseDatabricksHook, DbApiHook):
@@ -52,22 +50,24 @@ class DatabricksSqlHook(BaseDatabricksHook, DbApiHook):
:param kwargs: Additional parameters internal to Databricks SQL Connector parameters
"""
- hook_name = 'Databricks SQL'
+ hook_name = "Databricks SQL"
+ _test_connection_sql = "select 42"
def __init__(
self,
databricks_conn_id: str = BaseDatabricksHook.default_conn_name,
- http_path: Optional[str] = None,
- sql_endpoint_name: Optional[str] = None,
- session_configuration: Optional[Dict[str, str]] = None,
- http_headers: Optional[List[Tuple[str, str]]] = None,
- catalog: Optional[str] = None,
- schema: Optional[str] = None,
+ http_path: str | None = None,
+ sql_endpoint_name: str | None = None,
+ session_configuration: dict[str, str] | None = None,
+ http_headers: list[tuple[str, str]] | None = None,
+ catalog: str | None = None,
+ schema: str | None = None,
+ caller: str = "DatabricksSqlHook",
**kwargs,
) -> None:
- super().__init__(databricks_conn_id)
+ super().__init__(databricks_conn_id, caller=caller)
self._sql_conn = None
- self._token: Optional[str] = None
+ self._token: str | None = None
self._http_path = http_path
self._sql_endpoint_name = sql_endpoint_name
self.supports_autocommit = True
@@ -77,19 +77,19 @@ def __init__(
self.schema = schema
self.additional_params = kwargs
- def _get_extra_config(self) -> Dict[str, Optional[Any]]:
+ def _get_extra_config(self) -> dict[str, Any | None]:
extra_params = copy(self.databricks_conn.extra_dejson)
- for arg in ['http_path', 'session_configuration'] + self.extra_parameters:
+ for arg in ["http_path", "session_configuration"] + self.extra_parameters:
if arg in extra_params:
del extra_params[arg]
return extra_params
- def _get_sql_endpoint_by_name(self, endpoint_name) -> Dict[str, Any]:
+ def _get_sql_endpoint_by_name(self, endpoint_name) -> dict[str, Any]:
result = self._do_api_call(LIST_SQL_ENDPOINTS_ENDPOINT)
- if 'endpoints' not in result:
+ if "endpoints" not in result:
raise AirflowException("Can't list Databricks SQL endpoints")
- lst = [endpoint for endpoint in result['endpoints'] if endpoint['name'] == endpoint_name]
+ lst = [endpoint for endpoint in result["endpoints"] if endpoint["name"] == endpoint_name]
if len(lst) == 0:
raise AirflowException(f"Can't f Databricks SQL endpoint with name '{endpoint_name}'")
return lst[0]
@@ -99,9 +99,9 @@ def get_conn(self) -> Connection:
if not self._http_path:
if self._sql_endpoint_name:
endpoint = self._get_sql_endpoint_by_name(self._sql_endpoint_name)
- self._http_path = endpoint['odbc_params']['path']
- elif 'http_path' in self.databricks_conn.extra_dejson:
- self._http_path = self.databricks_conn.extra_dejson['http_path']
+ self._http_path = endpoint["odbc_params"]["path"]
+ elif "http_path" in self.databricks_conn.extra_dejson:
+ self._http_path = self.databricks_conn.extra_dejson["http_path"]
else:
raise AirflowException(
"http_path should be provided either explicitly, "
@@ -120,7 +120,7 @@ def get_conn(self) -> Connection:
requires_init = False
if not self.session_config:
- self.session_config = self.databricks_conn.extra_dejson.get('session_configuration')
+ self.session_config = self.databricks_conn.extra_dejson.get("session_configuration")
if not self._sql_conn or requires_init:
if self._sql_conn: # close already existing connection
@@ -133,25 +133,21 @@ def get_conn(self) -> Connection:
catalog=self.catalog,
session_configuration=self.session_config,
http_headers=self.http_headers,
- _user_agent_entry=USER_AGENT_STRING,
+ _user_agent_entry=self.user_agent_value,
**self._get_extra_config(),
**self.additional_params,
)
return self._sql_conn
- @staticmethod
- def maybe_split_sql_string(sql: str) -> List[str]:
- """
- Splits strings consisting of multiple SQL expressions into an
- TODO: do we need something more sophisticated?
-
- :param sql: SQL string potentially consisting of multiple expressions
- :return: list of individual expressions
- """
- splits = [s.strip() for s in re.split(";\\s*\r?\n", sql) if s.strip() != ""]
- return splits
-
- def run(self, sql: Union[str, List[str]], autocommit=True, parameters=None, handler=None):
+ def run(
+ self,
+ sql: str | Iterable[str],
+ autocommit: bool = False,
+ parameters: Iterable | Mapping | None = None,
+ handler: Callable | None = None,
+ split_statements: bool = True,
+ return_last: bool = True,
+ ) -> Any | list[Any] | None:
"""
Runs a command or a list of commands. Pass a list of sql
statements to the sql parameter to get them to execute
@@ -163,49 +159,44 @@ def run(self, sql: Union[str, List[str]], autocommit=True, parameters=None, hand
before executing the query.
:param parameters: The parameters to render the SQL query with.
:param handler: The result handler which is called with the result of each statement.
- :return: query results.
+ :param split_statements: Whether to split a single SQL string into statements and run separately
+ :param return_last: Whether to return result for only last statement or for all after split
+ :return: return only result of the LAST SQL expression if handler was provided.
"""
+ self.scalar_return_last = isinstance(sql, str) and return_last
if isinstance(sql, str):
- sql = self.maybe_split_sql_string(sql)
+ if split_statements:
+ sql = self.split_sql_string(sql)
+ else:
+ sql = [self.strip_sql_string(sql)]
if sql:
- self.log.debug("Executing %d statements", len(sql))
+ self.log.debug("Executing following statements against Databricks DB: %s", list(sql))
else:
raise ValueError("List of SQL statements is empty")
- conn = None
+ results = []
for sql_statement in sql:
# when using AAD tokens, it could expire if previous query run longer than token lifetime
- conn = self.get_conn()
- with closing(conn.cursor()) as cur:
- self.log.info("Executing statement: '%s', parameters: '%s'", sql_statement, parameters)
- if parameters:
- cur.execute(sql_statement, parameters)
- else:
- cur.execute(sql_statement)
- schema = cur.description
- results = []
- if handler is not None:
- cur = handler(cur)
- for row in cur:
- self.log.debug("Statement results: %s", row)
- results.append(row)
-
- self.log.info("Rows affected: %s", cur.rowcount)
- if conn:
- conn.close()
- self._sql_conn = None
+ with closing(self.get_conn()) as conn:
+ self.set_autocommit(conn, autocommit)
+
+ with closing(conn.cursor()) as cur:
+ self._run_command(cur, sql_statement, parameters)
- # Return only result of the last SQL expression
- return schema, results
+ if handler is not None:
+ result = handler(cur)
+ results.append(result)
+ self.last_description = cur.description
- def test_connection(self):
- """Test the Databricks SQL connection by running a simple query."""
- try:
- self.run(sql="select 42")
- except Exception as e:
- return False, str(e)
- return True, "Connection successfully checked"
+ self._sql_conn = None
+
+ if handler is None:
+ return None
+ elif self.scalar_return_last:
+ return results[-1]
+ else:
+ return results
def bulk_dump(self, table, tmp_file):
raise NotImplementedError()
diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py
index 8af4474b1315f..dbb44c70829a5 100644
--- a/airflow/providers/databricks/operators/databricks.py
+++ b/airflow/providers/databricks/operators/databricks.py
@@ -15,26 +15,27 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-#
"""This module contains Databricks operators."""
+from __future__ import annotations
import time
from logging import Logger
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
+from typing import TYPE_CHECKING, Any, Sequence
+from airflow.compat.functools import cached_property
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator, BaseOperatorLink, XCom
from airflow.providers.databricks.hooks.databricks import DatabricksHook, RunState
from airflow.providers.databricks.triggers.databricks import DatabricksExecutionTrigger
-from airflow.providers.databricks.utils.databricks import deep_string_coerce, validate_trigger_event
+from airflow.providers.databricks.utils.databricks import normalise_json_content, validate_trigger_event
if TYPE_CHECKING:
from airflow.models.taskinstance import TaskInstanceKey
from airflow.utils.context import Context
-DEFER_METHOD_NAME = 'execute_complete'
-XCOM_RUN_ID_KEY = 'run_id'
-XCOM_RUN_PAGE_URL_KEY = 'run_page_url'
+DEFER_METHOD_NAME = "execute_complete"
+XCOM_RUN_ID_KEY = "run_id"
+XCOM_RUN_PAGE_URL_KEY = "run_page_url"
def _handle_databricks_operator_execution(operator, hook, log, context) -> None:
@@ -45,35 +46,54 @@ def _handle_databricks_operator_execution(operator, hook, log, context) -> None:
:param context: Airflow context
"""
if operator.do_xcom_push and context is not None:
- context['ti'].xcom_push(key=XCOM_RUN_ID_KEY, value=operator.run_id)
- log.info('Run submitted with run_id: %s', operator.run_id)
+ context["ti"].xcom_push(key=XCOM_RUN_ID_KEY, value=operator.run_id)
+ log.info("Run submitted with run_id: %s", operator.run_id)
run_page_url = hook.get_run_page_url(operator.run_id)
if operator.do_xcom_push and context is not None:
- context['ti'].xcom_push(key=XCOM_RUN_PAGE_URL_KEY, value=run_page_url)
+ context["ti"].xcom_push(key=XCOM_RUN_PAGE_URL_KEY, value=run_page_url)
if operator.wait_for_termination:
while True:
- run_state = hook.get_run_state(operator.run_id)
+ run_info = hook.get_run(operator.run_id)
+ run_state = RunState(**run_info["state"])
if run_state.is_terminal:
if run_state.is_successful:
- log.info('%s completed successfully.', operator.task_id)
- log.info('View run status, Spark UI, and logs at %s', run_page_url)
+ log.info("%s completed successfully.", operator.task_id)
+ log.info("View run status, Spark UI, and logs at %s", run_page_url)
return
else:
- run_output = hook.get_run_output(operator.run_id)
- notebook_error = run_output['error']
- error_message = (
- f'{operator.task_id} failed with terminal state: {run_state} '
- f'and with the error {notebook_error}'
- )
+ if run_state.result_state == "FAILED":
+ task_run_id = None
+ if "tasks" in run_info:
+ for task in run_info["tasks"]:
+ if task.get("state", {}).get("result_state", "") == "FAILED":
+ task_run_id = task["run_id"]
+ if task_run_id is not None:
+ run_output = hook.get_run_output(task_run_id)
+ if "error" in run_output:
+ notebook_error = run_output["error"]
+ else:
+ notebook_error = run_state.state_message
+ else:
+ notebook_error = run_state.state_message
+ error_message = (
+ f"{operator.task_id} failed with terminal state: {run_state} "
+ f"and with the error {notebook_error}"
+ )
+ else:
+ error_message = (
+ f"{operator.task_id} failed with terminal state: {run_state} "
+ f"and with the error {run_state.state_message}"
+ )
raise AirflowException(error_message)
+
else:
- log.info('%s in run state: %s', operator.task_id, run_state)
- log.info('View run status, Spark UI, and logs at %s', run_page_url)
- log.info('Sleeping for %s seconds.', operator.polling_period_seconds)
+ log.info("%s in run state: %s", operator.task_id, run_state)
+ log.info("View run status, Spark UI, and logs at %s", run_page_url)
+ log.info("Sleeping for %s seconds.", operator.polling_period_seconds)
time.sleep(operator.polling_period_seconds)
else:
- log.info('View run status, Spark UI, and logs at %s', run_page_url)
+ log.info("View run status, Spark UI, and logs at %s", run_page_url)
def _handle_deferrable_databricks_operator_execution(operator, hook, log, context) -> None:
@@ -84,13 +104,13 @@ def _handle_deferrable_databricks_operator_execution(operator, hook, log, contex
:param context: Airflow context
"""
if operator.do_xcom_push and context is not None:
- context['ti'].xcom_push(key=XCOM_RUN_ID_KEY, value=operator.run_id)
- log.info(f'Run submitted with run_id: {operator.run_id}')
+ context["ti"].xcom_push(key=XCOM_RUN_ID_KEY, value=operator.run_id)
+ log.info("Run submitted with run_id: %s", operator.run_id)
run_page_url = hook.get_run_page_url(operator.run_id)
if operator.do_xcom_push and context is not None:
- context['ti'].xcom_push(key=XCOM_RUN_PAGE_URL_KEY, value=run_page_url)
- log.info(f'View run status, Spark UI, and logs at {run_page_url}')
+ context["ti"].xcom_push(key=XCOM_RUN_PAGE_URL_KEY, value=run_page_url)
+ log.info("View run status, Spark UI, and logs at %s", run_page_url)
if operator.wait_for_termination:
operator.defer(
@@ -105,15 +125,15 @@ def _handle_deferrable_databricks_operator_execution(operator, hook, log, contex
def _handle_deferrable_databricks_operator_completion(event: dict, log: Logger) -> None:
validate_trigger_event(event)
- run_state = RunState.from_json(event['run_state'])
- run_page_url = event['run_page_url']
- log.info(f'View run status, Spark UI, and logs at {run_page_url}')
+ run_state = RunState.from_json(event["run_state"])
+ run_page_url = event["run_page_url"]
+ log.info("View run status, Spark UI, and logs at %s", run_page_url)
if run_state.is_successful:
- log.info('Job run completed successfully.')
+ log.info("Job run completed successfully.")
return
else:
- error_message = f'Job run failed with terminal state: {run_state}'
+ error_message = f"Job run failed with terminal state: {run_state}"
raise AirflowException(error_message)
@@ -124,23 +144,11 @@ class DatabricksJobRunLink(BaseOperatorLink):
def get_link(
self,
- operator,
- dttm=None,
+ operator: BaseOperator,
*,
- ti_key: Optional["TaskInstanceKey"] = None,
+ ti_key: TaskInstanceKey,
) -> str:
- if ti_key is not None:
- run_page_url = XCom.get_value(key=XCOM_RUN_PAGE_URL_KEY, ti_key=ti_key)
- else:
- assert dttm
- run_page_url = XCom.get_one(
- key=XCOM_RUN_PAGE_URL_KEY,
- dag_id=operator.dag.dag_id,
- task_id=operator.task_id,
- execution_date=dttm,
- )
-
- return run_page_url
+ return XCom.get_value(key=XCOM_RUN_PAGE_URL_KEY, ti_key=ti_key)
class DatabricksSubmitRunOperator(BaseOperator):
@@ -150,62 +158,16 @@ class DatabricksSubmitRunOperator(BaseOperator):
`_
API endpoint.
- There are two ways to instantiate this operator.
-
- In the first way, you can take the JSON payload that you typically use
- to call the ``api/2.1/jobs/runs/submit`` endpoint and pass it directly
- to our ``DatabricksSubmitRunOperator`` through the ``json`` parameter.
- For example ::
-
- json = {
- 'new_cluster': {
- 'spark_version': '2.1.0-db3-scala2.11',
- 'num_workers': 2
- },
- 'notebook_task': {
- 'notebook_path': '/Users/airflow@example.com/PrepareData',
- },
- }
- notebook_run = DatabricksSubmitRunOperator(task_id='notebook_run', json=json)
-
- Another way to accomplish the same thing is to use the named parameters
- of the ``DatabricksSubmitRunOperator`` directly. Note that there is exactly
- one named parameter for each top level parameter in the ``runs/submit``
- endpoint. In this method, your code would look like this: ::
-
- new_cluster = {
- 'spark_version': '10.1.x-scala2.12',
- 'num_workers': 2
- }
- notebook_task = {
- 'notebook_path': '/Users/airflow@example.com/PrepareData',
- }
- notebook_run = DatabricksSubmitRunOperator(
- task_id='notebook_run',
- new_cluster=new_cluster,
- notebook_task=notebook_task)
-
- In the case where both the json parameter **AND** the named parameters
- are provided, they will be merged together. If there are conflicts during the merge,
- the named parameters will take precedence and override the top level ``json`` keys.
-
- Currently the named parameters that ``DatabricksSubmitRunOperator`` supports are
- - ``spark_jar_task``
- - ``notebook_task``
- - ``spark_python_task``
- - ``spark_jar_task``
- - ``spark_submit_task``
- - ``pipeline_task``
- - ``new_cluster``
- - ``existing_cluster_id``
- - ``libraries``
- - ``run_name``
- - ``timeout_seconds``
+ There are three ways to instantiate this operator.
.. seealso::
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:DatabricksSubmitRunOperator`
+ :param tasks: Array of Objects(RunSubmitTaskSettings) <= 100 items.
+
+ .. seealso::
+ https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsRunsSubmit
:param json: A JSON object containing API parameters which will be passed
directly to the ``api/2.1/jobs/runs/submit`` endpoint. The other named parameters
(i.e. ``spark_jar_task``, ``notebook_task``..) to this operator will
@@ -219,28 +181,28 @@ class DatabricksSubmitRunOperator(BaseOperator):
:param spark_jar_task: The main class and parameters for the JAR task. Note that
the actual JAR is specified in the ``libraries``.
*EITHER* ``spark_jar_task`` *OR* ``notebook_task`` *OR* ``spark_python_task``
- *OR* ``spark_submit_task`` *OR* ``pipeline_task`` should be specified.
+ *OR* ``spark_submit_task`` *OR* ``pipeline_task`` *OR* ``dbt_task`` should be specified.
This field will be templated.
.. seealso::
https://docs.databricks.com/dev-tools/api/2.0/jobs.html#jobssparkjartask
:param notebook_task: The notebook path and parameters for the notebook task.
*EITHER* ``spark_jar_task`` *OR* ``notebook_task`` *OR* ``spark_python_task``
- *OR* ``spark_submit_task`` *OR* ``pipeline_task`` should be specified.
+ *OR* ``spark_submit_task`` *OR* ``pipeline_task`` *OR* ``dbt_task`` should be specified.
This field will be templated.
.. seealso::
https://docs.databricks.com/dev-tools/api/2.0/jobs.html#jobsnotebooktask
:param spark_python_task: The python file path and parameters to run the python file with.
*EITHER* ``spark_jar_task`` *OR* ``notebook_task`` *OR* ``spark_python_task``
- *OR* ``spark_submit_task`` *OR* ``pipeline_task`` should be specified.
+ *OR* ``spark_submit_task`` *OR* ``pipeline_task`` *OR* ``dbt_task`` should be specified.
This field will be templated.
.. seealso::
https://docs.databricks.com/dev-tools/api/2.0/jobs.html#jobssparkpythontask
:param spark_submit_task: Parameters needed to run a spark-submit command.
*EITHER* ``spark_jar_task`` *OR* ``notebook_task`` *OR* ``spark_python_task``
- *OR* ``spark_submit_task`` *OR* ``pipeline_task`` should be specified.
+ *OR* ``spark_submit_task`` *OR* ``pipeline_task`` *OR* ``dbt_task`` should be specified.
This field will be templated.
.. seealso::
@@ -248,11 +210,18 @@ class DatabricksSubmitRunOperator(BaseOperator):
:param pipeline_task: Parameters needed to execute a Delta Live Tables pipeline task.
The provided dictionary must contain at least ``pipeline_id`` field!
*EITHER* ``spark_jar_task`` *OR* ``notebook_task`` *OR* ``spark_python_task``
- *OR* ``spark_submit_task`` *OR* ``pipeline_task`` should be specified.
+ *OR* ``spark_submit_task`` *OR* ``pipeline_task`` *OR* ``dbt_task`` should be specified.
This field will be templated.
.. seealso::
https://docs.databricks.com/dev-tools/api/2.0/jobs.html#jobspipelinetask
+ :param dbt_task: Parameters needed to execute a dbt task.
+ The provided dictionary must contain at least the ``commands`` field and the
+ ``git_source`` parameter also needs to be set.
+ *EITHER* ``spark_jar_task`` *OR* ``notebook_task`` *OR* ``spark_python_task``
+ *OR* ``spark_submit_task`` *OR* ``pipeline_task`` *OR* ``dbt_task`` should be specified.
+ This field will be templated.
+
:param new_cluster: Specs for a new cluster on which this task will be run.
*EITHER* ``new_cluster`` *OR* ``existing_cluster_id`` should be specified
(except when ``pipeline_task`` is used).
@@ -287,7 +256,7 @@ class DatabricksSubmitRunOperator(BaseOperator):
:param databricks_conn_id: Reference to the :ref:`Databricks connection `.
By default and in the common case this will be ``databricks_default``. To use
token based authentication, provide the key ``token`` in the extra field for the
- connection and create the key ``host`` and leave the ``host`` field empty.
+ connection and create the key ``host`` and leave the ``host`` field empty. (templated)
:param polling_period_seconds: Controls the rate which we poll for the result of
this run. By default the operator will poll every 30 seconds.
:param databricks_retry_limit: Amount of times retry if the Databricks backend is
@@ -304,38 +273,39 @@ class DatabricksSubmitRunOperator(BaseOperator):
"""
# Used in airflow.models.BaseOperator
- template_fields: Sequence[str] = ('json',)
- template_ext: Sequence[str] = ('.json',)
+ template_fields: Sequence[str] = ("json", "databricks_conn_id")
+ template_ext: Sequence[str] = (".json-tpl",)
# Databricks brand color (blue) under white text
- ui_color = '#1CB1C2'
- ui_fgcolor = '#fff'
+ ui_color = "#1CB1C2"
+ ui_fgcolor = "#fff"
operator_extra_links = (DatabricksJobRunLink(),)
def __init__(
self,
*,
- json: Optional[Any] = None,
- tasks: Optional[List[object]] = None,
- spark_jar_task: Optional[Dict[str, str]] = None,
- notebook_task: Optional[Dict[str, str]] = None,
- spark_python_task: Optional[Dict[str, Union[str, List[str]]]] = None,
- spark_submit_task: Optional[Dict[str, List[str]]] = None,
- pipeline_task: Optional[Dict[str, str]] = None,
- new_cluster: Optional[Dict[str, object]] = None,
- existing_cluster_id: Optional[str] = None,
- libraries: Optional[List[Dict[str, str]]] = None,
- run_name: Optional[str] = None,
- timeout_seconds: Optional[int] = None,
- databricks_conn_id: str = 'databricks_default',
+ json: Any | None = None,
+ tasks: list[object] | None = None,
+ spark_jar_task: dict[str, str] | None = None,
+ notebook_task: dict[str, str] | None = None,
+ spark_python_task: dict[str, str | list[str]] | None = None,
+ spark_submit_task: dict[str, list[str]] | None = None,
+ pipeline_task: dict[str, str] | None = None,
+ dbt_task: dict[str, str | list[str]] | None = None,
+ new_cluster: dict[str, object] | None = None,
+ existing_cluster_id: str | None = None,
+ libraries: list[dict[str, str]] | None = None,
+ run_name: str | None = None,
+ timeout_seconds: int | None = None,
+ databricks_conn_id: str = "databricks_default",
polling_period_seconds: int = 30,
databricks_retry_limit: int = 3,
databricks_retry_delay: int = 1,
- databricks_retry_args: Optional[Dict[Any, Any]] = None,
+ databricks_retry_args: dict[Any, Any] | None = None,
do_xcom_push: bool = True,
- idempotency_token: Optional[str] = None,
- access_control_list: Optional[List[Dict[str, str]]] = None,
+ idempotency_token: str | None = None,
+ access_control_list: list[dict[str, str]] | None = None,
wait_for_termination: bool = True,
- git_source: Optional[Dict[str, str]] = None,
+ git_source: dict[str, str] | None = None,
**kwargs,
) -> None:
"""Creates a new ``DatabricksSubmitRunOperator``."""
@@ -348,74 +318,82 @@ def __init__(
self.databricks_retry_args = databricks_retry_args
self.wait_for_termination = wait_for_termination
if tasks is not None:
- self.json['tasks'] = tasks
+ self.json["tasks"] = tasks
if spark_jar_task is not None:
- self.json['spark_jar_task'] = spark_jar_task
+ self.json["spark_jar_task"] = spark_jar_task
if notebook_task is not None:
- self.json['notebook_task'] = notebook_task
+ self.json["notebook_task"] = notebook_task
if spark_python_task is not None:
- self.json['spark_python_task'] = spark_python_task
+ self.json["spark_python_task"] = spark_python_task
if spark_submit_task is not None:
- self.json['spark_submit_task'] = spark_submit_task
+ self.json["spark_submit_task"] = spark_submit_task
if pipeline_task is not None:
- self.json['pipeline_task'] = pipeline_task
+ self.json["pipeline_task"] = pipeline_task
+ if dbt_task is not None:
+ self.json["dbt_task"] = dbt_task
if new_cluster is not None:
- self.json['new_cluster'] = new_cluster
+ self.json["new_cluster"] = new_cluster
if existing_cluster_id is not None:
- self.json['existing_cluster_id'] = existing_cluster_id
+ self.json["existing_cluster_id"] = existing_cluster_id
if libraries is not None:
- self.json['libraries'] = libraries
+ self.json["libraries"] = libraries
if run_name is not None:
- self.json['run_name'] = run_name
+ self.json["run_name"] = run_name
if timeout_seconds is not None:
- self.json['timeout_seconds'] = timeout_seconds
- if 'run_name' not in self.json:
- self.json['run_name'] = run_name or kwargs['task_id']
+ self.json["timeout_seconds"] = timeout_seconds
+ if "run_name" not in self.json:
+ self.json["run_name"] = run_name or kwargs["task_id"]
if idempotency_token is not None:
- self.json['idempotency_token'] = idempotency_token
+ self.json["idempotency_token"] = idempotency_token
if access_control_list is not None:
- self.json['access_control_list'] = access_control_list
+ self.json["access_control_list"] = access_control_list
if git_source is not None:
- self.json['git_source'] = git_source
+ self.json["git_source"] = git_source
- self.json = deep_string_coerce(self.json)
+ if "dbt_task" in self.json and "git_source" not in self.json:
+ raise AirflowException("git_source is required for dbt_task")
+
+ self.json = normalise_json_content(self.json)
# This variable will be used in case our task gets killed.
- self.run_id: Optional[int] = None
+ self.run_id: int | None = None
self.do_xcom_push = do_xcom_push
- def _get_hook(self) -> DatabricksHook:
+ @cached_property
+ def _hook(self):
+ return self._get_hook(caller="DatabricksSubmitRunOperator")
+
+ def _get_hook(self, caller: str) -> DatabricksHook:
return DatabricksHook(
self.databricks_conn_id,
retry_limit=self.databricks_retry_limit,
retry_delay=self.databricks_retry_delay,
retry_args=self.databricks_retry_args,
+ caller=caller,
)
- def execute(self, context: 'Context'):
- hook = self._get_hook()
- self.run_id = hook.submit_run(self.json)
- _handle_databricks_operator_execution(self, hook, self.log, context)
+ def execute(self, context: Context):
+ self.run_id = self._hook.submit_run(self.json)
+ _handle_databricks_operator_execution(self, self._hook, self.log, context)
def on_kill(self):
if self.run_id:
- hook = self._get_hook()
- hook.cancel_run(self.run_id)
+ self._hook.cancel_run(self.run_id)
self.log.info(
- 'Task: %s with run_id: %s was requested to be cancelled.', self.task_id, self.run_id
+ "Task: %s with run_id: %s was requested to be cancelled.", self.task_id, self.run_id
)
else:
- self.log.error('Error: Task: %s with invalid run_id was requested to be cancelled.', self.task_id)
+ self.log.error("Error: Task: %s with invalid run_id was requested to be cancelled.", self.task_id)
class DatabricksSubmitRunDeferrableOperator(DatabricksSubmitRunOperator):
"""Deferrable version of ``DatabricksSubmitRunOperator``"""
def execute(self, context):
- hook = self._get_hook()
+ hook = self._get_hook(caller="DatabricksSubmitRunDeferrableOperator")
self.run_id = hook.submit_run(self.json)
_handle_deferrable_databricks_operator_execution(self, hook, self.log, context)
- def execute_complete(self, context: Optional[dict], event: dict):
+ def execute_complete(self, context: dict | None, event: dict):
_handle_deferrable_databricks_operator_completion(event, self.log)
@@ -508,7 +486,7 @@ class DatabricksRunNowOperator(BaseOperator):
The map is passed to the notebook and will be accessible through the
dbutils.widgets.get function. See Widgets for more information.
If not specified upon run-now, the triggered run will use the
- job’s base parameters. notebook_params cannot be
+ job's base parameters. notebook_params cannot be
specified in conjunction with jar_params. The json representation
of this field (i.e. {"notebook_params":{"name":"john doe","age":"35"}})
cannot exceed 10,000 bytes.
@@ -526,8 +504,8 @@ class DatabricksRunNowOperator(BaseOperator):
.. seealso::
https://docs.databricks.com/dev-tools/api/latest/jobs.html#operation/JobsRunNow
- :param python_named_parameters: A list of parameters for jobs with python wheel tasks,
- e.g. "python_named_parameters": {"name": "john doe", "age": "35"}.
+ :param python_named_params: A list of named parameters for jobs with python wheel tasks,
+ e.g. "python_named_params": {"name": "john doe", "age": "35"}.
If specified upon run-now, it would overwrite the parameters specified in job setting.
This field will be templated.
@@ -560,9 +538,9 @@ class DatabricksRunNowOperator(BaseOperator):
:param databricks_conn_id: Reference to the :ref:`Databricks connection `.
By default and in the common case this will be ``databricks_default``. To use
token based authentication, provide the key ``token`` in the extra field for the
- connection and create the key ``host`` and leave the ``host`` field empty.
+ connection and create the key ``host`` and leave the ``host`` field empty. (templated)
:param polling_period_seconds: Controls the rate which we poll for the result of
- this run. By default the operator will poll every 30 seconds.
+ this run. By default, the operator will poll every 30 seconds.
:param databricks_retry_limit: Amount of times retry if the Databricks backend is
unreachable. Its value must be greater than or equal to 1.
:param databricks_retry_delay: Number of seconds to wait between retries (it
@@ -573,30 +551,30 @@ class DatabricksRunNowOperator(BaseOperator):
"""
# Used in airflow.models.BaseOperator
- template_fields: Sequence[str] = ('json',)
- template_ext: Sequence[str] = ('.json',)
+ template_fields: Sequence[str] = ("json", "databricks_conn_id")
+ template_ext: Sequence[str] = (".json-tpl",)
# Databricks brand color (blue) under white text
- ui_color = '#1CB1C2'
- ui_fgcolor = '#fff'
+ ui_color = "#1CB1C2"
+ ui_fgcolor = "#fff"
operator_extra_links = (DatabricksJobRunLink(),)
def __init__(
self,
*,
- job_id: Optional[str] = None,
- job_name: Optional[str] = None,
- json: Optional[Any] = None,
- notebook_params: Optional[Dict[str, str]] = None,
- python_params: Optional[List[str]] = None,
- jar_params: Optional[List[str]] = None,
- spark_submit_params: Optional[List[str]] = None,
- python_named_parameters: Optional[Dict[str, str]] = None,
- idempotency_token: Optional[str] = None,
- databricks_conn_id: str = 'databricks_default',
+ job_id: str | None = None,
+ job_name: str | None = None,
+ json: Any | None = None,
+ notebook_params: dict[str, str] | None = None,
+ python_params: list[str] | None = None,
+ jar_params: list[str] | None = None,
+ spark_submit_params: list[str] | None = None,
+ python_named_params: dict[str, str] | None = None,
+ idempotency_token: str | None = None,
+ databricks_conn_id: str = "databricks_default",
polling_period_seconds: int = 30,
databricks_retry_limit: int = 3,
databricks_retry_delay: int = 1,
- databricks_retry_args: Optional[Dict[Any, Any]] = None,
+ databricks_retry_args: dict[Any, Any] | None = None,
do_xcom_push: bool = True,
wait_for_termination: bool = True,
**kwargs,
@@ -612,66 +590,70 @@ def __init__(
self.wait_for_termination = wait_for_termination
if job_id is not None:
- self.json['job_id'] = job_id
+ self.json["job_id"] = job_id
if job_name is not None:
- self.json['job_name'] = job_name
- if 'job_id' in self.json and 'job_name' in self.json:
+ self.json["job_name"] = job_name
+ if "job_id" in self.json and "job_name" in self.json:
raise AirflowException("Argument 'job_name' is not allowed with argument 'job_id'")
if notebook_params is not None:
- self.json['notebook_params'] = notebook_params
+ self.json["notebook_params"] = notebook_params
if python_params is not None:
- self.json['python_params'] = python_params
- if python_named_parameters is not None:
- self.json['python_named_parameters'] = python_named_parameters
+ self.json["python_params"] = python_params
+ if python_named_params is not None:
+ self.json["python_named_params"] = python_named_params
if jar_params is not None:
- self.json['jar_params'] = jar_params
+ self.json["jar_params"] = jar_params
if spark_submit_params is not None:
- self.json['spark_submit_params'] = spark_submit_params
+ self.json["spark_submit_params"] = spark_submit_params
if idempotency_token is not None:
- self.json['idempotency_token'] = idempotency_token
+ self.json["idempotency_token"] = idempotency_token
- self.json = deep_string_coerce(self.json)
+ self.json = normalise_json_content(self.json)
# This variable will be used in case our task gets killed.
- self.run_id: Optional[int] = None
+ self.run_id: int | None = None
self.do_xcom_push = do_xcom_push
- def _get_hook(self) -> DatabricksHook:
+ @cached_property
+ def _hook(self):
+ return self._get_hook(caller="DatabricksRunNowOperator")
+
+ def _get_hook(self, caller: str) -> DatabricksHook:
return DatabricksHook(
self.databricks_conn_id,
retry_limit=self.databricks_retry_limit,
retry_delay=self.databricks_retry_delay,
retry_args=self.databricks_retry_args,
+ caller=caller,
)
- def execute(self, context: 'Context'):
- hook = self._get_hook()
- if 'job_name' in self.json:
- job_id = hook.find_job_id_by_name(self.json['job_name'])
+ def execute(self, context: Context):
+ hook = self._hook
+ if "job_name" in self.json:
+ job_id = hook.find_job_id_by_name(self.json["job_name"])
if job_id is None:
raise AirflowException(f"Job ID for job name {self.json['job_name']} can not be found")
- self.json['job_id'] = job_id
- del self.json['job_name']
+ self.json["job_id"] = job_id
+ del self.json["job_name"]
self.run_id = hook.run_now(self.json)
_handle_databricks_operator_execution(self, hook, self.log, context)
def on_kill(self):
if self.run_id:
- hook = self._get_hook()
- hook.cancel_run(self.run_id)
+ self._hook.cancel_run(self.run_id)
self.log.info(
- 'Task: %s with run_id: %s was requested to be cancelled.', self.task_id, self.run_id
+ "Task: %s with run_id: %s was requested to be cancelled.", self.task_id, self.run_id
)
else:
- self.log.error('Error: Task: %s with invalid run_id was requested to be cancelled.', self.task_id)
+ self.log.error("Error: Task: %s with invalid run_id was requested to be cancelled.", self.task_id)
class DatabricksRunNowDeferrableOperator(DatabricksRunNowOperator):
"""Deferrable version of ``DatabricksRunNowOperator``"""
def execute(self, context):
- hook = self._get_hook()
+ hook = self._get_hook(caller="DatabricksRunNowDeferrableOperator")
self.run_id = hook.run_now(self.json)
_handle_deferrable_databricks_operator_execution(self, hook, self.log, context)
- def execute_complete(self, context: Optional[dict], event: dict):
+ def execute_complete(self, context: dict | None, event: dict):
_handle_deferrable_databricks_operator_completion(event, self.log)
diff --git a/airflow/providers/databricks/operators/databricks_repos.py b/airflow/providers/databricks/operators/databricks_repos.py
index 97b46b8e244d6..f42114d474267 100644
--- a/airflow/providers/databricks/operators/databricks_repos.py
+++ b/airflow/providers/databricks/operators/databricks_repos.py
@@ -15,12 +15,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-#
"""This module contains Databricks operators."""
+from __future__ import annotations
+
import re
-from typing import TYPE_CHECKING, Optional, Sequence
-from urllib.parse import urlparse
+from typing import TYPE_CHECKING, Sequence
+from urllib.parse import urlsplit
+from airflow.compat.functools import cached_property
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.databricks.hooks.databricks import DatabricksHook
@@ -46,7 +48,7 @@ class DatabricksReposCreateOperator(BaseOperator):
:param databricks_conn_id: Reference to the :ref:`Databricks connection