diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 68f2f6cabb..174dfa5803 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -370,8 +370,13 @@ class DAG(BaseDag, LoggingMixin): # /Context Manager ---------------------------------------------- - def date_range(self, start_date, num=None, end_date=timezone.utcnow()): - if num: + def date_range( + self, + start_date: datetime, + num: Optional[int] = None, + end_date: Optional[datetime] = timezone.utcnow(), + ) -> List[datetime]: + if num is not None: end_date = None return utils_date_range( start_date=start_date, end_date=end_date, diff --git a/airflow/models/variable.py b/airflow/models/variable.py index 9369f26b5e..e87679033c 100644 --- a/airflow/models/variable.py +++ b/airflow/models/variable.py @@ -113,7 +113,7 @@ class Variable(Base, LoggingMixin): key: str, default_var: Any = __NO_DEFAULT_SENTINEL, deserialize_json: bool = False, - ): + ) -> Any: """ Sets a value for an Airflow Key diff --git a/airflow/utils/dates.py b/airflow/utils/dates.py index bd9b73fdb6..cf219677ff 100644 --- a/airflow/utils/dates.py +++ b/airflow/utils/dates.py @@ -17,9 +17,10 @@ # under the License. from datetime import datetime, timedelta +from typing import List, Optional, Union from croniter import croniter -from dateutil.relativedelta import relativedelta # noqa: F401 for doctest # pylint: disable=unused-import +from dateutil.relativedelta import relativedelta # noqa: F401 for doctest from airflow.utils import timezone @@ -33,7 +34,13 @@ cron_presets = { } -def date_range(start_date, end_date=None, num=None, delta=None): # pylint: disable=too-many-branches +# pylint: disable=too-many-branches +def date_range( + start_date: datetime, + end_date: Optional[datetime] = None, + num: Optional[int] = None, + delta: Optional[Union[str, timedelta, relativedelta]] = None, +) -> List[datetime]: """ Get a set of dates as a list based on a start, end and delta, delta can be something that can be added to `datetime.datetime` @@ -60,29 +67,31 @@ def date_range(start_date, end_date=None, num=None, delta=None): # pylint: disa output will always be sorted regardless :type num: int :param delta: step length. It can be datetime.timedelta or cron expression as string - :type delta: datetime.timedelta or str + :type delta: datetime.timedelta or str or dateutil.relativedelta """ if not delta: return [] - if end_date and start_date > end_date: - raise Exception("Wait. start_date needs to be before end_date") - if end_date and num: - raise Exception("Wait. Either specify end_date OR num") + if end_date: + if start_date > end_date: + raise Exception("Wait. start_date needs to be before end_date") + if num: + raise Exception("Wait. Either specify end_date OR num") if not end_date and not num: end_date = timezone.utcnow() - if delta in cron_presets: - delta = cron_presets.get(delta) delta_iscron = False time_zone = start_date.tzinfo + abs_delta: Union[timedelta, relativedelta] if isinstance(delta, str): delta_iscron = True if timezone.is_localized(start_date): start_date = timezone.make_naive(start_date, time_zone) - cron = croniter(delta, start_date) + cron = croniter(cron_presets.get(delta, delta), start_date) elif isinstance(delta, timedelta): - delta = abs(delta) + abs_delta = abs(delta) + elif isinstance(delta, relativedelta): + abs_delta = abs(delta) else: raise Exception("Wait. delta must be either datetime.timedelta or cron expression as str") @@ -90,7 +99,7 @@ def date_range(start_date, end_date=None, num=None, delta=None): # pylint: disa if end_date: if timezone.is_naive(start_date) and not timezone.is_naive(end_date): end_date = timezone.make_naive(end_date, time_zone) - while start_date <= end_date: + while start_date <= end_date: # type: ignore if timezone.is_naive(start_date): dates.append(timezone.make_aware(start_date, time_zone)) else: @@ -99,22 +108,23 @@ def date_range(start_date, end_date=None, num=None, delta=None): # pylint: disa if delta_iscron: start_date = cron.get_next(datetime) else: - start_date += delta + start_date += abs_delta else: - for _ in range(abs(num)): + num_entries: int = num # type: ignore + for _ in range(abs(num_entries)): if timezone.is_naive(start_date): dates.append(timezone.make_aware(start_date, time_zone)) else: dates.append(start_date) - if delta_iscron and num > 0: + if delta_iscron and num_entries > 0: start_date = cron.get_next(datetime) elif delta_iscron: start_date = cron.get_prev(datetime) - elif num > 0: - start_date += delta + elif num_entries > 0: + start_date += abs_delta else: - start_date -= delta + start_date -= abs_delta return sorted(dates) diff --git a/airflow/utils/timezone.py b/airflow/utils/timezone.py index 399a5a2e2d..24fa2d15a9 100644 --- a/airflow/utils/timezone.py +++ b/airflow/utils/timezone.py @@ -19,6 +19,7 @@ import datetime as dt import pendulum +from pendulum.datetime import DateTime from airflow.settings import TIMEZONE @@ -170,10 +171,10 @@ def datetime(*args, **kwargs): return dt.datetime(*args, **kwargs) -def parse(string, timezone=None): +def parse(string: str, timezone=None) -> DateTime: """ Parse a time string and return an aware datetime :param string: time string """ - return pendulum.parse(string, tz=timezone or TIMEZONE, strict=False) + return pendulum.parse(string, tz=timezone or TIMEZONE, strict=False) # type: ignore diff --git a/airflow/www/api/experimental/endpoints.py b/airflow/www/api/experimental/endpoints.py index 1f3ff25d76..6a74f5ce18 100644 --- a/airflow/www/api/experimental/endpoints.py +++ b/airflow/www/api/experimental/endpoints.py @@ -374,7 +374,7 @@ def get_lineage(dag_id: str, execution_date: str): """ Get Lineage details for a DagRun """ # Convert string datetime into actual datetime try: - execution_date = timezone.parse(execution_date) + execution_dt = timezone.parse(execution_date) except ValueError: error_message = ( 'Given execution date, {}, could not be identified ' @@ -387,7 +387,7 @@ def get_lineage(dag_id: str, execution_date: str): return response try: - lineage = get_lineage_api(dag_id=dag_id, execution_date=execution_date) + lineage = get_lineage_api(dag_id=dag_id, execution_date=execution_dt) except AirflowException as err: log.error(err) response = jsonify(error=f"{err}") diff --git a/tests/providers/apache/hive/transfers/test_mssql_to_hive.py b/tests/providers/apache/hive/transfers/test_mssql_to_hive.py index e6795ea8b3..1f06fcc470 100644 --- a/tests/providers/apache/hive/transfers/test_mssql_to_hive.py +++ b/tests/providers/apache/hive/transfers/test_mssql_to_hive.py @@ -24,7 +24,7 @@ from unittest.mock import Mock, PropertyMock, patch from airflow import PY38 if PY38: - MsSqlToHiveTransferOperator = None + MsSqlToHiveTransferOperator: None = None else: from airflow.providers.apache.hive.transfers.mssql_to_hive import MsSqlToHiveOperator diff --git a/tests/test_utils/mock_process.py b/tests/test_utils/mock_process.py index e8ecacdaab..a090291cb5 100644 --- a/tests/test_utils/mock_process.py +++ b/tests/test_utils/mock_process.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. +from typing import Optional from unittest import mock @@ -41,7 +42,7 @@ class MockStdOut: class MockSubProcess: PIPE = -1 STDOUT = -2 - returncode = None + returncode: Optional[int] = None def __init__(self, *args, **kwargs): self.stdout = MockStdOut(*args, **kwargs)