Keep functions signatures in decorators (#9786)

This commit is contained in:
Kamil Breguła 2020-07-13 22:56:31 +02:00 коммит произвёл GitHub
Родитель 68925904e4
Коммит 553bb7af7c
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
16 изменённых файлов: 95 добавлений и 61 удалений

Просмотреть файл

@ -17,7 +17,7 @@
# under the License.
"""Default authentication backend - everything is allowed"""
from functools import wraps
from typing import Optional
from typing import Callable, Optional, TypeVar, cast
from airflow.typing_compat import Protocol
@ -40,10 +40,13 @@ def init_app(_):
"""Initializes authentication backend"""
def requires_authentication(function):
T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name
def requires_authentication(function: T):
"""Decorator for functions that require authentication"""
@wraps(function)
def decorated(*args, **kwargs):
return function(*args, **kwargs)
return decorated
return cast(T, decorated)

Просмотреть файл

@ -17,7 +17,7 @@
# under the License.
"""Authentication backend that denies all requests"""
from functools import wraps
from typing import Optional
from typing import Callable, Optional, TypeVar, cast
from flask import Response
@ -30,7 +30,10 @@ def init_app(_):
"""Initializes authentication"""
def requires_authentication(function):
T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name
def requires_authentication(function: T):
"""Decorator for functions that require authentication"""
# noinspection PyUnusedLocal
@ -38,4 +41,4 @@ def requires_authentication(function):
def decorated(*args, **kwargs): # pylint: disable=unused-argument
return Response("Forbidden", 403)
return decorated
return cast(T, decorated)

Просмотреть файл

@ -44,6 +44,7 @@ import logging
import os
from functools import wraps
from socket import getfqdn
from typing import Callable, TypeVar, cast
import kerberos
# noinspection PyProtectedMember
@ -126,7 +127,10 @@ def _gssapi_authenticate(token):
kerberos.authGSSServerClean(state)
def requires_authentication(function):
T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name
def requires_authentication(function: T):
"""Decorator for functions that require authentication with Kerberos"""
@wraps(function)
def decorated(*args, **kwargs):
@ -147,4 +151,4 @@ def requires_authentication(function):
if return_code != kerberos.AUTH_GSS_CONTINUE:
return _forbidden()
return _unauthorized()
return decorated
return cast(T, decorated)

Просмотреть файл

@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from functools import wraps
from typing import Callable, Dict
from typing import Callable, Dict, TypeVar, cast
from pendulum.parsing import ParserError
@ -59,7 +59,10 @@ def check_limit(value: int):
return value
def format_parameters(params_formatters: Dict[str, Callable[..., bool]]):
T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name
def format_parameters(params_formatters: Dict[str, Callable[..., bool]]) -> Callable[[T], T]:
"""
Decorator factory that create decorator that convert parameters using given formatters.
@ -68,7 +71,7 @@ def format_parameters(params_formatters: Dict[str, Callable[..., bool]]):
:param params_formatters: Map of key name and formatter function
"""
def format_parameters_decorator(func):
def format_parameters_decorator(func: T):
@wraps(func)
def wrapped_function(*args, **kwargs):
for key, formatter in params_formatters.items():
@ -76,6 +79,6 @@ def format_parameters(params_formatters: Dict[str, Callable[..., bool]]):
kwargs[key] = formatter(kwargs[key])
return func(*args, **kwargs)
return wrapped_function
return cast(T, wrapped_function)
return format_parameters_decorator

Просмотреть файл

@ -21,7 +21,7 @@ Provides lineage support functions
import json
import logging
from functools import wraps
from typing import Any, Dict, Optional
from typing import Any, Callable, Dict, Optional, TypeVar, cast
import attr
import jinja2
@ -79,7 +79,10 @@ def _to_dataset(obj: Any, source: str) -> Optional[Metadata]:
return Metadata(type_name, source, data)
def apply_lineage(func):
T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name
def apply_lineage(func: T) -> T:
"""
Saves the lineage to XCom and if configured to do so sends it
to the backend.
@ -110,10 +113,10 @@ def apply_lineage(func):
return ret_val
return wrapper
return cast(T, wrapper)
def prepare_lineage(func):
def prepare_lineage(func: T) -> T:
"""
Prepares the lineage inlets and outlets. Inlets can be:
@ -172,4 +175,4 @@ def prepare_lineage(func):
self.log.debug("inlets: %s, outlets: %s", self.inlets, self.outlets)
return func(self, context, *args, **kwargs)
return wrapper
return cast(T, wrapper)

Просмотреть файл

@ -46,7 +46,7 @@ depends_on = None
Base = declarative_base()
class DagRun(Base):
class DagRun(Base): # type: ignore
"""
DagRun describes an instance of a Dag. It can be created
by the scheduler (for regular runs) or by an external trigger

Просмотреть файл

@ -44,7 +44,7 @@ Base = declarative_base()
ID_LEN = 250
class TaskInstance(Base):
class TaskInstance(Base): # type: ignore
"""
Task instances store the state of a task instance. This table is the
authority and single source of truth around what tasks have run and the

Просмотреть файл

@ -44,7 +44,7 @@ BATCH_SIZE = 5000
ID_LEN = 250
class TaskInstance(Base): # noqa: D101
class TaskInstance(Base): # noqa: D101 # type: ignore
__tablename__ = "task_instance"
task_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True)

