Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decorators keep functions signatures #9786

Merged
merged 1 commit into from
Jul 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions airflow/api/auth/backend/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# under the License.
"""Default authentication backend - everything is allowed"""
from functools import wraps
from typing import Optional
from typing import Callable, Optional, TypeVar, cast

from airflow.typing_compat import Protocol

Expand All @@ -40,10 +40,13 @@ def init_app(_):
"""Initializes authentication backend"""


def requires_authentication(function):
T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name


def requires_authentication(function: T):
"""Decorator for functions that require authentication"""
@wraps(function)
def decorated(*args, **kwargs):
return function(*args, **kwargs)

return decorated
return cast(T, decorated)
9 changes: 6 additions & 3 deletions airflow/api/auth/backend/deny_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# under the License.
"""Authentication backend that denies all requests"""
from functools import wraps
from typing import Optional
from typing import Callable, Optional, TypeVar, cast

from flask import Response

Expand All @@ -30,12 +30,15 @@ def init_app(_):
"""Initializes authentication"""


def requires_authentication(function):
T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name


def requires_authentication(function: T):
"""Decorator for functions that require authentication"""

# noinspection PyUnusedLocal
@wraps(function)
def decorated(*args, **kwargs): # pylint: disable=unused-argument
return Response("Forbidden", 403)

return decorated
return cast(T, decorated)
8 changes: 6 additions & 2 deletions airflow/api/auth/backend/kerberos_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import os
from functools import wraps
from socket import getfqdn
from typing import Callable, TypeVar, cast

import kerberos
# noinspection PyProtectedMember
Expand Down Expand Up @@ -126,7 +127,10 @@ def _gssapi_authenticate(token):
kerberos.authGSSServerClean(state)


def requires_authentication(function):
T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name


def requires_authentication(function: T):
"""Decorator for functions that require authentication with Kerberos"""
@wraps(function)
def decorated(*args, **kwargs):
Expand All @@ -147,4 +151,4 @@ def decorated(*args, **kwargs):
if return_code != kerberos.AUTH_GSS_CONTINUE:
return _forbidden()
return _unauthorized()
return decorated
return cast(T, decorated)
11 changes: 7 additions & 4 deletions airflow/api_connexion/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from functools import wraps
from typing import Callable, Dict
from typing import Callable, Dict, TypeVar, cast

from pendulum.parsing import ParserError

Expand Down Expand Up @@ -59,7 +59,10 @@ def check_limit(value: int):
return value


def format_parameters(params_formatters: Dict[str, Callable[..., bool]]):
T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name


def format_parameters(params_formatters: Dict[str, Callable[..., bool]]) -> Callable[[T], T]:
"""
Decorator factory that create decorator that convert parameters using given formatters.

Expand All @@ -68,14 +71,14 @@ def format_parameters(params_formatters: Dict[str, Callable[..., bool]]):
:param params_formatters: Map of key name and formatter function
"""

def format_parameters_decorator(func):
def format_parameters_decorator(func: T):
@wraps(func)
def wrapped_function(*args, **kwargs):
for key, formatter in params_formatters.items():
if key in kwargs:
kwargs[key] = formatter(kwargs[key])
return func(*args, **kwargs)

return wrapped_function
return cast(T, wrapped_function)

return format_parameters_decorator
13 changes: 8 additions & 5 deletions airflow/lineage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import json
import logging
from functools import wraps
from typing import Any, Dict, Optional
from typing import Any, Callable, Dict, Optional, TypeVar, cast

import attr
import jinja2
Expand Down Expand Up @@ -79,7 +79,10 @@ def _to_dataset(obj: Any, source: str) -> Optional[Metadata]:
return Metadata(type_name, source, data)


def apply_lineage(func):
T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name


def apply_lineage(func: T) -> T:
"""
Saves the lineage to XCom and if configured to do so sends it
to the backend.
Expand Down Expand Up @@ -110,10 +113,10 @@ def wrapper(self, context, *args, **kwargs):

return ret_val

return wrapper
return cast(T, wrapper)


def prepare_lineage(func):
def prepare_lineage(func: T) -> T:
"""
Prepares the lineage inlets and outlets. Inlets can be:

Expand Down Expand Up @@ -172,4 +175,4 @@ def wrapper(self, context, *args, **kwargs):
self.log.debug("inlets: %s, outlets: %s", self.inlets, self.outlets)
return func(self, context, *args, **kwargs)

return wrapper
return cast(T, wrapper)
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
Base = declarative_base()


class DagRun(Base):
class DagRun(Base): # type: ignore
"""
DagRun describes an instance of a Dag. It can be created
by the scheduler (for regular runs) or by an external trigger
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
ID_LEN = 250


class TaskInstance(Base):
class TaskInstance(Base): # type: ignore
"""
Task instances store the state of a task instance. This table is the
authority and single source of truth around what tasks have run and the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
ID_LEN = 250


class TaskInstance(Base): # noqa: D101
class TaskInstance(Base): # noqa: D101 # type: ignore
__tablename__ = "task_instance"

task_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True)
Expand Down
15 changes: 11 additions & 4 deletions airflow/operators/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from itertools import islice
from tempfile import TemporaryDirectory
from textwrap import dedent
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar, cast

import dill

Expand Down Expand Up @@ -254,7 +254,14 @@ def execute(self, context: Dict):
return return_value


def task(python_callable: Optional[Callable] = None, multiple_outputs: bool = False, **kwargs):
T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name


