Disable row level locking for Mariadb and MySQL <8 (#14031)

closes #11899
closes #13668

This PR disable row-level locking for MySQL variants that do not support skip_locked and no_wait -- MySQL < 8 and MariaDB

(cherry picked from commit 568327f01a)
This commit is contained in:
Kaxil Naik 2021-02-03 02:55:27 +00:00
Родитель c87bf1f204
Коммит 3870392356
7 изменённых файлов: 50 добавлений и 11 удалений

Просмотреть файл

@ -811,7 +811,9 @@ class SchedulerJob(BaseJob): # pylint: disable=too-many-instance-attributes
# We need to do this for mysql as well because it can cause deadlocks
# as discussed in https://issues.apache.org/jira/browse/AIRFLOW-2516
if self.using_sqlite or self.using_mysql:
tis_to_change: List[TI] = with_row_locks(query, of=TI, **skip_locked(session=session)).all()
tis_to_change: List[TI] = with_row_locks(
query, of=TI, session=session, **skip_locked(session=session)
).all()
for ti in tis_to_change:
ti.set_state(new_state, session=session)
tis_changed += 1
@ -921,6 +923,7 @@ class SchedulerJob(BaseJob): # pylint: disable=too-many-instance-attributes
task_instances_to_examine: List[TI] = with_row_locks(
query,
of=TI,
session=session,
**skip_locked(session=session),
).all()
# TODO[HA]: This was wrong before anyway, as it only looked at a sub-set of dags, not everything.
@ -1158,7 +1161,7 @@ class SchedulerJob(BaseJob): # pylint: disable=too-many-instance-attributes
for dag_id, task_id, execution_date, try_number in self.executor.queued_tasks.keys()
]
ti_query = session.query(TI).filter(or_(*filter_for_ti_state_change))
tis_to_set_to_scheduled: List[TI] = with_row_locks(ti_query).all()
tis_to_set_to_scheduled: List[TI] = with_row_locks(ti_query, session=session).all()
if not tis_to_set_to_scheduled:
return
@ -1827,7 +1830,9 @@ class SchedulerJob(BaseJob): # pylint: disable=too-many-instance-attributes
)
# Lock these rows, so that another scheduler can't try and adopt these too
tis_to_reset_or_adopt = with_row_locks(query, of=TI, **skip_locked(session=session)).all()
tis_to_reset_or_adopt = with_row_locks(
query, of=TI, session=session, **skip_locked(session=session)
).all()
to_reset = self.executor.try_adopt_task_instances(tis_to_reset_or_adopt)
reset_tis_message = []

Просмотреть файл

@ -1823,7 +1823,7 @@ class DAG(LoggingMixin):
.options(joinedload(DagModel.tags, innerjoin=False))
.filter(DagModel.dag_id.in_(dag_ids))
)
orm_dags = with_row_locks(query, of=DagModel).all()
orm_dags = with_row_locks(query, of=DagModel, session=session).all()
existing_dag_ids = {orm_dag.dag_id for orm_dag in orm_dags}
missing_dag_ids = dag_ids.difference(existing_dag_ids)
@ -2246,7 +2246,7 @@ class DagModel(Base):
.limit(cls.NUM_DAGS_PER_DAGRUN_QUERY)
)
return with_row_locks(query, of=cls, **skip_locked(session=session))
return with_row_locks(query, of=cls, session=session, **skip_locked(session=session))
def calculate_dagrun_date_fields(
self, dag: DAG, most_recent_dag_run: Optional[pendulum.DateTime], active_runs_of_dag: int

Просмотреть файл

@ -224,7 +224,9 @@ class DagRun(Base, LoggingMixin):
if not settings.ALLOW_FUTURE_EXEC_DATES:
query = query.filter(DagRun.execution_date <= func.now())
return with_row_locks(query.limit(max_number), of=cls, **skip_locked(session=session))
return with_row_locks(
query.limit(max_number), of=cls, session=session, **skip_locked(session=session)
)
@staticmethod
@provide_session

Просмотреть файл

@ -102,7 +102,7 @@ class Pool(Base):
query = session.query(Pool.pool, Pool.slots)
if lock_rows:
query = with_row_locks(query, **nowait(session))
query = with_row_locks(query, session=session, **nowait(session))
pool_rows: Iterable[Tuple[str, int]] = query.all()
for (pool_name, total_slots) in pool_rows:

Просмотреть файл

@ -1187,7 +1187,8 @@ class TaskInstance(Base, LoggingMixin): # pylint: disable=R0902,R0904
session.query(DagRun).filter_by(
dag_id=self.dag_id,
execution_date=self.execution_date,
)
),
session=session,
).one()
# Get a partial dag with just the specific tasks we want to

Просмотреть файл

@ -188,15 +188,19 @@ def nulls_first(col, session: Session) -> Dict[str, Any]:
USE_ROW_LEVEL_LOCKING: bool = conf.getboolean('scheduler', 'use_row_level_locking', fallback=True)
def with_row_locks(query, **kwargs):
def with_row_locks(query, session: Session, **kwargs):
"""
Apply with_for_update to an SQLAlchemy query, if row level locking is in use.
:param query: An SQLAlchemy Query object
:param session: ORM Session
:param kwargs: Extra kwargs to pass to with_for_update (of, nowait, skip_locked, etc)
:return: updated query
"""
if USE_ROW_LEVEL_LOCKING:
dialect = session.bind.dialect
# Don't use row level locks if the MySQL dialect (Mariadb & MySQL < 8) does not support it.
if USE_ROW_LEVEL_LOCKING and (dialect.name != "mysql" or dialect.supports_for_update_of):
return query.with_for_update(**kwargs)
else:
return query

Просмотреть файл

@ -27,7 +27,7 @@ from sqlalchemy.exc import StatementError
from airflow import settings
from airflow.models import DAG
from airflow.settings import Session
from airflow.utils.sqlalchemy import nowait, prohibit_commit, skip_locked
from airflow.utils.sqlalchemy import nowait, prohibit_commit, skip_locked, with_row_locks
from airflow.utils.state import State
from airflow.utils.timezone import utcnow
@ -161,6 +161,33 @@ class TestSqlAlchemyUtils(unittest.TestCase):
session.bind.dialect.supports_for_update_of = supports_for_update_of
assert nowait(session=session) == expected_return_value
@parameterized.expand(
[
("postgresql", True, True, True),
("postgresql", True, False, False),
("mysql", False, True, False),
("mysql", False, False, False),
("mysql", True, True, True),
("mysql", True, False, False),
("sqlite", False, True, True),
]
)
def test_with_row_locks(
self, dialect, supports_for_update_of, use_row_level_lock_conf, expected_use_row_level_lock
):
query = mock.Mock()
session = mock.Mock()
session.bind.dialect.name = dialect
session.bind.dialect.supports_for_update_of = supports_for_update_of
with mock.patch("airflow.utils.sqlalchemy.USE_ROW_LEVEL_LOCKING", use_row_level_lock_conf):
returned_value = with_row_locks(query=query, session=session, nowait=True)
if expected_use_row_level_lock:
query.with_for_update.assert_called_once_with(nowait=True)
else:
assert returned_value == query
query.with_for_update.assert_not_called()
def test_prohibit_commit(self):
with prohibit_commit(self.session) as guard:
self.session.execute('SELECT 1')