[AIRFLOW-6026] Use contextlib to redirect stderr and stdout (#6624)

This commit is contained in:
Kamil Breguła 2019-11-21 19:57:00 +01:00 коммит произвёл GitHub
Родитель 1d8b8cfcbc
Коммит da086661f7
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
11 изменённых файлов: 87 добавлений и 92 удалений

Просмотреть файл

@ -41,6 +41,36 @@ assists users migrating to a new version.
## Airflow Master ## Airflow Master
### Removal of redirect_stdout, redirect_stderr
Function `redirect_stderr` and `redirect_stdout` from `airflow.utils.log.logging_mixin` module has
been deleted because it can be easily replaced by the standard library.
The functions of the standard library are more flexible and can be used in larger cases.
The code below
```python
import logging
from airflow.utils.log.logging_mixin import redirect_stderr, redirect_stdout
logger = logging.getLogger("custom-logger")
with redirect_stdout(logger, logging.INFO), redirect_stderr(logger, logging.WARN):
print("I love Airflow")
```
can be replaced by the following code:
```python
from contextlib import redirect_stdout, redirect_stderr
import logging
from airflow.utils.log.logging_mixin import StreamLogWriter
logger = logging.getLogger("custom-logger")
with redirect_stdout(StreamLogWriter(logger, logging.INFO)), \
redirect_stderr(StreamLogWriter(logger, logging.WARN)):
print("I Love Airflow")
```
### Removal of XCom.get_one() ### Removal of XCom.get_one()
This one is supersede by `XCom.get_many().first()` which will return the same result. This one is supersede by `XCom.get_many().first()` which will return the same result.

Просмотреть файл

@ -22,6 +22,7 @@ import json
import logging import logging
import os import os
import textwrap import textwrap
from contextlib import redirect_stderr, redirect_stdout
from airflow import DAG, AirflowException, LoggingMixin, conf, jobs, settings from airflow import DAG, AirflowException, LoggingMixin, conf, jobs, settings
from airflow.executors import get_default_executor from airflow.executors import get_default_executor
@ -29,7 +30,7 @@ from airflow.models import DagPickle, TaskInstance
from airflow.ti_deps.dep_context import SCHEDULER_QUEUED_DEPS, DepContext from airflow.ti_deps.dep_context import SCHEDULER_QUEUED_DEPS, DepContext
from airflow.utils import cli as cli_utils, db from airflow.utils import cli as cli_utils, db
from airflow.utils.cli import get_dag, get_dags from airflow.utils.cli import get_dag, get_dags
from airflow.utils.log.logging_mixin import redirect_stderr, redirect_stdout from airflow.utils.log.logging_mixin import StreamLogWriter
from airflow.utils.net import get_hostname from airflow.utils.net import get_hostname
@ -131,7 +132,8 @@ def task_run(args, dag=None):
if args.interactive: if args.interactive:
_run(args, dag, ti) _run(args, dag, ti)
else: else:
with redirect_stdout(ti.log, logging.INFO), redirect_stderr(ti.log, logging.WARN): with redirect_stdout(StreamLogWriter(ti.log, logging.INFO)), \
redirect_stderr(StreamLogWriter(ti.log, logging.WARN)):
_run(args, dag, ti) _run(args, dag, ti)
logging.shutdown() logging.shutdown()

Просмотреть файл

@ -26,6 +26,7 @@ import sys
import threading import threading
import time import time
from collections import defaultdict from collections import defaultdict
from contextlib import redirect_stderr, redirect_stdout
from datetime import timedelta from datetime import timedelta
from time import sleep from time import sleep
from typing import List, Set from typing import List, Set
@ -122,45 +123,40 @@ class DagFileProcessor(AbstractDagFileProcessor, LoggingMixin):
# This helper runs in the newly created process # This helper runs in the newly created process
log = logging.getLogger("airflow.processor") log = logging.getLogger("airflow.processor")
stdout = StreamLogWriter(log, logging.INFO)
stderr = StreamLogWriter(log, logging.WARN)
set_context(log, file_path) set_context(log, file_path)
setproctitle("airflow scheduler - DagFileProcessor {}".format(file_path)) setproctitle("airflow scheduler - DagFileProcessor {}".format(file_path))
try: try:
# redirect stdout/stderr to log # redirect stdout/stderr to log
sys.stdout = stdout with redirect_stdout(StreamLogWriter(log, logging.INFO)),\
sys.stderr = stderr redirect_stderr(StreamLogWriter(log, logging.WARN)):
# Re-configure the ORM engine as there are issues with multiple processes # Re-configure the ORM engine as there are issues with multiple processes
settings.configure_orm() settings.configure_orm()
# Change the thread name to differentiate log lines. This is # Change the thread name to differentiate log lines. This is
# really a separate process, but changing the name of the # really a separate process, but changing the name of the
# process doesn't work, so changing the thread name instead. # process doesn't work, so changing the thread name instead.
threading.current_thread().name = thread_name threading.current_thread().name = thread_name
start_time = time.time() start_time = time.time()
log.info("Started process (PID=%s) to work on %s", log.info("Started process (PID=%s) to work on %s",
os.getpid(), file_path) os.getpid(), file_path)
scheduler_job = SchedulerJob(dag_ids=dag_id_white_list, log=log) scheduler_job = SchedulerJob(dag_ids=dag_id_white_list, log=log)
result = scheduler_job.process_file(file_path, result = scheduler_job.process_file(file_path,
zombies, zombies,
pickle_dags) pickle_dags)
result_channel.send(result) result_channel.send(result)
end_time = time.time() end_time = time.time()
log.info( log.info(
"Processing %s took %.3f seconds", file_path, end_time - start_time "Processing %s took %.3f seconds", file_path, end_time - start_time
) )
except Exception: except Exception:
# Log exceptions through the logging framework. # Log exceptions through the logging framework.
log.exception("Got an exception! Propagating...") log.exception("Got an exception! Propagating...")
raise raise
finally: finally:
result_channel.close() result_channel.close()
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
# We re-initialized the ORM within this Process above so we need to # We re-initialized the ORM within this Process above so we need to
# tear it down manually here # tear it down manually here
settings.dispose_orm() settings.dispose_orm()

Просмотреть файл

@ -19,7 +19,6 @@
import logging import logging
import re import re
import sys import sys
from contextlib import contextmanager
from logging import Handler, Logger, StreamHandler from logging import Handler, Logger, StreamHandler
# 7-bit C1 ANSI escape sequences # 7-bit C1 ANSI escape sequences
@ -144,26 +143,6 @@ class RedirectStdHandler(StreamHandler):
return sys.stdout return sys.stdout
@contextmanager
def redirect_stdout(logger, level):
writer = StreamLogWriter(logger, level)
try:
sys.stdout = writer
yield
finally:
sys.stdout = sys.__stdout__
@contextmanager
def redirect_stderr(logger, level):
writer = StreamLogWriter(logger, level)
try:
sys.stderr = writer
yield
finally:
sys.stderr = sys.__stderr__
def set_context(logger, value): def set_context(logger, value):
""" """
Walks the tree of loggers and tries to set the context for each handler Walks the tree of loggers and tries to set the context for each handler

Просмотреть файл

@ -20,7 +20,7 @@ import re
import subprocess import subprocess
import tempfile import tempfile
import unittest import unittest
from unittest import mock from contextlib import redirect_stdout
from airflow import settings from airflow import settings
from airflow.bin import cli from airflow.bin import cli
@ -35,9 +35,9 @@ class TestCliConnections(unittest.TestCase):
cls.parser = cli.CLIFactory.get_parser() cls.parser = cli.CLIFactory.get_parser()
def test_cli_connections_list(self): def test_cli_connections_list(self):
with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout: with redirect_stdout(io.StringIO()) as stdout:
connection_command.connections_list(self.parser.parse_args(['connections', 'list'])) connection_command.connections_list(self.parser.parse_args(['connections', 'list']))
stdout = mock_stdout.getvalue() stdout = stdout.getvalue()
conns = [[x.strip("'") for x in re.findall(r"'\w+'", line)[:2]] conns = [[x.strip("'") for x in re.findall(r"'\w+'", line)[:2]]
for ii, line in enumerate(stdout.split('\n')) for ii, line in enumerate(stdout.split('\n'))
if ii % 2 == 1] if ii % 2 == 1]
@ -71,7 +71,7 @@ class TestCliConnections(unittest.TestCase):
db.resetdb() db.resetdb()
# Add connections: # Add connections:
uri = 'postgresql://airflow:airflow@host:5432/airflow' uri = 'postgresql://airflow:airflow@host:5432/airflow'
with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout: with redirect_stdout(io.StringIO()) as stdout:
connection_command.connections_add(self.parser.parse_args( connection_command.connections_add(self.parser.parse_args(
['connections', 'add', 'new1', ['connections', 'add', 'new1',
'--conn_uri=%s' % uri])) '--conn_uri=%s' % uri]))
@ -92,7 +92,7 @@ class TestCliConnections(unittest.TestCase):
connection_command.connections_add(self.parser.parse_args( connection_command.connections_add(self.parser.parse_args(
['connections', 'add', 'new6', ['connections', 'add', 'new6',
'--conn_uri', "", '--conn_type=google_cloud_platform', '--conn_extra', "{'extra': 'yes'}"])) '--conn_uri', "", '--conn_type=google_cloud_platform', '--conn_extra', "{'extra': 'yes'}"]))
stdout = mock_stdout.getvalue() stdout = stdout.getvalue()
# Check addition stdout # Check addition stdout
lines = [l for l in stdout.split('\n') if len(l) > 0] lines = [l for l in stdout.split('\n') if len(l) > 0]
@ -112,11 +112,11 @@ class TestCliConnections(unittest.TestCase):
]) ])
# Attempt to add duplicate # Attempt to add duplicate
with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout: with redirect_stdout(io.StringIO()) as stdout:
connection_command.connections_add(self.parser.parse_args( connection_command.connections_add(self.parser.parse_args(
['connections', 'add', 'new1', ['connections', 'add', 'new1',
'--conn_uri=%s' % uri])) '--conn_uri=%s' % uri]))
stdout = mock_stdout.getvalue() stdout = stdout.getvalue()
# Check stdout for addition attempt # Check stdout for addition attempt
lines = [l for l in stdout.split('\n') if len(l) > 0] lines = [l for l in stdout.split('\n') if len(l) > 0]
@ -161,7 +161,7 @@ class TestCliConnections(unittest.TestCase):
None, None, "{'extra': 'yes'}")) None, None, "{'extra': 'yes'}"))
# Delete connections # Delete connections
with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout: with redirect_stdout(io.StringIO()) as stdout:
connection_command.connections_delete(self.parser.parse_args( connection_command.connections_delete(self.parser.parse_args(
['connections', 'delete', 'new1'])) ['connections', 'delete', 'new1']))
connection_command.connections_delete(self.parser.parse_args( connection_command.connections_delete(self.parser.parse_args(
@ -174,7 +174,7 @@ class TestCliConnections(unittest.TestCase):
['connections', 'delete', 'new5'])) ['connections', 'delete', 'new5']))
connection_command.connections_delete(self.parser.parse_args( connection_command.connections_delete(self.parser.parse_args(
['connections', 'delete', 'new6'])) ['connections', 'delete', 'new6']))
stdout = mock_stdout.getvalue() stdout = stdout.getvalue()
# Check deletion stdout # Check deletion stdout
lines = [l for l in stdout.split('\n') if len(l) > 0] lines = [l for l in stdout.split('\n') if len(l) > 0]
@ -197,10 +197,10 @@ class TestCliConnections(unittest.TestCase):
self.assertTrue(result is None) self.assertTrue(result is None)
# Attempt to delete a non-existing connection # Attempt to delete a non-existing connection
with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout: with redirect_stdout(io.StringIO()) as stdout:
connection_command.connections_delete(self.parser.parse_args( connection_command.connections_delete(self.parser.parse_args(
['connections', 'delete', 'fake'])) ['connections', 'delete', 'fake']))
stdout = mock_stdout.getvalue() stdout = stdout.getvalue()
# Check deletion attempt stdout # Check deletion attempt stdout
lines = [l for l in stdout.split('\n') if len(l) > 0] lines = [l for l in stdout.split('\n') if len(l) > 0]

