diff --git a/tests/test_utils/asserts.py b/tests/test_utils/asserts.py index ca3cf2fd62..220331de15 100644 --- a/tests/test_utils/asserts.py +++ b/tests/test_utils/asserts.py @@ -17,6 +17,8 @@ import logging import re +import traceback +from collections import Counter from contextlib import contextmanager from sqlalchemy import event @@ -33,11 +35,6 @@ def assert_equal_ignore_multiple_spaces(case, first, second, msg=None): return case.assertEqual(_trim(first), _trim(second), msg) -class CountQueriesResult: - def __init__(self): - self.count = 0 - - class CountQueries: """ Counts the number of queries sent to Airflow Database in a given context. @@ -46,18 +43,26 @@ class CountQueries: not be included. """ def __init__(self): - self.result = CountQueriesResult() + self.result = Counter() def __enter__(self): event.listen(airflow.settings.engine, "after_cursor_execute", self.after_cursor_execute) return self.result - def __exit__(self, type_, value, traceback): + def __exit__(self, type_, value, tb): event.remove(airflow.settings.engine, "after_cursor_execute", self.after_cursor_execute) - log.debug("Queries count: %d", self.result.count) + log.debug("Queries count: %d", sum(self.result.values())) def after_cursor_execute(self, *args, **kwargs): - self.result.count += 1 + stack = [ + f for f in traceback.extract_stack() + if 'sqlalchemy' not in f.filename and + __file__ != f.filename and + ('session.py' not in f.filename and f.name != 'wrapper') + ] + stack_info = ">".join([f"{f.filename.rpartition('/')[-1]}:{f.name}" for f in stack][-3:]) + lineno = stack[-1].lineno + self.result[f"{stack_info}:{lineno}"] += 1 count_queries = CountQueries # pylint: disable=invalid-name @@ -67,7 +72,15 @@ count_queries = CountQueries # pylint: disable=invalid-name def assert_queries_count(expected_count, message_fmt=None): with count_queries() as result: yield None - message_fmt = message_fmt or "The expected number of db queries is {expected_count}. " \ - "The current number is {current_count}." - message = message_fmt.format(current_count=result.count, expected_count=expected_count) - assert expected_count == result.count, message + + count = sum(result.values()) + if expected_count != count: + message_fmt = message_fmt or "The expected number of db queries is {expected_count}. " \ + "The current number is {current_count}.\n\n" \ + "Recorded query locations:" + message = message_fmt.format(current_count=count, expected_count=expected_count) + + for location, count in result.items(): + message += f'\n\t{location}:\t{count}' + + raise AssertionError(message)