Show the location of the queries when the assert_queries_count fails. (#11186)
Example output (I forced one of the existing tests to fail) ``` E AssertionError: The expected number of db queries is 3. The current number is 2. E E Recorded query locations: E scheduler_job.py:_run_scheduler_loop>scheduler_job.py:_emit_pool_metrics>pool.py:slots_stats:94: 1 E scheduler_job.py:_run_scheduler_loop>scheduler_job.py:_emit_pool_metrics>pool.py:slots_stats:101: 1 ``` This makes it a bit easier to see what the queries are, without having to re-run with full query tracing and then analyze the logs.
This commit is contained in:
Родитель
e2dc706b08
Коммит
6694eaa831
|
@ -17,6 +17,8 @@
|
|||
|
||||
import logging
|
||||
import re
|
||||
import traceback
|
||||
from collections import Counter
|
||||
from contextlib import contextmanager
|
||||
|
||||
from sqlalchemy import event
|
||||
|
@ -33,11 +35,6 @@ def assert_equal_ignore_multiple_spaces(case, first, second, msg=None):
|
|||
return case.assertEqual(_trim(first), _trim(second), msg)
|
||||
|
||||
|
||||
class CountQueriesResult:
|
||||
def __init__(self):
|
||||
self.count = 0
|
||||
|
||||
|
||||
class CountQueries:
|
||||
"""
|
||||
Counts the number of queries sent to Airflow Database in a given context.
|
||||
|
@ -46,18 +43,26 @@ class CountQueries:
|
|||
not be included.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.result = CountQueriesResult()
|
||||
self.result = Counter()
|
||||
|
||||
def __enter__(self):
|
||||
event.listen(airflow.settings.engine, "after_cursor_execute", self.after_cursor_execute)
|
||||
return self.result
|
||||
|
||||
def __exit__(self, type_, value, traceback):
|
||||
def __exit__(self, type_, value, tb):
|
||||
event.remove(airflow.settings.engine, "after_cursor_execute", self.after_cursor_execute)
|
||||
log.debug("Queries count: %d", self.result.count)
|
||||
log.debug("Queries count: %d", sum(self.result.values()))
|
||||
|
||||
def after_cursor_execute(self, *args, **kwargs):
|
||||
self.result.count += 1
|
||||
stack = [
|
||||
f for f in traceback.extract_stack()
|
||||
if 'sqlalchemy' not in f.filename and
|
||||
__file__ != f.filename and
|
||||
('session.py' not in f.filename and f.name != 'wrapper')
|
||||
]
|
||||
stack_info = ">".join([f"{f.filename.rpartition('/')[-1]}:{f.name}" for f in stack][-3:])
|
||||
lineno = stack[-1].lineno
|
||||
self.result[f"{stack_info}:{lineno}"] += 1
|
||||
|
||||
|
||||
count_queries = CountQueries # pylint: disable=invalid-name
|
||||
|
@ -67,7 +72,15 @@ count_queries = CountQueries # pylint: disable=invalid-name
|
|||
def assert_queries_count(expected_count, message_fmt=None):
|
||||
with count_queries() as result:
|
||||
yield None
|
||||
|
||||
count = sum(result.values())
|
||||
if expected_count != count:
|
||||
message_fmt = message_fmt or "The expected number of db queries is {expected_count}. " \
|
||||
"The current number is {current_count}."
|
||||
message = message_fmt.format(current_count=result.count, expected_count=expected_count)
|
||||
assert expected_count == result.count, message
|
||||
"The current number is {current_count}.\n\n" \
|
||||
"Recorded query locations:"
|
||||
message = message_fmt.format(current_count=count, expected_count=expected_count)
|
||||
|
||||
for location, count in result.items():
|
||||
message += f'\n\t{location}:\t{count}'
|
||||
|
||||
raise AssertionError(message)
|
||||
|
|
Загрузка…
Ссылка в новой задаче