Keep functions signatures in decorators (#9786)
This commit is contained in:
Родитель
68925904e4
Коммит
553bb7af7c
|
@ -17,7 +17,7 @@
|
|||
# under the License.
|
||||
"""Default authentication backend - everything is allowed"""
|
||||
from functools import wraps
|
||||
from typing import Optional
|
||||
from typing import Callable, Optional, TypeVar, cast
|
||||
|
||||
from airflow.typing_compat import Protocol
|
||||
|
||||
|
@ -40,10 +40,13 @@ def init_app(_):
|
|||
"""Initializes authentication backend"""
|
||||
|
||||
|
||||
def requires_authentication(function):
|
||||
T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def requires_authentication(function: T):
|
||||
"""Decorator for functions that require authentication"""
|
||||
@wraps(function)
|
||||
def decorated(*args, **kwargs):
|
||||
return function(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
return cast(T, decorated)
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
# under the License.
|
||||
"""Authentication backend that denies all requests"""
|
||||
from functools import wraps
|
||||
from typing import Optional
|
||||
from typing import Callable, Optional, TypeVar, cast
|
||||
|
||||
from flask import Response
|
||||
|
||||
|
@ -30,7 +30,10 @@ def init_app(_):
|
|||
"""Initializes authentication"""
|
||||
|
||||
|
||||
def requires_authentication(function):
|
||||
T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def requires_authentication(function: T):
|
||||
"""Decorator for functions that require authentication"""
|
||||
|
||||
# noinspection PyUnusedLocal
|
||||
|
@ -38,4 +41,4 @@ def requires_authentication(function):
|
|||
def decorated(*args, **kwargs): # pylint: disable=unused-argument
|
||||
return Response("Forbidden", 403)
|
||||
|
||||
return decorated
|
||||
return cast(T, decorated)
|
||||
|
|
|
@ -44,6 +44,7 @@ import logging
|
|||
import os
|
||||
from functools import wraps
|
||||
from socket import getfqdn
|
||||
from typing import Callable, TypeVar, cast
|
||||
|
||||
import kerberos
|
||||
# noinspection PyProtectedMember
|
||||
|
@ -126,7 +127,10 @@ def _gssapi_authenticate(token):
|
|||
kerberos.authGSSServerClean(state)
|
||||
|
||||
|
||||
def requires_authentication(function):
|
||||
T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def requires_authentication(function: T):
|
||||
"""Decorator for functions that require authentication with Kerberos"""
|
||||
@wraps(function)
|
||||
def decorated(*args, **kwargs):
|
||||
|
@ -147,4 +151,4 @@ def requires_authentication(function):
|
|||
if return_code != kerberos.AUTH_GSS_CONTINUE:
|
||||
return _forbidden()
|
||||
return _unauthorized()
|
||||
return decorated
|
||||
return cast(T, decorated)
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from functools import wraps
|
||||
from typing import Callable, Dict
|
||||
from typing import Callable, Dict, TypeVar, cast
|
||||
|
||||
from pendulum.parsing import ParserError
|
||||
|
||||
|
@ -59,7 +59,10 @@ def check_limit(value: int):
|
|||
return value
|
||||
|
||||
|
||||
def format_parameters(params_formatters: Dict[str, Callable[..., bool]]):
|
||||
T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def format_parameters(params_formatters: Dict[str, Callable[..., bool]]) -> Callable[[T], T]:
|
||||
"""
|
||||
Decorator factory that create decorator that convert parameters using given formatters.
|
||||
|
||||
|
@ -68,7 +71,7 @@ def format_parameters(params_formatters: Dict[str, Callable[..., bool]]):
|
|||
:param params_formatters: Map of key name and formatter function
|
||||
"""
|
||||
|
||||
def format_parameters_decorator(func):
|
||||
def format_parameters_decorator(func: T):
|
||||
@wraps(func)
|
||||
def wrapped_function(*args, **kwargs):
|
||||
for key, formatter in params_formatters.items():
|
||||
|
@ -76,6 +79,6 @@ def format_parameters(params_formatters: Dict[str, Callable[..., bool]]):
|
|||
kwargs[key] = formatter(kwargs[key])
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapped_function
|
||||
return cast(T, wrapped_function)
|
||||
|
||||
return format_parameters_decorator
|
||||
|
|
|
@ -21,7 +21,7 @@ Provides lineage support functions
|
|||
import json
|
||||
import logging
|
||||
from functools import wraps
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Callable, Dict, Optional, TypeVar, cast
|
||||
|
||||
import attr
|
||||
import jinja2
|
||||
|
@ -79,7 +79,10 @@ def _to_dataset(obj: Any, source: str) -> Optional[Metadata]:
|
|||
return Metadata(type_name, source, data)
|
||||
|
||||
|
||||
def apply_lineage(func):
|
||||
T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def apply_lineage(func: T) -> T:
|
||||
"""
|
||||
Saves the lineage to XCom and if configured to do so sends it
|
||||
to the backend.
|
||||
|
@ -110,10 +113,10 @@ def apply_lineage(func):
|
|||
|
||||
return ret_val
|
||||
|
||||
return wrapper
|
||||
return cast(T, wrapper)
|
||||
|
||||
|
||||
def prepare_lineage(func):
|
||||
def prepare_lineage(func: T) -> T:
|
||||
"""
|
||||
Prepares the lineage inlets and outlets. Inlets can be:
|
||||
|
||||
|
@ -172,4 +175,4 @@ def prepare_lineage(func):
|
|||
self.log.debug("inlets: %s, outlets: %s", self.inlets, self.outlets)
|
||||
return func(self, context, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
return cast(T, wrapper)
|
||||
|
|
|
@ -46,7 +46,7 @@ depends_on = None
|
|||
Base = declarative_base()
|
||||
|
||||
|
||||
class DagRun(Base):
|
||||
class DagRun(Base): # type: ignore
|
||||
"""
|
||||
DagRun describes an instance of a Dag. It can be created
|
||||
by the scheduler (for regular runs) or by an external trigger
|
||||
|
|
|
@ -44,7 +44,7 @@ Base = declarative_base()
|
|||
ID_LEN = 250
|
||||
|
||||
|
||||
class TaskInstance(Base):
|
||||
class TaskInstance(Base): # type: ignore
|
||||
"""
|
||||
Task instances store the state of a task instance. This table is the
|
||||
authority and single source of truth around what tasks have run and the
|
||||
|
|
|
@ -44,7 +44,7 @@ BATCH_SIZE = 5000
|
|||
ID_LEN = 250
|
||||
|
||||
|
||||
class TaskInstance(Base): # noqa: D101
|
||||
class TaskInstance(Base): # noqa: D101 # type: ignore
|
||||
__tablename__ = "task_instance"
|
||||
|
||||
task_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True)
|
||||
|
|
|
@ -27,7 +27,7 @@ from inspect import signature
|
|||
from itertools import islice
|
||||
from tempfile import TemporaryDirectory
|
||||
from textwrap import dedent
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar, cast
|
||||
|
||||
import dill
|
||||
|
||||
|
@ -254,7 +254,14 @@ class _PythonFunctionalOperator(BaseOperator):
|
|||
return return_value
|
||||
|
||||
|
||||
def task(python_callable: Optional[Callable] = None, multiple_outputs: bool = False, **kwargs):
|
||||
T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def task(
|
||||
python_callable: Optional[Callable] = None,
|
||||
multiple_outputs: bool = False,
|
||||
**kwargs
|
||||
) -> Callable[[T], T]:
|
||||
"""
|
||||
Python operator decorator. Wraps a function into an Airflow operator.
|
||||
Accepts kwargs for operator kwarg. Can be reused in a single DAG.
|
||||
|
@ -268,7 +275,7 @@ def task(python_callable: Optional[Callable] = None, multiple_outputs: bool = Fa
|
|||
:type multiple_outputs: bool
|
||||
|
||||
"""
|
||||
def wrapper(f):
|
||||
def wrapper(f: T):
|
||||
"""
|
||||
Python wrapper to generate PythonFunctionalOperator out of simple python functions.
|
||||
Used for Airflow functional interface
|
||||
|
@ -281,7 +288,7 @@ def task(python_callable: Optional[Callable] = None, multiple_outputs: bool = Fa
|
|||
op = _PythonFunctionalOperator(python_callable=f, op_args=args, op_kwargs=f_kwargs,
|
||||
multiple_outputs=multiple_outputs, **kwargs)
|
||||
return XComArg(op)
|
||||
return factory
|
||||
return cast(T, factory)
|
||||
if callable(python_callable):
|
||||
return wrapper(python_callable)
|
||||
elif python_callable is not None:
|
||||
|
|
|
@ -28,7 +28,7 @@ import shutil
|
|||
from functools import wraps
|
||||
from inspect import signature
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import Optional
|
||||
from typing import Callable, Optional, TypeVar, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from botocore.exceptions import ClientError
|
||||
|
@ -37,8 +37,10 @@ from airflow.exceptions import AirflowException
|
|||
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
|
||||
from airflow.utils.helpers import chunks
|
||||
|
||||
T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name
|
||||
|
||||
def provide_bucket_name(func):
|
||||
|
||||
def provide_bucket_name(func: T) -> T:
|
||||
"""
|
||||
Function decorator that provides a bucket name taken from the connection
|
||||
in case no bucket name has been passed to the function.
|
||||
|
@ -59,10 +61,10 @@ def provide_bucket_name(func):
|
|||
|
||||
return func(*bound_args.args, **bound_args.kwargs)
|
||||
|
||||
return wrapper
|
||||
return cast(T, wrapper)
|
||||
|
||||
|
||||
def unify_bucket_name_and_key(func):
|
||||
def unify_bucket_name_and_key(func: T) -> T:
|
||||
"""
|
||||
Function decorator that unifies bucket name and key taken from the key
|
||||
in case no bucket name and at least a key has been passed to the function.
|
||||
|
@ -88,7 +90,7 @@ def unify_bucket_name_and_key(func):
|
|||
|
||||
return func(*bound_args.args, **bound_args.kwargs)
|
||||
|
||||
return wrapper
|
||||
return cast(T, wrapper)
|
||||
|
||||
|
||||
class S3Hook(AwsBaseHook):
|
||||
|
|
|
@ -29,7 +29,7 @@ import uuid
|
|||
import warnings
|
||||
from copy import deepcopy
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Any, Callable, Dict, List, Optional, TypeVar
|
||||
from typing import Any, Callable, Dict, List, Optional, TypeVar, cast
|
||||
|
||||
from googleapiclient.discovery import build
|
||||
|
||||
|
@ -47,12 +47,12 @@ JOB_ID_PATTERN = re.compile(
|
|||
r'Submitted job: (?P<job_id_java>.*)|Created job with id: \[(?P<job_id_python>.*)\]'
|
||||
)
|
||||
|
||||
RT = TypeVar('RT') # pylint: disable=invalid-name
|
||||
T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def _fallback_variable_parameter(parameter_name, variable_key_name):
|
||||
def _fallback_variable_parameter(parameter_name: str, variable_key_name: str) -> Callable[[T], T]:
|
||||
|
||||
def _wrapper(func: Callable[..., RT]) -> Callable[..., RT]:
|
||||
def _wrapper(func: T) -> T:
|
||||
"""
|
||||
Decorator that provides fallback for location from `region` key in `variables` parameters.
|
||||
|
||||
|
@ -60,7 +60,7 @@ def _fallback_variable_parameter(parameter_name, variable_key_name):
|
|||
:return: result of the function call
|
||||
"""
|
||||
@functools.wraps(func)
|
||||
def inner_wrapper(self: "DataflowHook", *args, **kwargs) -> RT:
|
||||
def inner_wrapper(self: "DataflowHook", *args, **kwargs):
|
||||
if args:
|
||||
raise AirflowException(
|
||||
"You must use keyword arguments in this methods rather than positional")
|
||||
|
@ -81,7 +81,7 @@ def _fallback_variable_parameter(parameter_name, variable_key_name):
|
|||
kwargs['variables'] = copy_variables
|
||||
|
||||
return func(self, *args, **kwargs)
|
||||
return inner_wrapper
|
||||
return cast(T, inner_wrapper)
|
||||
|
||||
return _wrapper
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ from contextlib import contextmanager
|
|||
from io import BytesIO
|
||||
from os import path
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import Optional, Set, Tuple, TypeVar, Union
|
||||
from typing import Callable, Optional, Set, Tuple, TypeVar, Union, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from google.api_core.exceptions import NotFound
|
||||
|
@ -39,13 +39,14 @@ from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
|
|||
from airflow.version import version
|
||||
|
||||
RT = TypeVar('RT') # pylint: disable=invalid-name
|
||||
T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def _fallback_object_url_to_object_name_and_bucket_name(
|
||||
object_url_keyword_arg_name='object_url',
|
||||
bucket_name_keyword_arg_name='bucket_name',
|
||||
object_name_keyword_arg_name='object_name',
|
||||
):
|
||||
) -> Callable[[T], T]:
|
||||
"""
|
||||
Decorator factory that convert object URL parameter to object name and bucket name parameter.
|
||||
|
||||
|
@ -57,7 +58,7 @@ def _fallback_object_url_to_object_name_and_bucket_name(
|
|||
:type object_name_keyword_arg_name: str
|
||||
:return: Decorator
|
||||
"""
|
||||
def _wrapper(func):
|
||||
def _wrapper(func: T):
|
||||
|
||||
@functools.wraps(func)
|
||||
def _inner_wrapper(self: "GCSHook", * args, **kwargs) -> RT:
|
||||
|
@ -99,7 +100,7 @@ def _fallback_object_url_to_object_name_and_bucket_name(
|
|||
)
|
||||
|
||||
return func(self, *args, **kwargs)
|
||||
return _inner_wrapper
|
||||
return cast(T, _inner_wrapper)
|
||||
return _wrapper
|
||||
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ import os
|
|||
import tempfile
|
||||
from contextlib import contextmanager
|
||||
from subprocess import check_output
|
||||
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, TypeVar
|
||||
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, TypeVar, cast
|
||||
|
||||
import google.auth
|
||||
import google.auth.credentials
|
||||
|
@ -118,6 +118,7 @@ class retry_if_operation_in_progress(tenacity.retry_if_exception): # pylint: di
|
|||
super().__init__(is_operation_in_progress_exception)
|
||||
|
||||
|
||||
T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name
|
||||
RT = TypeVar('RT') # pylint: disable=invalid-name
|
||||
|
||||
|
||||
|
@ -309,13 +310,13 @@ class GoogleBaseHook(BaseHook):
|
|||
return decorator
|
||||
|
||||
@staticmethod
|
||||
def operation_in_progress_retry(*args, **kwargs) -> Callable:
|
||||
def operation_in_progress_retry(*args, **kwargs) -> Callable[[T], T]:
|
||||
"""
|
||||
A decorator that provides a mechanism to repeat requests in response to
|
||||
operation in progress (HTTP 409)
|
||||
limit.
|
||||
"""
|
||||
def decorator(fun: Callable):
|
||||
def decorator(fun: T):
|
||||
default_kwargs = {
|
||||
'wait': tenacity.wait_exponential(multiplier=1, max=300),
|
||||
'retry': retry_if_operation_in_progress(),
|
||||
|
@ -323,9 +324,9 @@ class GoogleBaseHook(BaseHook):
|
|||
'after': tenacity.after_log(log, logging.DEBUG),
|
||||
}
|
||||
default_kwargs.update(**kwargs)
|
||||
return tenacity.retry(
|
||||
return cast(T, tenacity.retry(
|
||||
*args, **default_kwargs
|
||||
)(fun)
|
||||
)(fun))
|
||||
return decorator
|
||||
|
||||
@staticmethod
|
||||
|
@ -357,7 +358,7 @@ class GoogleBaseHook(BaseHook):
|
|||
return inner_wrapper
|
||||
|
||||
@staticmethod
|
||||
def provide_gcp_credential_file(func: Callable[..., RT]) -> Callable[..., RT]:
|
||||
def provide_gcp_credential_file(func: T) -> T:
|
||||
"""
|
||||
Function decorator that provides a GCP credentials for application supporting Application
|
||||
Default Credentials (ADC) strategy.
|
||||
|
@ -367,10 +368,10 @@ class GoogleBaseHook(BaseHook):
|
|||
makes it easier to use multiple connection in one function.
|
||||
"""
|
||||
@functools.wraps(func)
|
||||
def wrapper(self: GoogleBaseHook, *args, **kwargs) -> RT:
|
||||
def wrapper(self: GoogleBaseHook, *args, **kwargs):
|
||||
with self.provide_gcp_credential_file_as_context():
|
||||
return func(self, *args, **kwargs)
|
||||
return wrapper
|
||||
return cast(T, wrapper)
|
||||
|
||||
@contextmanager
|
||||
def provide_gcp_credential_file_as_context(self):
|
||||
|
|
|
@ -16,13 +16,12 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
|
||||
import logging
|
||||
import socket
|
||||
import string
|
||||
import textwrap
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, Callable, Optional
|
||||
from typing import TYPE_CHECKING, Callable, Optional, TypeVar, cast
|
||||
|
||||
from airflow.configuration import conf
|
||||
from airflow.exceptions import AirflowConfigException, InvalidStatsNameException
|
||||
|
@ -94,7 +93,10 @@ def get_current_handler_stat_name_func() -> Callable[[str], str]:
|
|||
return conf.getimport('scheduler', 'stat_name_handler') or stat_name_default_handler
|
||||
|
||||
|
||||
def validate_stat(fn):
|
||||
T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def validate_stat(fn: T) -> T:
|
||||
"""Check if stat name contains invalid characters.
|
||||
Log and not emit stats if name is invalid
|
||||
"""
|
||||
|
@ -108,7 +110,7 @@ def validate_stat(fn):
|
|||
log.error('Invalid stat name: %s.', stat, exc_info=True)
|
||||
return
|
||||
|
||||
return wrapper
|
||||
return cast(T, wrapper)
|
||||
|
||||
|
||||
class AllowListValidator:
|
||||
|
|
|
@ -32,7 +32,7 @@ import threading
|
|||
import traceback
|
||||
from argparse import Namespace
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from typing import Callable, Optional, TypeVar, cast
|
||||
|
||||
from airflow import settings
|
||||
from airflow.exceptions import AirflowException
|
||||
|
@ -41,8 +41,10 @@ from airflow.utils import cli_action_loggers
|
|||
from airflow.utils.platform import is_terminal_support_colors
|
||||
from airflow.utils.session import provide_session
|
||||
|
||||
T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name
|
||||
|
||||
def action_logging(f):
|
||||
|
||||
def action_logging(f: T) -> T:
|
||||
"""
|
||||
Decorates function to execute function at the same time submitting action_logging
|
||||
but in CLI context. It will call action logger callbacks twice,
|
||||
|
@ -89,7 +91,7 @@ def action_logging(f):
|
|||
metrics['end_datetime'] = datetime.utcnow()
|
||||
cli_action_loggers.on_post_execution(**metrics)
|
||||
|
||||
return wrapper
|
||||
return cast(T, wrapper)
|
||||
|
||||
|
||||
def _build_metrics(func_name, namespace):
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
import functools
|
||||
import gzip
|
||||
from io import BytesIO as IO
|
||||
from typing import Callable, TypeVar, cast
|
||||
|
||||
import pendulum
|
||||
from flask import after_this_request, flash, g, redirect, request, url_for
|
||||
|
@ -26,8 +27,10 @@ from flask import after_this_request, flash, g, redirect, request, url_for
|
|||
from airflow.models import Log
|
||||
from airflow.utils.session import create_session
|
||||
|
||||
T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name
|
||||
|
||||
def action_logging(f):
|
||||
|
||||
def action_logging(f: T) -> T:
|
||||
"""
|
||||
Decorator to log user actions
|
||||
"""
|
||||
|
@ -56,10 +59,10 @@ def action_logging(f):
|
|||
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
return cast(T, wrapper)
|
||||
|
||||
|
||||
def gzipped(f):
|
||||
def gzipped(f: T) -> T:
|
||||
"""
|
||||
Decorator to make a view compressed
|
||||
"""
|
||||
|
@ -92,14 +95,14 @@ def gzipped(f):
|
|||
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return view_func
|
||||
return cast(T, view_func)
|
||||
|
||||
|
||||
def has_dag_access(**dag_kwargs):
|
||||
def has_dag_access(**dag_kwargs) -> Callable[[T], T]:
|
||||
"""
|
||||
Decorator to check whether the user has read / write permission on the dag.
|
||||
"""
|
||||
def decorator(f):
|
||||
def decorator(f: T):
|
||||
@functools.wraps(f)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
has_access = self.appbuilder.sm.has_access
|
||||
|
@ -124,5 +127,5 @@ def has_dag_access(**dag_kwargs):
|
|||
flash("Access is Denied", "danger")
|
||||
return redirect(url_for(self.appbuilder.sm.auth_view.
|
||||
__class__.__name__ + ".login"))
|
||||
return wrapper
|
||||
return cast(T, wrapper)
|
||||
return decorator
|
||||
|
|
Загрузка…
Ссылка в новой задаче