Просмотреть файл

@ -138,14 +138,12 @@ class TestCliDags(unittest.TestCase):
mock_run.reset_mock() mock_run.reset_mock()
dag = self.dagbag.get_dag('example_bash_operator') dag = self.dagbag.get_dag('example_bash_operator')
with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout: with contextlib.redirect_stdout(io.StringIO()) as stdout:
dag_command.dag_backfill(self.parser.parse_args([ dag_command.dag_backfill(self.parser.parse_args([
'dags', 'backfill', 'example_bash_operator', '-t', 'runme_0', '--dry_run', 'dags', 'backfill', 'example_bash_operator', '-t', 'runme_0', '--dry_run',
'-s', DEFAULT_DATE.isoformat()]), dag=dag) '-s', DEFAULT_DATE.isoformat()]), dag=dag)
mock_stdout.seek(0, 0) output = stdout.getvalue()
output = mock_stdout.read()
self.assertIn("Dry run of DAG example_bash_operator on {}\n".format(DEFAULT_DATE.isoformat()), output) self.assertIn("Dry run of DAG example_bash_operator on {}\n".format(DEFAULT_DATE.isoformat()), output)
self.assertIn("Task runme_0\n".format(DEFAULT_DATE.isoformat()), output) self.assertIn("Task runme_0\n".format(DEFAULT_DATE.isoformat()), output)
@ -179,8 +177,7 @@ class TestCliDags(unittest.TestCase):
mock_run.reset_mock() mock_run.reset_mock()
def test_show_dag_print(self): def test_show_dag_print(self):
temp_stdout = io.StringIO() with contextlib.redirect_stdout(io.StringIO()) as temp_stdout:
with contextlib.redirect_stdout(temp_stdout):
dag_command.dag_show(self.parser.parse_args([ dag_command.dag_show(self.parser.parse_args([
'dags', 'show', 'example_bash_operator'])) 'dags', 'show', 'example_bash_operator']))
out = temp_stdout.getvalue() out = temp_stdout.getvalue()
@ -190,8 +187,7 @@ class TestCliDags(unittest.TestCase):
@mock.patch("airflow.cli.commands.dag_command.render_dag") @mock.patch("airflow.cli.commands.dag_command.render_dag")
def test_show_dag_dave(self, mock_render_dag): def test_show_dag_dave(self, mock_render_dag):
temp_stdout = io.StringIO() with contextlib.redirect_stdout(io.StringIO()) as temp_stdout:
with contextlib.redirect_stdout(temp_stdout):
dag_command.dag_show(self.parser.parse_args([ dag_command.dag_show(self.parser.parse_args([
'dags', 'show', 'example_bash_operator', '--save', 'awesome.png'] 'dags', 'show', 'example_bash_operator', '--save', 'awesome.png']
)) ))
@ -206,8 +202,7 @@ class TestCliDags(unittest.TestCase):
def test_show_dag_imgcat(self, mock_render_dag, mock_popen): def test_show_dag_imgcat(self, mock_render_dag, mock_popen):
mock_render_dag.return_value.pipe.return_value = b"DOT_DATA" mock_render_dag.return_value.pipe.return_value = b"DOT_DATA"
mock_popen.return_value.communicate.return_value = (b"OUT", b"ERR") mock_popen.return_value.communicate.return_value = (b"OUT", b"ERR")
temp_stdout = io.StringIO() with contextlib.redirect_stdout(io.StringIO()) as temp_stdout:
with contextlib.redirect_stdout(temp_stdout):
dag_command.dag_show(self.parser.parse_args([ dag_command.dag_show(self.parser.parse_args([
'dags', 'show', 'example_bash_operator', '--imgcat'] 'dags', 'show', 'example_bash_operator', '--imgcat']
)) ))

