Add Type Annotations & Docstrings to airflow/models/dagrun.py (#10466)

This commit is contained in:
Kaxil Naik 2020-08-22 09:58:14 +01:00 коммит произвёл GitHub
Родитель 7c206a82a6
Коммит 90b9e7e3c7
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 22 добавлений и 10 удалений

Просмотреть файл

@ -111,11 +111,12 @@ class DagRun(Base, LoggingMixin):
return synonym('_state', descriptor=property(self.get_state, self.set_state)) return synonym('_state', descriptor=property(self.get_state, self.set_state))
@provide_session @provide_session
def refresh_from_db(self, session=None): def refresh_from_db(self, session: Session = None):
""" """
Reloads the current dagrun from the database Reloads the current dagrun from the database
:param session: database session :param session: database session
:type session: Session
""" """
DR = DagRun DR = DagRun
@ -203,6 +204,7 @@ class DagRun(Base, LoggingMixin):
@staticmethod @staticmethod
def generate_run_id(run_type: DagRunType, execution_date: datetime) -> str: def generate_run_id(run_type: DagRunType, execution_date: datetime) -> str:
"""Generate Run ID based on Run Type and Execution Date"""
return f"{run_type.value}__{execution_date.isoformat()}" return f"{run_type.value}__{execution_date.isoformat()}"
@provide_session @provide_session
@ -237,11 +239,14 @@ class DagRun(Base, LoggingMixin):
return tis.all() return tis.all()
@provide_session @provide_session
def get_task_instance(self, task_id, session=None): def get_task_instance(self, task_id: str, session: Session = None):
""" """
Returns the task instance specified by task_id for this dag run Returns the task instance specified by task_id for this dag run
:param task_id: the task id :param task_id: the task id
:type task_id: str
:param session: Sqlalchemy ORM Session
:type session: Session
""" """
ti = session.query(TI).filter( ti = session.query(TI).filter(
TI.dag_id == self.dag_id, TI.dag_id == self.dag_id,
@ -258,8 +263,7 @@ class DagRun(Base, LoggingMixin):
:return: DAG :return: DAG
""" """
if not self.dag: if not self.dag:
raise AirflowException("The DAG (.dag) for {} needs to be set" raise AirflowException("The DAG (.dag) for {} needs to be set".format(self))
.format(self))
return self.dag return self.dag
@ -280,7 +284,7 @@ class DagRun(Base, LoggingMixin):
).first() ).first()
@provide_session @provide_session
def get_previous_scheduled_dagrun(self, session=None): def get_previous_scheduled_dagrun(self, session: Session = None):
"""The previous, SCHEDULED DagRun, if there is one""" """The previous, SCHEDULED DagRun, if there is one"""
dag = self.get_dag() dag = self.get_dag()
@ -290,11 +294,13 @@ class DagRun(Base, LoggingMixin):
).first() ).first()
@provide_session @provide_session
def update_state(self, session=None) -> List[TI]: def update_state(self, session: Session = None) -> List[TI]:
""" """
Determines the overall state of the DagRun based on the state Determines the overall state of the DagRun based on the state
of its TaskInstances. of its TaskInstances.
:param session: Sqlalchemy ORM Session
:type session: Session
:return: ready_tis: the tis that can be scheduled in the current loop :return: ready_tis: the tis that can be scheduled in the current loop
:rtype ready_tis: list[airflow.models.TaskInstance] :rtype ready_tis: list[airflow.models.TaskInstance]
""" """
@ -336,8 +342,7 @@ class DagRun(Base, LoggingMixin):
): ):
self.log.error('Marking run %s failed', self) self.log.error('Marking run %s failed', self)
self.set_state(State.FAILED) self.set_state(State.FAILED)
dag.handle_callback(self, success=False, reason='task_failure', dag.handle_callback(self, success=False, reason='task_failure', session=session)
session=session)
# if all leafs succeeded and no unfinished tasks, the run succeeded # if all leafs succeeded and no unfinished tasks, the run succeeded
elif not unfinished_tasks and all( elif not unfinished_tasks and all(
@ -430,10 +435,13 @@ class DagRun(Base, LoggingMixin):
Stats.timing('dagrun.duration.failed.{}'.format(self.dag_id), duration) Stats.timing('dagrun.duration.failed.{}'.format(self.dag_id), duration)
@provide_session @provide_session
def verify_integrity(self, session=None): def verify_integrity(self, session: Session = None):
""" """
Verifies the DagRun by checking for removed tasks or tasks that are not in the Verifies the DagRun by checking for removed tasks or tasks that are not in the
database yet. It will set state to removed or add the task if required. database yet. It will set state to removed or add the task if required.
:param session: Sqlalchemy ORM Session
:type session: Session
""" """
dag = self.get_dag() dag = self.get_dag()
tis = self.get_task_instances(session=session) tis = self.get_task_instances(session=session)
@ -487,8 +495,12 @@ class DagRun(Base, LoggingMixin):
session.rollback() session.rollback()
@staticmethod @staticmethod
def get_run(session, dag_id, execution_date): def get_run(session: Session, dag_id: str, execution_date: datetime):
""" """
Get a single DAG Run
:param session: Sqlalchemy ORM Session
:type session: Session
:param dag_id: DAG ID :param dag_id: DAG ID
:type dag_id: unicode :type dag_id: unicode
:param execution_date: execution date :param execution_date: execution date