Show the location of the queries when the assert_queries_count fails. (#11186)

Example output (I forced one of the existing tests to fail)

```
E   AssertionError: The expected number of db queries is 3. The current number is 2.
E
E   Recorded query locations:
E   	scheduler_job.py:_run_scheduler_loop>scheduler_job.py:_emit_pool_metrics>pool.py:slots_stats:94:	1
E   	scheduler_job.py:_run_scheduler_loop>scheduler_job.py:_emit_pool_metrics>pool.py:slots_stats:101:	1
```

This makes it a bit easier to see what the queries are, without having
to re-run with full query tracing and then analyze the logs.
This commit is contained in:
Ash Berlin-Taylor 2020-09-28 19:39:21 +01:00 коммит произвёл GitHub
Родитель e2dc706b08
Коммит 6694eaa831
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 26 добавлений и 13 удалений

Просмотреть файл

@ -17,6 +17,8 @@
import logging
import re
import traceback
from collections import Counter
from contextlib import contextmanager
from sqlalchemy import event
@ -33,11 +35,6 @@ def assert_equal_ignore_multiple_spaces(case, first, second, msg=None):
return case.assertEqual(_trim(first), _trim(second), msg)
class CountQueriesResult:
def __init__(self):
self.count = 0
class CountQueries:
"""
Counts the number of queries sent to Airflow Database in a given context.
@ -46,18 +43,26 @@ class CountQueries:
not be included.
"""
def __init__(self):
self.result = CountQueriesResult()
self.result = Counter()
def __enter__(self):
event.listen(airflow.settings.engine, "after_cursor_execute", self.after_cursor_execute)
return self.result
def __exit__(self, type_, value, traceback):
def __exit__(self, type_, value, tb):
event.remove(airflow.settings.engine, "after_cursor_execute", self.after_cursor_execute)
log.debug("Queries count: %d", self.result.count)
log.debug("Queries count: %d", sum(self.result.values()))
def after_cursor_execute(self, *args, **kwargs):
self.result.count += 1
stack = [
f for f in traceback.extract_stack()
if 'sqlalchemy' not in f.filename and
__file__ != f.filename and
('session.py' not in f.filename and f.name != 'wrapper')
]
stack_info = ">".join([f"{f.filename.rpartition('/')[-1]}:{f.name}" for f in stack][-3:])
lineno = stack[-1].lineno
self.result[f"{stack_info}:{lineno}"] += 1
count_queries = CountQueries # pylint: disable=invalid-name
@ -67,7 +72,15 @@ count_queries = CountQueries # pylint: disable=invalid-name
def assert_queries_count(expected_count, message_fmt=None):
with count_queries() as result:
yield None
count = sum(result.values())
if expected_count != count:
message_fmt = message_fmt or "The expected number of db queries is {expected_count}. " \
"The current number is {current_count}."
message = message_fmt.format(current_count=result.count, expected_count=expected_count)
assert expected_count == result.count, message
"The current number is {current_count}.\n\n" \
"Recorded query locations:"
message = message_fmt.format(current_count=count, expected_count=expected_count)
for location, count in result.items():
message += f'\n\t{location}:\t{count}'
raise AssertionError(message)