[AIRFLOW-6026] Use contextlib to redirect stderr and stdout (#6624)

This commit is contained in:
Kamil Breguła 2019-11-21 19:57:00 +01:00 коммит произвёл GitHub
Родитель 1d8b8cfcbc
Коммит da086661f7
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
11 изменённых файлов: 87 добавлений и 92 удалений

Просмотреть файл

@ -41,6 +41,36 @@ assists users migrating to a new version.
## Airflow Master
### Removal of redirect_stdout, redirect_stderr
Function `redirect_stderr` and `redirect_stdout` from `airflow.utils.log.logging_mixin` module has
been deleted because it can be easily replaced by the standard library.
The functions of the standard library are more flexible and can be used in larger cases.
The code below
```python
import logging
from airflow.utils.log.logging_mixin import redirect_stderr, redirect_stdout
logger = logging.getLogger("custom-logger")
with redirect_stdout(logger, logging.INFO), redirect_stderr(logger, logging.WARN):
print("I love Airflow")
```
can be replaced by the following code:
```python
from contextlib import redirect_stdout, redirect_stderr
import logging
from airflow.utils.log.logging_mixin import StreamLogWriter
logger = logging.getLogger("custom-logger")
with redirect_stdout(StreamLogWriter(logger, logging.INFO)), \
redirect_stderr(StreamLogWriter(logger, logging.WARN)):
print("I Love Airflow")
```
### Removal of XCom.get_one()
This one is supersede by `XCom.get_many().first()` which will return the same result.

Просмотреть файл

@ -22,6 +22,7 @@ import json
import logging
import os
import textwrap
from contextlib import redirect_stderr, redirect_stdout
from airflow import DAG, AirflowException, LoggingMixin, conf, jobs, settings
from airflow.executors import get_default_executor
@ -29,7 +30,7 @@ from airflow.models import DagPickle, TaskInstance
from airflow.ti_deps.dep_context import SCHEDULER_QUEUED_DEPS, DepContext
from airflow.utils import cli as cli_utils, db
from airflow.utils.cli import get_dag, get_dags
from airflow.utils.log.logging_mixin import redirect_stderr, redirect_stdout
from airflow.utils.log.logging_mixin import StreamLogWriter
from airflow.utils.net import get_hostname
@ -131,7 +132,8 @@ def task_run(args, dag=None):
if args.interactive:
_run(args, dag, ti)
else:
with redirect_stdout(ti.log, logging.INFO), redirect_stderr(ti.log, logging.WARN):
with redirect_stdout(StreamLogWriter(ti.log, logging.INFO)), \
redirect_stderr(StreamLogWriter(ti.log, logging.WARN)):
_run(args, dag, ti)
logging.shutdown()

Просмотреть файл

@ -26,6 +26,7 @@ import sys
import threading
import time
from collections import defaultdict
from contextlib import redirect_stderr, redirect_stdout
from datetime import timedelta
from time import sleep
from typing import List, Set
@ -122,45 +123,40 @@ class DagFileProcessor(AbstractDagFileProcessor, LoggingMixin):
# This helper runs in the newly created process
log = logging.getLogger("airflow.processor")
stdout = StreamLogWriter(log, logging.INFO)
stderr = StreamLogWriter(log, logging.WARN)
set_context(log, file_path)
setproctitle("airflow scheduler - DagFileProcessor {}".format(file_path))
try:
# redirect stdout/stderr to log
sys.stdout = stdout
sys.stderr = stderr
with redirect_stdout(StreamLogWriter(log, logging.INFO)),\
redirect_stderr(StreamLogWriter(log, logging.WARN)):
# Re-configure the ORM engine as there are issues with multiple processes
settings.configure_orm()
# Re-configure the ORM engine as there are issues with multiple processes
settings.configure_orm()
# Change the thread name to differentiate log lines. This is
# really a separate process, but changing the name of the
# process doesn't work, so changing the thread name instead.
threading.current_thread().name = thread_name
start_time = time.time()
# Change the thread name to differentiate log lines. This is
# really a separate process, but changing the name of the
# process doesn't work, so changing the thread name instead.
threading.current_thread().name = thread_name
start_time = time.time()
log.info("Started process (PID=%s) to work on %s",
os.getpid(), file_path)
scheduler_job = SchedulerJob(dag_ids=dag_id_white_list, log=log)
result = scheduler_job.process_file(file_path,
zombies,
pickle_dags)
result_channel.send(result)
end_time = time.time()
log.info(
"Processing %s took %.3f seconds", file_path, end_time - start_time
)
log.info("Started process (PID=%s) to work on %s",
os.getpid(), file_path)
scheduler_job = SchedulerJob(dag_ids=dag_id_white_list, log=log)
result = scheduler_job.process_file(file_path,
zombies,
pickle_dags)
result_channel.send(result)
end_time = time.time()
log.info(
"Processing %s took %.3f seconds", file_path, end_time - start_time
)
except Exception:
# Log exceptions through the logging framework.
log.exception("Got an exception! Propagating...")
raise
finally:
result_channel.close()
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__
# We re-initialized the ORM within this Process above so we need to
# tear it down manually here
settings.dispose_orm()

Просмотреть файл

@ -19,7 +19,6 @@
import logging
import re
import sys
from contextlib import contextmanager
from logging import Handler, Logger, StreamHandler
# 7-bit C1 ANSI escape sequences
@ -144,26 +143,6 @@ class RedirectStdHandler(StreamHandler):
return sys.stdout
@contextmanager
def redirect_stdout(logger, level):
writer = StreamLogWriter(logger, level)
try:
sys.stdout = writer
yield
finally:
sys.stdout = sys.__stdout__
@contextmanager
def redirect_stderr(logger, level):
writer = StreamLogWriter(logger, level)
try:
sys.stderr = writer
yield
finally:
sys.stderr = sys.__stderr__
def set_context(logger, value):
"""
Walks the tree of loggers and tries to set the context for each handler

Просмотреть файл

@ -20,7 +20,7 @@ import re
import subprocess
import tempfile
import unittest
from unittest import mock
from contextlib import redirect_stdout
from airflow import settings
from airflow.bin import cli
@ -35,9 +35,9 @@ class TestCliConnections(unittest.TestCase):
cls.parser = cli.CLIFactory.get_parser()
def test_cli_connections_list(self):
with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout:
with redirect_stdout(io.StringIO()) as stdout:
connection_command.connections_list(self.parser.parse_args(['connections', 'list']))
stdout = mock_stdout.getvalue()
stdout = stdout.getvalue()
conns = [[x.strip("'") for x in re.findall(r"'\w+'", line)[:2]]
for ii, line in enumerate(stdout.split('\n'))
if ii % 2 == 1]
@ -71,7 +71,7 @@ class TestCliConnections(unittest.TestCase):
db.resetdb()
# Add connections:
uri = 'postgresql://airflow:airflow@host:5432/airflow'
with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout:
with redirect_stdout(io.StringIO()) as stdout:
connection_command.connections_add(self.parser.parse_args(
['connections', 'add', 'new1',
'--conn_uri=%s' % uri]))
@ -92,7 +92,7 @@ class TestCliConnections(unittest.TestCase):
connection_command.connections_add(self.parser.parse_args(
['connections', 'add', 'new6',
'--conn_uri', "", '--conn_type=google_cloud_platform', '--conn_extra', "{'extra': 'yes'}"]))
stdout = mock_stdout.getvalue()
stdout = stdout.getvalue()
# Check addition stdout
lines = [l for l in stdout.split('\n') if len(l) > 0]
@ -112,11 +112,11 @@ class TestCliConnections(unittest.TestCase):
])
# Attempt to add duplicate
with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout:
with redirect_stdout(io.StringIO()) as stdout:
connection_command.connections_add(self.parser.parse_args(
['connections', 'add', 'new1',
'--conn_uri=%s' % uri]))
stdout = mock_stdout.getvalue()
stdout = stdout.getvalue()
# Check stdout for addition attempt
lines = [l for l in stdout.split('\n') if len(l) > 0]
@ -161,7 +161,7 @@ class TestCliConnections(unittest.TestCase):
None, None, "{'extra': 'yes'}"))
# Delete connections
with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout:
with redirect_stdout(io.StringIO()) as stdout:
connection_command.connections_delete(self.parser.parse_args(
['connections', 'delete', 'new1']))
connection_command.connections_delete(self.parser.parse_args(
@ -174,7 +174,7 @@ class TestCliConnections(unittest.TestCase):
['connections', 'delete', 'new5']))
connection_command.connections_delete(self.parser.parse_args(
['connections', 'delete', 'new6']))
stdout = mock_stdout.getvalue()
stdout = stdout.getvalue()
# Check deletion stdout
lines = [l for l in stdout.split('\n') if len(l) > 0]
@ -197,10 +197,10 @@ class TestCliConnections(unittest.TestCase):
self.assertTrue(result is None)
# Attempt to delete a non-existing connection
with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout:
with redirect_stdout(io.StringIO()) as stdout:
connection_command.connections_delete(self.parser.parse_args(
['connections', 'delete', 'fake']))
stdout = mock_stdout.getvalue()
stdout = stdout.getvalue()
# Check deletion attempt stdout
lines = [l for l in stdout.split('\n') if len(l) > 0]

Просмотреть файл

@ -138,14 +138,12 @@ class TestCliDags(unittest.TestCase):
mock_run.reset_mock()
dag = self.dagbag.get_dag('example_bash_operator')
with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout:
with contextlib.redirect_stdout(io.StringIO()) as stdout:
dag_command.dag_backfill(self.parser.parse_args([
'dags', 'backfill', 'example_bash_operator', '-t', 'runme_0', '--dry_run',
'-s', DEFAULT_DATE.isoformat()]), dag=dag)
mock_stdout.seek(0, 0)
output = mock_stdout.read()
output = stdout.getvalue()
self.assertIn("Dry run of DAG example_bash_operator on {}\n".format(DEFAULT_DATE.isoformat()), output)
self.assertIn("Task runme_0\n".format(DEFAULT_DATE.isoformat()), output)
@ -179,8 +177,7 @@ class TestCliDags(unittest.TestCase):
mock_run.reset_mock()
def test_show_dag_print(self):
temp_stdout = io.StringIO()
with contextlib.redirect_stdout(temp_stdout):
with contextlib.redirect_stdout(io.StringIO()) as temp_stdout:
dag_command.dag_show(self.parser.parse_args([
'dags', 'show', 'example_bash_operator']))
out = temp_stdout.getvalue()
@ -190,8 +187,7 @@ class TestCliDags(unittest.TestCase):
@mock.patch("airflow.cli.commands.dag_command.render_dag")
def test_show_dag_dave(self, mock_render_dag):
temp_stdout = io.StringIO()
with contextlib.redirect_stdout(temp_stdout):
with contextlib.redirect_stdout(io.StringIO()) as temp_stdout:
dag_command.dag_show(self.parser.parse_args([
'dags', 'show', 'example_bash_operator', '--save', 'awesome.png']
))
@ -206,8 +202,7 @@ class TestCliDags(unittest.TestCase):
def test_show_dag_imgcat(self, mock_render_dag, mock_popen):
mock_render_dag.return_value.pipe.return_value = b"DOT_DATA"
mock_popen.return_value.communicate.return_value = (b"OUT", b"ERR")
temp_stdout = io.StringIO()
with contextlib.redirect_stdout(temp_stdout):
with contextlib.redirect_stdout(io.StringIO()) as temp_stdout:
dag_command.dag_show(self.parser.parse_args([
'dags', 'show', 'example_bash_operator', '--imgcat']
))

Просмотреть файл

@ -19,7 +19,7 @@
#
import io
import unittest
from unittest import mock
from contextlib import redirect_stdout
from airflow import models
from airflow.bin import cli
@ -82,9 +82,9 @@ class TestCliRoles(unittest.TestCase):
self.appbuilder.sm.add_role('FakeTeamA')
self.appbuilder.sm.add_role('FakeTeamB')
with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout:
with redirect_stdout(io.StringIO()) as stdout:
role_command.roles_list(self.parser.parse_args(['roles', 'list']))
stdout = mock_stdout.getvalue()
stdout = stdout.getvalue()
self.assertIn('FakeTeamA', stdout)
self.assertIn('FakeTeamB', stdout)

Просмотреть файл

@ -18,9 +18,9 @@
# under the License.
#
import io
import sys
import unittest
from argparse import Namespace
from contextlib import redirect_stdout
from datetime import datetime, timedelta
from unittest import mock
from unittest.mock import MagicMock
@ -122,16 +122,10 @@ class TestCliTasks(unittest.TestCase):
execution_date=timezone.parse('2018-01-01')
)
saved_stdout = sys.stdout
try:
sys.stdout = out = io.StringIO()
with redirect_stdout(io.StringIO()) as stdout:
task_command.task_test(args)
output = out.getvalue()
# Check that prints, and log messages, are shown
self.assertIn("'example_python_operator__print_the_context__20180101'", output)
finally:
sys.stdout = saved_stdout
# Check that prints, and log messages, are shown
self.assertIn("'example_python_operator__print_the_context__20180101'", stdout.getvalue())
@mock.patch("airflow.cli.commands.task_command.jobs.LocalTaskJob")
def test_run_naive_taskinstance(self, mock_local_job):

Просмотреть файл

@ -20,7 +20,7 @@ import json
import os
import tempfile
import unittest
from unittest import mock
from contextlib import redirect_stdout
from airflow import models
from airflow.bin import cli
@ -100,9 +100,9 @@ class TestCliUsers(unittest.TestCase):
'--use_random_password'
])
user_command.users_create(args)
with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout:
with redirect_stdout(io.StringIO()) as stdout:
user_command.users_list(self.parser.parse_args(['users', 'list']))
stdout = mock_stdout.getvalue()
stdout = stdout.getvalue()
for i in range(0, 3):
self.assertIn('user{}'.format(i), stdout)

Просмотреть файл

@ -17,7 +17,7 @@
import io
import unittest
from unittest import mock
from contextlib import redirect_stdout
import airflow.cli.commands.version_command
from airflow.bin import cli
@ -30,7 +30,6 @@ class TestCliVersion(unittest.TestCase):
cls.parser = cli.CLIFactory.get_parser()
def test_cli_version(self):
with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout:
with redirect_stdout(io.StringIO()) as stdout:
airflow.cli.commands.version_command.version(self.parser.parse_args(['version']))
stdout = mock_stdout.getvalue()
self.assertIn(version, stdout)
self.assertIn(version, stdout.getvalue())

Просмотреть файл

@ -23,8 +23,8 @@ import logging
import os
import re
import unittest
from contextlib import redirect_stdout
from tempfile import NamedTemporaryFile
from unittest import mock
from unittest.mock import patch
import pendulum
@ -908,9 +908,9 @@ class TestDag(unittest.TestCase):
t3 = DummyOperator(task_id="t3")
t1 >> t2 >> t3
with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout:
with redirect_stdout(io.StringIO()) as stdout:
dag.tree_view()
stdout = mock_stdout.getvalue()
stdout = stdout.getvalue()
stdout_lines = stdout.split("\n")
self.assertIn('t1', stdout_lines[0])