def task(
python_callable: Optional[Callable] = None,
multiple_outputs: bool = False,
**kwargs
) -> Callable[[T], T]:
"""
Python operator decorator. Wraps a function into an Airflow operator.
Accepts kwargs for operator kwarg. Can be reused in a single DAG.
Expand All @@ -268,7 +275,7 @@ def task(python_callable: Optional[Callable] = None, multiple_outputs: bool = Fa
:type multiple_outputs: bool

"""
def wrapper(f):
def wrapper(f: T):
"""
Python wrapper to generate PythonFunctionalOperator out of simple python functions.
Used for Airflow functional interface
Expand All @@ -281,7 +288,7 @@ def factory(*args, **f_kwargs):
op = _PythonFunctionalOperator(python_callable=f, op_args=args, op_kwargs=f_kwargs,
multiple_outputs=multiple_outputs, **kwargs)
return XComArg(op)
return factory
return cast(T, factory)
if callable(python_callable):
return wrapper(python_callable)
elif python_callable is not None:
Expand Down
12 changes: 7 additions & 5 deletions airflow/providers/amazon/aws/hooks/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from functools import wraps
from inspect import signature
from tempfile import NamedTemporaryFile
from typing import Optional
from typing import Callable, Optional, TypeVar, cast
from urllib.parse import urlparse

from botocore.exceptions import ClientError
Expand All @@ -37,8 +37,10 @@
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.utils.helpers import chunks

T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name

def provide_bucket_name(func):

def provide_bucket_name(func: T) -> T:
"""
Function decorator that provides a bucket name taken from the connection
in case no bucket name has been passed to the function.
Expand All @@ -59,10 +61,10 @@ def wrapper(*args, **kwargs):

return func(*bound_args.args, **bound_args.kwargs)

return wrapper
return cast(T, wrapper)


def unify_bucket_name_and_key(func):
def unify_bucket_name_and_key(func: T) -> T:
"""
Function decorator that unifies bucket name and key taken from the key
in case no bucket name and at least a key has been passed to the function.
Expand All @@ -88,7 +90,7 @@ def get_key_name():

return func(*bound_args.args, **bound_args.kwargs)

return wrapper
return cast(T, wrapper)


class S3Hook(AwsBaseHook):
Expand Down
12 changes: 6 additions & 6 deletions airflow/providers/google/cloud/hooks/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import warnings
from copy import deepcopy
from tempfile import TemporaryDirectory
from typing import Any, Callable, Dict, List, Optional, TypeVar
from typing import Any, Callable, Dict, List, Optional, TypeVar, cast

from googleapiclient.discovery import build

Expand All @@ -47,20 +47,20 @@
r'Submitted job: (?P<job_id_java>.*)|Created job with id: \[(?P<job_id_python>.*)\]'
)

RT = TypeVar('RT') # pylint: disable=invalid-name
T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name


def _fallback_variable_parameter(parameter_name, variable_key_name):
def _fallback_variable_parameter(parameter_name: str, variable_key_name: str) -> Callable[[T], T]:

def _wrapper(func: Callable[..., RT]) -> Callable[..., RT]:
def _wrapper(func: T) -> T:
"""
Decorator that provides fallback for location from `region` key in `variables` parameters.

:param func: function to wrap
:return: result of the function call
"""
@functools.wraps(func)
def inner_wrapper(self: "DataflowHook", *args, **kwargs) -> RT:
def inner_wrapper(self: "DataflowHook", *args, **kwargs):
if args:
raise AirflowException(
"You must use keyword arguments in this methods rather than positional")
Expand All @@ -81,7 +81,7 @@ def inner_wrapper(self: "DataflowHook", *args, **kwargs) -> RT:
kwargs['variables'] = copy_variables

return func(self, *args, **kwargs)
return inner_wrapper
return cast(T, inner_wrapper)

return _wrapper

Expand Down
9 changes: 5 additions & 4 deletions airflow/providers/google/cloud/hooks/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from io import BytesIO
from os import path
from tempfile import NamedTemporaryFile
from typing import Optional, Set, Tuple, TypeVar, Union
from typing import Callable, Optional, Set, Tuple, TypeVar, Union, cast
from urllib.parse import urlparse

from google.api_core.exceptions import NotFound
Expand All @@ -39,13 +39,14 @@
from airflow.version import version

RT = TypeVar('RT') # pylint: disable=invalid-name
T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name


def _fallback_object_url_to_object_name_and_bucket_name(
object_url_keyword_arg_name='object_url',
bucket_name_keyword_arg_name='bucket_name',
object_name_keyword_arg_name='object_name',
):
) -> Callable[[T], T]:
"""
Decorator factory that convert object URL parameter to object name and bucket name parameter.

Expand All @@ -57,7 +58,7 @@ def _fallback_object_url_to_object_name_and_bucket_name(
:type object_name_keyword_arg_name: str
:return: Decorator
"""
def _wrapper(func):
def _wrapper(func: T):

@functools.wraps(func)
def _inner_wrapper(self: "GCSHook", * args, **kwargs) -> RT:
Expand Down Expand Up @@ -99,7 +100,7 @@ def _inner_wrapper(self: "GCSHook", * args, **kwargs) -> RT:
)

return func(self, *args, **kwargs)
return _inner_wrapper
return cast(T, _inner_wrapper)
return _wrapper


Expand Down
Loading