Skip to content

Commit

Permalink
Add cli command for 'airflow dags reserialize` (apache#19471)
Browse files Browse the repository at this point in the history
  • Loading branch information
collinmcnulty authored and Dillon Johnson committed Dec 1, 2021
1 parent b241203 commit 9d32d48
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 0 deletions.
18 changes: 18 additions & 0 deletions airflow/cli/cli_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,13 @@ def _check(value):
help="The maximum number of triggers that a Triggerer will run at one time.",
)

# reserialize
ARG_CLEAR_ONLY = Arg(
("--clear-only",),
action="store_true",
help="If passed, serialized DAGs will be cleared but not reserialized.",
)

ALTERNATIVE_CONN_SPECS_ARGS = [
ARG_CONN_TYPE,
ARG_CONN_DESCRIPTION,
Expand Down Expand Up @@ -977,6 +984,17 @@ class GroupCommand(NamedTuple):
ARG_SAVE_DAGRUN,
),
),
ActionCommand(
name='reserialize',
help="Reserialize all DAGs by parsing the DagBag files",
description=(
"Drop all serialized dags from the metadata DB. This will cause all DAGs to be reserialized "
"from the DagBag folder. This can be helpful if your serialized DAGs get out of sync with the "
"version of Airflow that you are running."
),
func=lazy_load_command('airflow.cli.commands.dag_command.dag_reserialize'),
args=(ARG_CLEAR_ONLY,),
),
)
TASKS_COMMANDS = (
ActionCommand(
Expand Down
12 changes: 12 additions & 0 deletions airflow/cli/commands/dag_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from airflow.jobs.base_job import BaseJob
from airflow.models import DagBag, DagModel, DagRun, TaskInstance
from airflow.models.dag import DAG
from airflow.models.serialized_dag import SerializedDagModel
from airflow.utils import cli as cli_utils
from airflow.utils.cli import (
get_dag,
Expand Down Expand Up @@ -441,3 +442,14 @@ def dag_test(args, session=None):
_display_dot_via_imgcat(dot_graph)
if show_dagrun:
print(dot_graph.source)


@provide_session
@cli_utils.action_logging
def dag_reserialize(args, session=None):
session.query(SerializedDagModel).delete(synchronize_session=False)

if not args.clear_only:
dagbag = DagBag()
dagbag.collect_dags(only_if_updated=False, safe_mode=False)
dagbag.sync_to_db()
3 changes: 3 additions & 0 deletions docs/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ Redhat
ReidentifyContentResponse
Reinitialising
Remoting
Reserialize
ResourceRequirements
Roadmap
Robinhood
Expand Down Expand Up @@ -1167,6 +1168,8 @@ replicaSet
repo
repos
reqs
reserialize
reserialized
resetdb
resourceVersion
resultset
Expand Down
22 changes: 22 additions & 0 deletions tests/cli/commands/test_dag_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from airflow.cli.commands import dag_command
from airflow.exceptions import AirflowException
from airflow.models import DagBag, DagModel, DagRun
from airflow.models.serialized_dag import SerializedDagModel
from airflow.utils import timezone
from airflow.utils.session import create_session
from airflow.utils.state import State
Expand Down Expand Up @@ -62,6 +63,27 @@ def tearDownClass(cls) -> None:
clear_db_runs()
clear_db_dags()

def test_reserialize(self):
# Assert that there are serialized Dags
with create_session() as session:
serialized_dags_before_command = session.query(SerializedDagModel).all()
assert len(serialized_dags_before_command) # There are serialized DAGs to delete

# Run clear of serialized dags
dag_command.dag_reserialize(self.parser.parse_args(['dags', 'reserialize', "--clear-only"]))
# Assert no serialized Dags
with create_session() as session:
serialized_dags_after_clear = session.query(SerializedDagModel).all()
assert not len(serialized_dags_after_clear)

# Serialize manually
dag_command.dag_reserialize(self.parser.parse_args(['dags', 'reserialize']))

# Check serialized DAGs are back
with create_session() as session:
serialized_dags_after_reserialize = session.query(SerializedDagModel).all()
assert len(serialized_dags_after_reserialize) >= 40 # Serialized DAGs back

@mock.patch("airflow.cli.commands.dag_command.DAG.run")
def test_backfill(self, mock_run):
dag_command.dag_backfill(
Expand Down

0 comments on commit 9d32d48

Please sign in to comment.