Просмотреть файл

@ -19,7 +19,7 @@
# #
import io import io
import unittest import unittest
from unittest import mock from contextlib import redirect_stdout
from airflow import models from airflow import models
from airflow.bin import cli from airflow.bin import cli
@ -82,9 +82,9 @@ class TestCliRoles(unittest.TestCase):
self.appbuilder.sm.add_role('FakeTeamA') self.appbuilder.sm.add_role('FakeTeamA')
self.appbuilder.sm.add_role('FakeTeamB') self.appbuilder.sm.add_role('FakeTeamB')
with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout: with redirect_stdout(io.StringIO()) as stdout:
role_command.roles_list(self.parser.parse_args(['roles', 'list'])) role_command.roles_list(self.parser.parse_args(['roles', 'list']))
stdout = mock_stdout.getvalue() stdout = stdout.getvalue()
self.assertIn('FakeTeamA', stdout) self.assertIn('FakeTeamA', stdout)
self.assertIn('FakeTeamB', stdout) self.assertIn('FakeTeamB', stdout)

Просмотреть файл

@ -18,9 +18,9 @@
# under the License. # under the License.
# #
import io import io
import sys
import unittest import unittest
from argparse import Namespace from argparse import Namespace
from contextlib import redirect_stdout
from datetime import datetime, timedelta from datetime import datetime, timedelta
from unittest import mock from unittest import mock
from unittest.mock import MagicMock from unittest.mock import MagicMock
@ -122,16 +122,10 @@ class TestCliTasks(unittest.TestCase):
execution_date=timezone.parse('2018-01-01') execution_date=timezone.parse('2018-01-01')
) )
saved_stdout = sys.stdout with redirect_stdout(io.StringIO()) as stdout:
try:
sys.stdout = out = io.StringIO()
task_command.task_test(args) task_command.task_test(args)
# Check that prints, and log messages, are shown
output = out.getvalue() self.assertIn("'example_python_operator__print_the_context__20180101'", stdout.getvalue())
# Check that prints, and log messages, are shown
self.assertIn("'example_python_operator__print_the_context__20180101'", output)
finally:
sys.stdout = saved_stdout
@mock.patch("airflow.cli.commands.task_command.jobs.LocalTaskJob") @mock.patch("airflow.cli.commands.task_command.jobs.LocalTaskJob")
def test_run_naive_taskinstance(self, mock_local_job): def test_run_naive_taskinstance(self, mock_local_job):