Просмотреть файл

@ -27,7 +27,7 @@ from inspect import signature
from itertools import islice
from tempfile import TemporaryDirectory
from textwrap import dedent
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar, cast
import dill
@ -254,7 +254,14 @@ class _PythonFunctionalOperator(BaseOperator):
return return_value
def task(python_callable: Optional[Callable] = None, multiple_outputs: bool = False, **kwargs):
T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name
def task(
python_callable: Optional[Callable] = None,
multiple_outputs: bool = False,
**kwargs
) -> Callable[[T], T]:
"""
Python operator decorator. Wraps a function into an Airflow operator.
Accepts kwargs for operator kwarg. Can be reused in a single DAG.
@ -268,7 +275,7 @@ def task(python_callable: Optional[Callable] = None, multiple_outputs: bool = Fa
:type multiple_outputs: bool
"""
def wrapper(f):
def wrapper(f: T):
"""
Python wrapper to generate PythonFunctionalOperator out of simple python functions.
Used for Airflow functional interface
@ -281,7 +288,7 @@ def task(python_callable: Optional[Callable] = None, multiple_outputs: bool = Fa
op = _PythonFunctionalOperator(python_callable=f, op_args=args, op_kwargs=f_kwargs,
multiple_outputs=multiple_outputs, **kwargs)
return XComArg(op)
return factory
return cast(T, factory)
if callable(python_callable):
return wrapper(python_callable)
elif python_callable is not None:

Просмотреть файл

@ -28,7 +28,7 @@ import shutil
from functools import wraps
from inspect import signature
from tempfile import NamedTemporaryFile
from typing import Optional
from typing import Callable, Optional, TypeVar, cast
from urllib.parse import urlparse
from botocore.exceptions import ClientError
@ -37,8 +37,10 @@ from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.utils.helpers import chunks
T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name
def provide_bucket_name(func):
def provide_bucket_name(func: T) -> T:
"""
Function decorator that provides a bucket name taken from the connection
in case no bucket name has been passed to the function.
@ -59,10 +61,10 @@ def provide_bucket_name(func):
return func(*bound_args.args, **bound_args.kwargs)
return wrapper
return cast(T, wrapper)
def unify_bucket_name_and_key(func):
def unify_bucket_name_and_key(func: T) -> T:
"""
Function decorator that unifies bucket name and key taken from the key
in case no bucket name and at least a key has been passed to the function.
@ -88,7 +90,7 @@ def unify_bucket_name_and_key(func):
return func(*bound_args.args, **bound_args.kwargs)
return wrapper
return cast(T, wrapper)
class S3Hook(AwsBaseHook):

Просмотреть файл

@ -29,7 +29,7 @@ import uuid
import warnings
from copy import deepcopy
from tempfile import TemporaryDirectory
from typing import Any, Callable, Dict, List, Optional, TypeVar
from typing import Any, Callable, Dict, List, Optional, TypeVar, cast
from googleapiclient.discovery import build
@ -47,12 +47,12 @@ JOB_ID_PATTERN = re.compile(
r'Submitted job: (?P<job_id_java>.*)|Created job with id: \[(?P<job_id_python>.*)\]'
)
RT = TypeVar('RT') # pylint: disable=invalid-name
T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name
def _fallback_variable_parameter(parameter_name, variable_key_name):
def _fallback_variable_parameter(parameter_name: str, variable_key_name: str) -> Callable[[T], T]:
def _wrapper(func: Callable[..., RT]) -> Callable[..., RT]:
def _wrapper(func: T) -> T:
"""
Decorator that provides fallback for location from `region` key in `variables` parameters.
@ -60,7 +60,7 @@ def _fallback_variable_parameter(parameter_name, variable_key_name):
:return: result of the function call
"""
@functools.wraps(func)
def inner_wrapper(self: "DataflowHook", *args, **kwargs) -> RT:
def inner_wrapper(self: "DataflowHook", *args, **kwargs):
if args:
raise AirflowException(
"You must use keyword arguments in this methods rather than positional")
@ -81,7 +81,7 @@ def _fallback_variable_parameter(parameter_name, variable_key_name):
kwargs['variables'] = copy_variables
return func(self, *args, **kwargs)
return inner_wrapper
return cast(T, inner_wrapper)
return _wrapper

Просмотреть файл

@ -28,7 +28,7 @@ from contextlib import contextmanager
from io import BytesIO
from os import path
from tempfile import NamedTemporaryFile
from typing import Optional, Set, Tuple, TypeVar, Union
from typing import Callable, Optional, Set, Tuple, TypeVar, Union, cast
from urllib.parse import urlparse
from google.api_core.exceptions import NotFound
@ -39,13 +39,14 @@ from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
from airflow.version import version
RT = TypeVar('RT') # pylint: disable=invalid-name
T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name
def _fallback_object_url_to_object_name_and_bucket_name(
object_url_keyword_arg_name='object_url',
bucket_name_keyword_arg_name='bucket_name',
object_name_keyword_arg_name='object_name',
):
) -> Callable[[T], T]:
"""
Decorator factory that convert object URL parameter to object name and bucket name parameter.
@ -57,7 +58,7 @@ def _fallback_object_url_to_object_name_and_bucket_name(
:type object_name_keyword_arg_name: str
:return: Decorator
"""
def _wrapper(func):
def _wrapper(func: T):
@functools.wraps(func)
def _inner_wrapper(self: "GCSHook", * args, **kwargs) -> RT:
@ -99,7 +100,7 @@ def _fallback_object_url_to_object_name_and_bucket_name(
)
return func(self, *args, **kwargs)
return _inner_wrapper
return cast(T, _inner_wrapper)
return _wrapper

Просмотреть файл

@ -26,7 +26,7 @@ import os
import tempfile
from contextlib import contextmanager
from subprocess import check_output
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, TypeVar
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, TypeVar, cast
import google.auth
import google.auth.credentials
@ -118,6 +118,7 @@ class retry_if_operation_in_progress(tenacity.retry_if_exception): # pylint: di
super().__init__(is_operation_in_progress_exception)
T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name
RT = TypeVar('RT') # pylint: disable=invalid-name
@ -309,13 +310,13 @@ class GoogleBaseHook(BaseHook):
return decorator
@staticmethod
def operation_in_progress_retry(*args, **kwargs) -> Callable:
def operation_in_progress_retry(*args, **kwargs) -> Callable[[T], T]:
"""
A decorator that provides a mechanism to repeat requests in response to
operation in progress (HTTP 409)
limit.
"""
def decorator(fun: Callable):
def decorator(fun: T):
default_kwargs = {
'wait': tenacity.wait_exponential(multiplier=1, max=300),
'retry': retry_if_operation_in_progress(),
@ -323,9 +324,9 @@ class GoogleBaseHook(BaseHook):
'after': tenacity.after_log(log, logging.DEBUG),
}
default_kwargs.update(**kwargs)
return tenacity.retry(
return cast(T, tenacity.retry(
*args, **default_kwargs
)(fun)
)(fun))
return decorator
@staticmethod
@ -357,7 +358,7 @@ class GoogleBaseHook(BaseHook):
return inner_wrapper
@staticmethod
def provide_gcp_credential_file(func: Callable[..., RT]) -> Callable[..., RT]:
def provide_gcp_credential_file(func: T) -> T:
"""
Function decorator that provides a GCP credentials for application supporting Application
Default Credentials (ADC) strategy.
@ -367,10 +368,10 @@ class GoogleBaseHook(BaseHook):
makes it easier to use multiple connection in one function.
"""
@functools.wraps(func)
def wrapper(self: GoogleBaseHook, *args, **kwargs) -> RT:
def wrapper(self: GoogleBaseHook, *args, **kwargs):
with self.provide_gcp_credential_file_as_context():
return func(self, *args, **kwargs)
return wrapper
return cast(T, wrapper)
@contextmanager
def provide_gcp_credential_file_as_context(self):

Просмотреть файл

@ -16,13 +16,12 @@
# specific language governing permissions and limitations
# under the License.
import logging
import socket
import string
import textwrap
from functools import wraps
from typing import TYPE_CHECKING, Callable, Optional
from typing import TYPE_CHECKING, Callable, Optional, TypeVar, cast
from airflow.configuration import conf
from airflow.exceptions import AirflowConfigException, InvalidStatsNameException
@ -94,7 +93,10 @@ def get_current_handler_stat_name_func() -> Callable[[str], str]:
return conf.getimport('scheduler', 'stat_name_handler') or stat_name_default_handler
def validate_stat(fn):
T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name
def validate_stat(fn: T) -> T:
"""Check if stat name contains invalid characters.
Log and not emit stats if name is invalid
"""
@ -108,7 +110,7 @@ def validate_stat(fn):
log.error('Invalid stat name: %s.', stat, exc_info=True)
return
return wrapper
return cast(T, wrapper)
class AllowListValidator:

Просмотреть файл

@ -32,7 +32,7 @@ import threading
import traceback
from argparse import Namespace
from datetime import datetime
from typing import Optional
from typing import Callable, Optional, TypeVar, cast
from airflow import settings
from airflow.exceptions import AirflowException
@ -41,8 +41,10 @@ from airflow.utils import cli_action_loggers
from airflow.utils.platform import is_terminal_support_colors
from airflow.utils.session import provide_session
T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name
def action_logging(f):
def action_logging(f: T) -> T:
"""
Decorates function to execute function at the same time submitting action_logging
but in CLI context. It will call action logger callbacks twice,
@ -89,7 +91,7 @@ def action_logging(f):
metrics['end_datetime'] = datetime.utcnow()
cli_action_loggers.on_post_execution(**metrics)
return wrapper
return cast(T, wrapper)
def _build_metrics(func_name, namespace):

Просмотреть файл

@ -19,6 +19,7 @@
import functools
import gzip
from io import BytesIO as IO
from typing import Callable, TypeVar, cast
import pendulum
from flask import after_this_request, flash, g, redirect, request, url_for
@ -26,8 +27,10 @@ from flask import after_this_request, flash, g, redirect, request, url_for
from airflow.models import Log
from airflow.utils.session import create_session
T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name
def action_logging(f):
def action_logging(f: T) -> T:
"""
Decorator to log user actions
"""
@ -56,10 +59,10 @@ def action_logging(f):
return f(*args, **kwargs)
return wrapper
return cast(T, wrapper)
def gzipped(f):
def gzipped(f: T) -> T:
"""
Decorator to make a view compressed
"""
@ -92,14 +95,14 @@ def gzipped(f):
return f(*args, **kwargs)
return view_func
return cast(T, view_func)
def has_dag_access(**dag_kwargs):
def has_dag_access(**dag_kwargs) -> Callable[[T], T]:
"""
Decorator to check whether the user has read / write permission on the dag.
"""
def decorator(f):
def decorator(f: T):
@functools.wraps(f)
def wrapper(self, *args, **kwargs):
has_access = self.appbuilder.sm.has_access
@ -124,5 +127,5 @@ def has_dag_access(**dag_kwargs):
flash("Access is Denied", "danger")
return redirect(url_for(self.appbuilder.sm.auth_view.
__class__.__name__ + ".login"))
return wrapper
return cast(T, wrapper)
return decorator