Skip to content

Commit

Permalink
Enable Black on Connexion API folders (apache#10545)
Browse files Browse the repository at this point in the history
  • Loading branch information
kaxil authored Aug 25, 2020
1 parent d6ce8c8 commit 7c0d6ab
Show file tree
Hide file tree
Showing 62 changed files with 641 additions and 1,017 deletions.
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[flake8]
max-line-length = 110
ignore = E731,W504,I001,W503
ignore = E231,E731,W504,I001,W503
exclude = .svn,CVS,.bzr,.hg,.git,__pycache__,.eggs,*.egg,node_modules
format = ${cyan}%(path)s${reset}:${yellow_bold}%(row)d${reset}:${green_bold}%(col)d${reset}: ${red_bold}%(code)s${reset} %(text)s
per-file-ignores =
Expand Down
8 changes: 7 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,12 @@ repos:
- repo: meta
hooks:
- id: check-hooks-apply
- repo: https://github.com/psf/black
rev: stable
hooks:
- id: black
files: api_connexion/.*\.py
args: [--config=./pyproject.toml]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.2.0
hooks:
Expand Down Expand Up @@ -184,7 +190,7 @@ repos:
name: Run isort to sort imports
types: [python]
# To keep consistent with the global isort skip config defined in setup.cfg
exclude: ^build/.*$|^.tox/.*$|^venv/.*$
exclude: ^build/.*$|^.tox/.*$|^venv/.*$|.*api_connexion/.*\.py
- repo: https://github.com/pycqa/pydocstyle
rev: 5.0.2
hooks:
Expand Down
12 changes: 5 additions & 7 deletions airflow/api_connexion/endpoints/config_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,7 @@ def _conf_dict_to_config(conf_dict: dict) -> Config:
config = Config(
sections=[
ConfigSection(
name=section,
options=[
ConfigOption(key=key, value=value)
for key, value in options.items()
]
name=section, options=[ConfigOption(key=key, value=value) for key, value in options.items()]
)
for section, options in conf_dict.items()
]
Expand All @@ -49,8 +45,10 @@ def _option_to_text(config_option: ConfigOption) -> str:

def _section_to_text(config_section: ConfigSection) -> str:
"""Convert a single config section to text"""
return (f'[{config_section.name}]{LINE_SEP}'
f'{LINE_SEP.join(_option_to_text(option) for option in config_section.options)}{LINE_SEP}')
return (
f'[{config_section.name}]{LINE_SEP}'
f'{LINE_SEP.join(_option_to_text(option) for option in config_section.options)}{LINE_SEP}'
)


def _config_to_text(config: Config) -> str:
Expand Down
14 changes: 8 additions & 6 deletions airflow/api_connexion/endpoints/connection_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@
from airflow.api_connexion.exceptions import AlreadyExists, BadRequest, NotFound
from airflow.api_connexion.parameters import check_limit, format_parameters
from airflow.api_connexion.schemas.connection_schema import (
ConnectionCollection, connection_collection_item_schema, connection_collection_schema, connection_schema,
ConnectionCollection,
connection_collection_item_schema,
connection_collection_schema,
connection_schema,
)
from airflow.models import Connection
from airflow.utils.session import provide_session
Expand Down Expand Up @@ -56,9 +59,7 @@ def get_connection(connection_id, session):


@security.requires_authentication
@format_parameters({
'limit': check_limit
})
@format_parameters({'limit': check_limit})
@provide_session
def get_connections(session, limit, offset=0):
"""
Expand All @@ -67,8 +68,9 @@ def get_connections(session, limit, offset=0):
total_entries = session.query(func.count(Connection.id)).scalar()
query = session.query(Connection)
connections = query.order_by(Connection.id).offset(offset).limit(limit).all()
return connection_collection_schema.dump(ConnectionCollection(connections=connections,
total_entries=total_entries))
return connection_collection_schema.dump(
ConnectionCollection(connections=connections, total_entries=total_entries)
)


@security.requires_authentication
Expand Down
9 changes: 5 additions & 4 deletions airflow/api_connexion/endpoints/dag_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
from airflow.api_connexion.exceptions import BadRequest, NotFound
from airflow.api_connexion.parameters import check_limit, format_parameters
from airflow.api_connexion.schemas.dag_schema import (
DAGCollection, dag_detail_schema, dag_schema, dags_collection_schema,
DAGCollection,
dag_detail_schema,
dag_schema,
dags_collection_schema,
)
from airflow.models.dag import DagModel
from airflow.utils.session import provide_session
Expand Down Expand Up @@ -55,9 +58,7 @@ def get_dag_details(dag_id):


@security.requires_authentication
@format_parameters({
'limit': check_limit
})
@format_parameters({'limit': check_limit})
@provide_session
def get_dags(session, limit, offset=0):
"""
Expand Down
105 changes: 68 additions & 37 deletions airflow/api_connexion/endpoints/dag_run_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
from airflow.api_connexion.exceptions import AlreadyExists, BadRequest, NotFound
from airflow.api_connexion.parameters import check_limit, format_datetime, format_parameters
from airflow.api_connexion.schemas.dag_run_schema import (
DAGRunCollection, dagrun_collection_schema, dagrun_schema, dagruns_batch_form_schema,
DAGRunCollection,
dagrun_collection_schema,
dagrun_schema,
dagruns_batch_form_schema,
)
from airflow.models import DagModel, DagRun
from airflow.utils.session import provide_session
Expand All @@ -36,11 +39,7 @@ def delete_dag_run(dag_id, dag_run_id, session):
"""
Delete a DAG Run
"""
if (
session.query(DagRun)
.filter(DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id)
.delete() == 0
):
if session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id).delete() == 0:
raise NotFound(detail=f"DAGRun with DAG ID: '{dag_id}' and DagRun ID: '{dag_run_id}' not found")
return NoContent, 204

Expand All @@ -51,23 +50,24 @@ def get_dag_run(dag_id, dag_run_id, session):
"""
Get a DAG Run.
"""
dag_run = session.query(DagRun).filter(
DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id).one_or_none()
dag_run = session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id).one_or_none()
if dag_run is None:
raise NotFound("DAGRun not found")
return dagrun_schema.dump(dag_run)


