diff --git a/airflow/api/auth/backend/default.py b/airflow/api/auth/backend/default.py index 883dd2f457..d444c80579 100644 --- a/airflow/api/auth/backend/default.py +++ b/airflow/api/auth/backend/default.py @@ -17,7 +17,7 @@ # under the License. """Default authentication backend - everything is allowed""" from functools import wraps -from typing import Optional +from typing import Callable, Optional, TypeVar, cast from airflow.typing_compat import Protocol @@ -40,10 +40,13 @@ def init_app(_): """Initializes authentication backend""" -def requires_authentication(function): +T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name + + +def requires_authentication(function: T): """Decorator for functions that require authentication""" @wraps(function) def decorated(*args, **kwargs): return function(*args, **kwargs) - return decorated + return cast(T, decorated) diff --git a/airflow/api/auth/backend/deny_all.py b/airflow/api/auth/backend/deny_all.py index a609addf2c..2971a5aff7 100644 --- a/airflow/api/auth/backend/deny_all.py +++ b/airflow/api/auth/backend/deny_all.py @@ -17,7 +17,7 @@ # under the License. """Authentication backend that denies all requests""" from functools import wraps -from typing import Optional +from typing import Callable, Optional, TypeVar, cast from flask import Response @@ -30,7 +30,10 @@ def init_app(_): """Initializes authentication""" -def requires_authentication(function): +T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name + + +def requires_authentication(function: T): """Decorator for functions that require authentication""" # noinspection PyUnusedLocal @@ -38,4 +41,4 @@ def requires_authentication(function): def decorated(*args, **kwargs): # pylint: disable=unused-argument return Response("Forbidden", 403) - return decorated + return cast(T, decorated) diff --git a/airflow/api/auth/backend/kerberos_auth.py b/airflow/api/auth/backend/kerberos_auth.py index 6c6d431eea..9780e3ae10 100644 --- a/airflow/api/auth/backend/kerberos_auth.py +++ b/airflow/api/auth/backend/kerberos_auth.py @@ -44,6 +44,7 @@ import logging import os from functools import wraps from socket import getfqdn +from typing import Callable, TypeVar, cast import kerberos # noinspection PyProtectedMember @@ -126,7 +127,10 @@ def _gssapi_authenticate(token): kerberos.authGSSServerClean(state) -def requires_authentication(function): +T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name + + +def requires_authentication(function: T): """Decorator for functions that require authentication with Kerberos""" @wraps(function) def decorated(*args, **kwargs): @@ -147,4 +151,4 @@ def requires_authentication(function): if return_code != kerberos.AUTH_GSS_CONTINUE: return _forbidden() return _unauthorized() - return decorated + return cast(T, decorated) diff --git a/airflow/api_connexion/parameters.py b/airflow/api_connexion/parameters.py index 6ee9be58a9..6c99f74387 100644 --- a/airflow/api_connexion/parameters.py +++ b/airflow/api_connexion/parameters.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. from functools import wraps -from typing import Callable, Dict +from typing import Callable, Dict, TypeVar, cast from pendulum.parsing import ParserError @@ -59,7 +59,10 @@ def check_limit(value: int): return value -def format_parameters(params_formatters: Dict[str, Callable[..., bool]]): +T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name + + +def format_parameters(params_formatters: Dict[str, Callable[..., bool]]) -> Callable[[T], T]: """ Decorator factory that create decorator that convert parameters using given formatters. @@ -68,7 +71,7 @@ def format_parameters(params_formatters: Dict[str, Callable[..., bool]]): :param params_formatters: Map of key name and formatter function """ - def format_parameters_decorator(func): + def format_parameters_decorator(func: T): @wraps(func) def wrapped_function(*args, **kwargs): for key, formatter in params_formatters.items(): @@ -76,6 +79,6 @@ def format_parameters(params_formatters: Dict[str, Callable[..., bool]]): kwargs[key] = formatter(kwargs[key]) return func(*args, **kwargs) - return wrapped_function + return cast(T, wrapped_function) return format_parameters_decorator diff --git a/airflow/lineage/__init__.py b/airflow/lineage/__init__.py index f3c043bb0f..036e793450 100644 --- a/airflow/lineage/__init__.py +++ b/airflow/lineage/__init__.py @@ -21,7 +21,7 @@ Provides lineage support functions import json import logging from functools import wraps -from typing import Any, Dict, Optional +from typing import Any, Callable, Dict, Optional, TypeVar, cast import attr import jinja2 @@ -79,7 +79,10 @@ def _to_dataset(obj: Any, source: str) -> Optional[Metadata]: return Metadata(type_name, source, data) -def apply_lineage(func): +T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name + + +def apply_lineage(func: T) -> T: """ Saves the lineage to XCom and if configured to do so sends it to the backend. @@ -110,10 +113,10 @@ def apply_lineage(func): return ret_val - return wrapper + return cast(T, wrapper) -def prepare_lineage(func): +def prepare_lineage(func: T) -> T: """ Prepares the lineage inlets and outlets. Inlets can be: @@ -172,4 +175,4 @@ def prepare_lineage(func): self.log.debug("inlets: %s, outlets: %s", self.inlets, self.outlets) return func(self, context, *args, **kwargs) - return wrapper + return cast(T, wrapper) diff --git a/airflow/migrations/versions/3c20cacc0044_add_dagrun_run_type.py b/airflow/migrations/versions/3c20cacc0044_add_dagrun_run_type.py index cc3de10ac3..5ea22d0a11 100644 --- a/airflow/migrations/versions/3c20cacc0044_add_dagrun_run_type.py +++ b/airflow/migrations/versions/3c20cacc0044_add_dagrun_run_type.py @@ -46,7 +46,7 @@ depends_on = None Base = declarative_base() -class DagRun(Base): +class DagRun(Base): # type: ignore """ DagRun describes an instance of a Dag. It can be created by the scheduler (for regular runs) or by an external trigger diff --git a/airflow/migrations/versions/6e96a59344a4_make_taskinstance_pool_not_nullable.py b/airflow/migrations/versions/6e96a59344a4_make_taskinstance_pool_not_nullable.py index 31bd0359c8..1504682baf 100644 --- a/airflow/migrations/versions/6e96a59344a4_make_taskinstance_pool_not_nullable.py +++ b/airflow/migrations/versions/6e96a59344a4_make_taskinstance_pool_not_nullable.py @@ -44,7 +44,7 @@ Base = declarative_base() ID_LEN = 250 -class TaskInstance(Base): +class TaskInstance(Base): # type: ignore """ Task instances store the state of a task instance. This table is the authority and single source of truth around what tasks have run and the diff --git a/airflow/migrations/versions/cc1e65623dc7_add_max_tries_column_to_task_instance.py b/airflow/migrations/versions/cc1e65623dc7_add_max_tries_column_to_task_instance.py index 83c15a9e7a..46f9178781 100644 --- a/airflow/migrations/versions/cc1e65623dc7_add_max_tries_column_to_task_instance.py +++ b/airflow/migrations/versions/cc1e65623dc7_add_max_tries_column_to_task_instance.py @@ -44,7 +44,7 @@ BATCH_SIZE = 5000 ID_LEN = 250 -class TaskInstance(Base): # noqa: D101 +class TaskInstance(Base): # noqa: D101 # type: ignore __tablename__ = "task_instance" task_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True) diff --git a/airflow/operators/python.py b/airflow/operators/python.py index 5bbc715ec1..d33725c550 100644 --- a/airflow/operators/python.py +++ b/airflow/operators/python.py @@ -27,7 +27,7 @@ from inspect import signature from itertools import islice from tempfile import TemporaryDirectory from textwrap import dedent -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar, cast import dill @@ -254,7 +254,14 @@ class _PythonFunctionalOperator(BaseOperator): return return_value -def task(python_callable: Optional[Callable] = None, multiple_outputs: bool = False, **kwargs): +T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name + + +def task( + python_callable: Optional[Callable] = None, + multiple_outputs: bool = False, + **kwargs +) -> Callable[[T], T]: """ Python operator decorator. Wraps a function into an Airflow operator. Accepts kwargs for operator kwarg. Can be reused in a single DAG. @@ -268,7 +275,7 @@ def task(python_callable: Optional[Callable] = None, multiple_outputs: bool = Fa :type multiple_outputs: bool """ - def wrapper(f): + def wrapper(f: T): """ Python wrapper to generate PythonFunctionalOperator out of simple python functions. Used for Airflow functional interface @@ -281,7 +288,7 @@ def task(python_callable: Optional[Callable] = None, multiple_outputs: bool = Fa op = _PythonFunctionalOperator(python_callable=f, op_args=args, op_kwargs=f_kwargs, multiple_outputs=multiple_outputs, **kwargs) return XComArg(op) - return factory + return cast(T, factory) if callable(python_callable): return wrapper(python_callable) elif python_callable is not None: diff --git a/airflow/providers/amazon/aws/hooks/s3.py b/airflow/providers/amazon/aws/hooks/s3.py index d2d6be629d..3e4b88b153 100644 --- a/airflow/providers/amazon/aws/hooks/s3.py +++ b/airflow/providers/amazon/aws/hooks/s3.py @@ -28,7 +28,7 @@ import shutil from functools import wraps from inspect import signature from tempfile import NamedTemporaryFile -from typing import Optional +from typing import Callable, Optional, TypeVar, cast from urllib.parse import urlparse from botocore.exceptions import ClientError @@ -37,8 +37,10 @@ from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook from airflow.utils.helpers import chunks +T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name -def provide_bucket_name(func): + +def provide_bucket_name(func: T) -> T: """ Function decorator that provides a bucket name taken from the connection in case no bucket name has been passed to the function. @@ -59,10 +61,10 @@ def provide_bucket_name(func): return func(*bound_args.args, **bound_args.kwargs) - return wrapper + return cast(T, wrapper) -def unify_bucket_name_and_key(func): +def unify_bucket_name_and_key(func: T) -> T: """ Function decorator that unifies bucket name and key taken from the key in case no bucket name and at least a key has been passed to the function. @@ -88,7 +90,7 @@ def unify_bucket_name_and_key(func): return func(*bound_args.args, **bound_args.kwargs) - return wrapper + return cast(T, wrapper) class S3Hook(AwsBaseHook): diff --git a/airflow/providers/google/cloud/hooks/dataflow.py b/airflow/providers/google/cloud/hooks/dataflow.py index bf5b683b5a..ee3b3f6258 100644 --- a/airflow/providers/google/cloud/hooks/dataflow.py +++ b/airflow/providers/google/cloud/hooks/dataflow.py @@ -29,7 +29,7 @@ import uuid import warnings from copy import deepcopy from tempfile import TemporaryDirectory -from typing import Any, Callable, Dict, List, Optional, TypeVar +from typing import Any, Callable, Dict, List, Optional, TypeVar, cast from googleapiclient.discovery import build @@ -47,12 +47,12 @@ JOB_ID_PATTERN = re.compile( r'Submitted job: (?P.*)|Created job with id: \[(?P.*)\]' ) -RT = TypeVar('RT') # pylint: disable=invalid-name +T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name -def _fallback_variable_parameter(parameter_name, variable_key_name): +def _fallback_variable_parameter(parameter_name: str, variable_key_name: str) -> Callable[[T], T]: - def _wrapper(func: Callable[..., RT]) -> Callable[..., RT]: + def _wrapper(func: T) -> T: """ Decorator that provides fallback for location from `region` key in `variables` parameters. @@ -60,7 +60,7 @@ def _fallback_variable_parameter(parameter_name, variable_key_name): :return: result of the function call """ @functools.wraps(func) - def inner_wrapper(self: "DataflowHook", *args, **kwargs) -> RT: + def inner_wrapper(self: "DataflowHook", *args, **kwargs): if args: raise AirflowException( "You must use keyword arguments in this methods rather than positional") @@ -81,7 +81,7 @@ def _fallback_variable_parameter(parameter_name, variable_key_name): kwargs['variables'] = copy_variables return func(self, *args, **kwargs) - return inner_wrapper + return cast(T, inner_wrapper) return _wrapper diff --git a/airflow/providers/google/cloud/hooks/gcs.py b/airflow/providers/google/cloud/hooks/gcs.py index 11a1954255..94523ff961 100644 --- a/airflow/providers/google/cloud/hooks/gcs.py +++ b/airflow/providers/google/cloud/hooks/gcs.py @@ -28,7 +28,7 @@ from contextlib import contextmanager from io import BytesIO from os import path from tempfile import NamedTemporaryFile -from typing import Optional, Set, Tuple, TypeVar, Union +from typing import Callable, Optional, Set, Tuple, TypeVar, Union, cast from urllib.parse import urlparse from google.api_core.exceptions import NotFound @@ -39,13 +39,14 @@ from airflow.providers.google.common.hooks.base_google import GoogleBaseHook from airflow.version import version RT = TypeVar('RT') # pylint: disable=invalid-name +T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name def _fallback_object_url_to_object_name_and_bucket_name( object_url_keyword_arg_name='object_url', bucket_name_keyword_arg_name='bucket_name', object_name_keyword_arg_name='object_name', -): +) -> Callable[[T], T]: """ Decorator factory that convert object URL parameter to object name and bucket name parameter. @@ -57,7 +58,7 @@ def _fallback_object_url_to_object_name_and_bucket_name( :type object_name_keyword_arg_name: str :return: Decorator """ - def _wrapper(func): + def _wrapper(func: T): @functools.wraps(func) def _inner_wrapper(self: "GCSHook", * args, **kwargs) -> RT: @@ -99,7 +100,7 @@ def _fallback_object_url_to_object_name_and_bucket_name( ) return func(self, *args, **kwargs) - return _inner_wrapper + return cast(T, _inner_wrapper) return _wrapper diff --git a/airflow/providers/google/common/hooks/base_google.py b/airflow/providers/google/common/hooks/base_google.py index 3cc3eb0234..4e942a7541 100644 --- a/airflow/providers/google/common/hooks/base_google.py +++ b/airflow/providers/google/common/hooks/base_google.py @@ -26,7 +26,7 @@ import os import tempfile from contextlib import contextmanager from subprocess import check_output -from typing import Any, Callable, Dict, Optional, Sequence, Tuple, TypeVar +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, TypeVar, cast import google.auth import google.auth.credentials @@ -118,6 +118,7 @@ class retry_if_operation_in_progress(tenacity.retry_if_exception): # pylint: di super().__init__(is_operation_in_progress_exception) +T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name RT = TypeVar('RT') # pylint: disable=invalid-name @@ -309,13 +310,13 @@ class GoogleBaseHook(BaseHook): return decorator @staticmethod - def operation_in_progress_retry(*args, **kwargs) -> Callable: + def operation_in_progress_retry(*args, **kwargs) -> Callable[[T], T]: """ A decorator that provides a mechanism to repeat requests in response to operation in progress (HTTP 409) limit. """ - def decorator(fun: Callable): + def decorator(fun: T): default_kwargs = { 'wait': tenacity.wait_exponential(multiplier=1, max=300), 'retry': retry_if_operation_in_progress(), @@ -323,9 +324,9 @@ class GoogleBaseHook(BaseHook): 'after': tenacity.after_log(log, logging.DEBUG), } default_kwargs.update(**kwargs) - return tenacity.retry( + return cast(T, tenacity.retry( *args, **default_kwargs - )(fun) + )(fun)) return decorator @staticmethod @@ -357,7 +358,7 @@ class GoogleBaseHook(BaseHook): return inner_wrapper @staticmethod - def provide_gcp_credential_file(func: Callable[..., RT]) -> Callable[..., RT]: + def provide_gcp_credential_file(func: T) -> T: """ Function decorator that provides a GCP credentials for application supporting Application Default Credentials (ADC) strategy. @@ -367,10 +368,10 @@ class GoogleBaseHook(BaseHook): makes it easier to use multiple connection in one function. """ @functools.wraps(func) - def wrapper(self: GoogleBaseHook, *args, **kwargs) -> RT: + def wrapper(self: GoogleBaseHook, *args, **kwargs): with self.provide_gcp_credential_file_as_context(): return func(self, *args, **kwargs) - return wrapper + return cast(T, wrapper) @contextmanager def provide_gcp_credential_file_as_context(self): diff --git a/airflow/stats.py b/airflow/stats.py index 852bafeb8a..7ea9bf35d5 100644 --- a/airflow/stats.py +++ b/airflow/stats.py @@ -16,13 +16,12 @@ # specific language governing permissions and limitations # under the License. - import logging import socket import string import textwrap from functools import wraps -from typing import TYPE_CHECKING, Callable, Optional +from typing import TYPE_CHECKING, Callable, Optional, TypeVar, cast from airflow.configuration import conf from airflow.exceptions import AirflowConfigException, InvalidStatsNameException @@ -94,7 +93,10 @@ def get_current_handler_stat_name_func() -> Callable[[str], str]: return conf.getimport('scheduler', 'stat_name_handler') or stat_name_default_handler -def validate_stat(fn): +T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name + + +def validate_stat(fn: T) -> T: """Check if stat name contains invalid characters. Log and not emit stats if name is invalid """ @@ -108,7 +110,7 @@ def validate_stat(fn): log.error('Invalid stat name: %s.', stat, exc_info=True) return - return wrapper + return cast(T, wrapper) class AllowListValidator: diff --git a/airflow/utils/cli.py b/airflow/utils/cli.py index dfcf2355da..76d809c4d0 100644 --- a/airflow/utils/cli.py +++ b/airflow/utils/cli.py @@ -32,7 +32,7 @@ import threading import traceback from argparse import Namespace from datetime import datetime -from typing import Optional +from typing import Callable, Optional, TypeVar, cast from airflow import settings from airflow.exceptions import AirflowException @@ -41,8 +41,10 @@ from airflow.utils import cli_action_loggers from airflow.utils.platform import is_terminal_support_colors from airflow.utils.session import provide_session +T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name -def action_logging(f): + +def action_logging(f: T) -> T: """ Decorates function to execute function at the same time submitting action_logging but in CLI context. It will call action logger callbacks twice, @@ -89,7 +91,7 @@ def action_logging(f): metrics['end_datetime'] = datetime.utcnow() cli_action_loggers.on_post_execution(**metrics) - return wrapper + return cast(T, wrapper) def _build_metrics(func_name, namespace): diff --git a/airflow/www/decorators.py b/airflow/www/decorators.py index 68bf466c02..d236df1830 100644 --- a/airflow/www/decorators.py +++ b/airflow/www/decorators.py @@ -19,6 +19,7 @@ import functools import gzip from io import BytesIO as IO +from typing import Callable, TypeVar, cast import pendulum from flask import after_this_request, flash, g, redirect, request, url_for @@ -26,8 +27,10 @@ from flask import after_this_request, flash, g, redirect, request, url_for from airflow.models import Log from airflow.utils.session import create_session +T = TypeVar("T", bound=Callable) # pylint: disable=invalid-name -def action_logging(f): + +def action_logging(f: T) -> T: """ Decorator to log user actions """ @@ -56,10 +59,10 @@ def action_logging(f): return f(*args, **kwargs) - return wrapper + return cast(T, wrapper) -def gzipped(f): +def gzipped(f: T) -> T: """ Decorator to make a view compressed """ @@ -92,14 +95,14 @@ def gzipped(f): return f(*args, **kwargs) - return view_func + return cast(T, view_func) -def has_dag_access(**dag_kwargs): +def has_dag_access(**dag_kwargs) -> Callable[[T], T]: """ Decorator to check whether the user has read / write permission on the dag. """ - def decorator(f): + def decorator(f: T): @functools.wraps(f) def wrapper(self, *args, **kwargs): has_access = self.appbuilder.sm.has_access @@ -124,5 +127,5 @@ def has_dag_access(**dag_kwargs): flash("Access is Denied", "danger") return redirect(url_for(self.appbuilder.sm.auth_view. __class__.__name__ + ".login")) - return wrapper + return cast(T, wrapper) return decorator