Просмотреть файл

@ -20,7 +20,7 @@ import json
import os import os
import tempfile import tempfile
import unittest import unittest
from unittest import mock from contextlib import redirect_stdout
from airflow import models from airflow import models
from airflow.bin import cli from airflow.bin import cli
@ -100,9 +100,9 @@ class TestCliUsers(unittest.TestCase):
'--use_random_password' '--use_random_password'
]) ])
user_command.users_create(args) user_command.users_create(args)
with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout: with redirect_stdout(io.StringIO()) as stdout:
user_command.users_list(self.parser.parse_args(['users', 'list'])) user_command.users_list(self.parser.parse_args(['users', 'list']))
stdout = mock_stdout.getvalue() stdout = stdout.getvalue()
for i in range(0, 3): for i in range(0, 3):
self.assertIn('user{}'.format(i), stdout) self.assertIn('user{}'.format(i), stdout)

Просмотреть файл

@ -17,7 +17,7 @@
import io import io
import unittest import unittest
from unittest import mock from contextlib import redirect_stdout
import airflow.cli.commands.version_command import airflow.cli.commands.version_command
from airflow.bin import cli from airflow.bin import cli
@ -30,7 +30,6 @@ class TestCliVersion(unittest.TestCase):
cls.parser = cli.CLIFactory.get_parser() cls.parser = cli.CLIFactory.get_parser()
def test_cli_version(self): def test_cli_version(self):
with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout: with redirect_stdout(io.StringIO()) as stdout:
airflow.cli.commands.version_command.version(self.parser.parse_args(['version'])) airflow.cli.commands.version_command.version(self.parser.parse_args(['version']))
stdout = mock_stdout.getvalue() self.assertIn(version, stdout.getvalue())
self.assertIn(version, stdout)

Просмотреть файл

@ -23,8 +23,8 @@ import logging
import os import os
import re import re
import unittest import unittest
from contextlib import redirect_stdout
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from unittest import mock
from unittest.mock import patch from unittest.mock import patch
import pendulum import pendulum
@ -908,9 +908,9 @@ class TestDag(unittest.TestCase):
t3 = DummyOperator(task_id="t3") t3 = DummyOperator(task_id="t3")
t1 >> t2 >> t3 t1 >> t2 >> t3
with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout: with redirect_stdout(io.StringIO()) as stdout:
dag.tree_view() dag.tree_view()
stdout = mock_stdout.getvalue() stdout = stdout.getvalue()
stdout_lines = stdout.split("\n") stdout_lines = stdout.split("\n")
self.assertIn('t1', stdout_lines[0]) self.assertIn('t1', stdout_lines[0])