@security.requires_authentication
@format_parameters({
'start_date_gte': format_datetime,
'start_date_lte': format_datetime,
'execution_date_gte': format_datetime,
'execution_date_lte': format_datetime,
'end_date_gte': format_datetime,
'end_date_lte': format_datetime,
'limit': check_limit
})
@format_parameters(
{
'start_date_gte': format_datetime,
'start_date_lte': format_datetime,
'execution_date_gte': format_datetime,
'execution_date_lte': format_datetime,
'end_date_gte': format_datetime,
'end_date_lte': format_datetime,
'limit': check_limit,
}
)
@provide_session
def get_dag_runs(
session,
Expand All @@ -91,27 +91,52 @@ def get_dag_runs(
if dag_id != "~":
query = query.filter(DagRun.dag_id == dag_id)

dag_run, total_entries = _fetch_dag_runs(query, session, end_date_gte, end_date_lte, execution_date_gte,
execution_date_lte, start_date_gte, start_date_lte,
limit, offset)
dag_run, total_entries = _fetch_dag_runs(
query,
session,
end_date_gte,
end_date_lte,
execution_date_gte,
execution_date_lte,
start_date_gte,
start_date_lte,
limit,
offset,
)

return dagrun_collection_schema.dump(DAGRunCollection(dag_runs=dag_run,
total_entries=total_entries))
return dagrun_collection_schema.dump(DAGRunCollection(dag_runs=dag_run, total_entries=total_entries))


def _fetch_dag_runs(query, session, end_date_gte, end_date_lte,
execution_date_gte, execution_date_lte,
start_date_gte, start_date_lte, limit, offset):
query = _apply_date_filters_to_query(query, end_date_gte, end_date_lte, execution_date_gte,
execution_date_lte, start_date_gte, start_date_lte)
def _fetch_dag_runs(
query,
session,
end_date_gte,
end_date_lte,
execution_date_gte,
execution_date_lte,
start_date_gte,
start_date_lte,
limit,
offset,
):
query = _apply_date_filters_to_query(
query,
end_date_gte,
end_date_lte,
execution_date_gte,
execution_date_lte,
start_date_gte,
start_date_lte,
)
# apply offset and limit
dag_run = query.order_by(DagRun.id).offset(offset).limit(limit).all()
total_entries = session.query(func.count(DagRun.id)).scalar()
return dag_run, total_entries


def _apply_date_filters_to_query(query, end_date_gte, end_date_lte, execution_date_gte,
execution_date_lte, start_date_gte, start_date_lte):
def _apply_date_filters_to_query(
query, end_date_gte, end_date_lte, execution_date_gte, execution_date_lte, start_date_gte, start_date_lte
):
# filter start date
if start_date_gte:
query = query.filter(DagRun.start_date >= start_date_gte)
Expand Down Expand Up @@ -147,13 +172,20 @@ def get_dag_runs_batch(session):
if data["dag_ids"]:
query = query.filter(DagRun.dag_id.in_(data["dag_ids"]))

dag_runs, total_entries = _fetch_dag_runs(query, session, data["end_date_gte"], data["end_date_lte"],
data["execution_date_gte"], data["execution_date_lte"],
data["start_date_gte"], data["start_date_lte"],
data["page_limit"], data["page_offset"])
dag_runs, total_entries = _fetch_dag_runs(
query,
session,
data["end_date_gte"],
data["end_date_lte"],
data["execution_date_gte"],
data["execution_date_lte"],
data["start_date_gte"],
data["start_date_lte"],
data["page_limit"],
data["page_offset"],
)

return dagrun_collection_schema.dump(DAGRunCollection(dag_runs=dag_runs,
total_entries=total_entries))
return dagrun_collection_schema.dump(DAGRunCollection(dag_runs=dag_runs, total_entries=total_entries))


@security.requires_authentication
Expand All @@ -167,8 +199,7 @@ def post_dag_run(dag_id, session):

post_body = dagrun_schema.load(request.json, session=session)
dagrun_instance = (
session.query(DagRun).filter(
DagRun.dag_id == dag_id, DagRun.run_id == post_body["run_id"]).first()
session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == post_body["run_id"]).first()
)
if not dagrun_instance:
dag_run = DagRun(dag_id=dag_id, run_type=DagRunType.MANUAL.value, **post_body)
Expand Down
13 changes: 7 additions & 6 deletions airflow/api_connexion/endpoints/event_log_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
from airflow.api_connexion.exceptions import NotFound
from airflow.api_connexion.parameters import check_limit, format_parameters
from airflow.api_connexion.schemas.event_log_schema import (
EventLogCollection, event_log_collection_schema, event_log_schema,
EventLogCollection,
event_log_collection_schema,
event_log_schema,
)
from airflow.models import Log
from airflow.utils.session import provide_session
Expand All @@ -41,9 +43,7 @@ def get_event_log(event_log_id, session):


@security.requires_authentication
@format_parameters({
'limit': check_limit
})
@format_parameters({'limit': check_limit})
@provide_session
def get_event_logs(session, limit, offset=None):
"""
Expand All @@ -52,5 +52,6 @@ def get_event_logs(session, limit, offset=None):

total_entries = session.query(func.count(Log.id)).scalar()
event_logs = session.query(Log).order_by(Log.id).offset(offset).limit(limit).all()
return event_log_collection_schema.dump(EventLogCollection(event_logs=event_logs,
total_entries=total_entries))
return event_log_collection_schema.dump(
EventLogCollection(event_logs=event_logs, total_entries=total_entries)
)
5 changes: 1 addition & 4 deletions airflow/api_connexion/endpoints/health_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,7 @@ def get_health():

payload = {
"metadatabase": {"status": metadatabase_status},
"scheduler": {
"status": scheduler_status,
"latest_scheduler_heartbeat": latest_scheduler_heartbeat,
},
"scheduler": {"status": scheduler_status, "latest_scheduler_heartbeat": latest_scheduler_heartbeat,},
}

return health_schema.dump(payload)
8 changes: 4 additions & 4 deletions airflow/api_connexion/endpoints/import_error_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
from airflow.api_connexion.exceptions import NotFound
from airflow.api_connexion.parameters import check_limit, format_parameters
from airflow.api_connexion.schemas.error_schema import (
ImportErrorCollection, import_error_collection_schema, import_error_schema,
ImportErrorCollection,
import_error_collection_schema,
import_error_schema,
)
from airflow.models.errors import ImportError # pylint: disable=redefined-builtin
from airflow.utils.session import provide_session
Expand All @@ -41,9 +43,7 @@ def get_import_error(import_error_id, session):


@security.requires_authentication
@format_parameters({
'limit': check_limit
})
@format_parameters({'limit': check_limit})
@provide_session
def get_import_errors(session, limit, offset=None):
"""
Expand Down
12 changes: 3 additions & 9 deletions airflow/api_connexion/endpoints/log_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@

@security.requires_authentication
@provide_session
def get_log(session, dag_id, dag_run_id, task_id, task_try_number,
full_content=False, token=None):
def get_log(session, dag_id, dag_run_id, task_id, task_try_number, full_content=False, token=None):
"""
Get logs for specific task instance
"""
Expand Down Expand Up @@ -77,13 +76,8 @@ def get_log(session, dag_id, dag_run_id, task_id, task_try_number,
logs, metadata = task_log_reader.read_log_chunks(ti, task_try_number, metadata)
logs = logs[0] if task_try_number is not None else logs
token = URLSafeSerializer(key).dumps(metadata)
return logs_schema.dump(LogResponseObject(continuation_token=token,
content=logs)
)
return logs_schema.dump(LogResponseObject(continuation_token=token, content=logs))
# text/plain. Stream
logs = task_log_reader.read_log_stream(ti, task_try_number, metadata)

return Response(
logs,
headers={"Content-Type": return_type}
)
return Response(logs, headers={"Content-Type": return_type})
8 changes: 2 additions & 6 deletions airflow/api_connexion/endpoints/pool_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,7 @@ def get_pool(pool_name, session):


@security.requires_authentication
@format_parameters({
'limit': check_limit
})
@format_parameters({'limit': check_limit})
@provide_session
def get_pools(session, limit, offset=None):
"""
Expand All @@ -65,9 +63,7 @@ def get_pools(session, limit, offset=None):

total_entries = session.query(func.count(Pool.id)).scalar()
pools = session.query(Pool).order_by(Pool.id).offset(offset).limit(limit).all()
return pool_collection_schema.dump(
PoolCollection(pools=pools, total_entries=total_entries)
)
return pool_collection_schema.dump(PoolCollection(pools=pools, total_entries=total_entries))


@security.requires_authentication
Expand Down
Loading

0 comments on commit 7c0d6ab

Please sign in to comment.