[AIRFLOW-6026] Use contextlib to redirect stderr and stdout (#6624)
This commit is contained in:
Родитель
1d8b8cfcbc
Коммит
da086661f7
30
UPDATING.md
30
UPDATING.md
|
@ -41,6 +41,36 @@ assists users migrating to a new version.
|
|||
|
||||
## Airflow Master
|
||||
|
||||
### Removal of redirect_stdout, redirect_stderr
|
||||
|
||||
Function `redirect_stderr` and `redirect_stdout` from `airflow.utils.log.logging_mixin` module has
|
||||
been deleted because it can be easily replaced by the standard library.
|
||||
The functions of the standard library are more flexible and can be used in larger cases.
|
||||
|
||||
The code below
|
||||
```python
|
||||
import logging
|
||||
|
||||
from airflow.utils.log.logging_mixin import redirect_stderr, redirect_stdout
|
||||
|
||||
logger = logging.getLogger("custom-logger")
|
||||
with redirect_stdout(logger, logging.INFO), redirect_stderr(logger, logging.WARN):
|
||||
print("I love Airflow")
|
||||
```
|
||||
can be replaced by the following code:
|
||||
```python
|
||||
from contextlib import redirect_stdout, redirect_stderr
|
||||
import logging
|
||||
|
||||
from airflow.utils.log.logging_mixin import StreamLogWriter
|
||||
|
||||
logger = logging.getLogger("custom-logger")
|
||||
|
||||
with redirect_stdout(StreamLogWriter(logger, logging.INFO)), \
|
||||
redirect_stderr(StreamLogWriter(logger, logging.WARN)):
|
||||
print("I Love Airflow")
|
||||
```
|
||||
|
||||
### Removal of XCom.get_one()
|
||||
|
||||
This one is supersede by `XCom.get_many().first()` which will return the same result.
|
||||
|
|
|
@ -22,6 +22,7 @@ import json
|
|||
import logging
|
||||
import os
|
||||
import textwrap
|
||||
from contextlib import redirect_stderr, redirect_stdout
|
||||
|
||||
from airflow import DAG, AirflowException, LoggingMixin, conf, jobs, settings
|
||||
from airflow.executors import get_default_executor
|
||||
|
@ -29,7 +30,7 @@ from airflow.models import DagPickle, TaskInstance
|
|||
from airflow.ti_deps.dep_context import SCHEDULER_QUEUED_DEPS, DepContext
|
||||
from airflow.utils import cli as cli_utils, db
|
||||
from airflow.utils.cli import get_dag, get_dags
|
||||
from airflow.utils.log.logging_mixin import redirect_stderr, redirect_stdout
|
||||
from airflow.utils.log.logging_mixin import StreamLogWriter
|
||||
from airflow.utils.net import get_hostname
|
||||
|
||||
|
||||
|
@ -131,7 +132,8 @@ def task_run(args, dag=None):
|
|||
if args.interactive:
|
||||
_run(args, dag, ti)
|
||||
else:
|
||||
with redirect_stdout(ti.log, logging.INFO), redirect_stderr(ti.log, logging.WARN):
|
||||
with redirect_stdout(StreamLogWriter(ti.log, logging.INFO)), \
|
||||
redirect_stderr(StreamLogWriter(ti.log, logging.WARN)):
|
||||
_run(args, dag, ti)
|
||||
logging.shutdown()
|
||||
|
||||
|
|
|
@ -26,6 +26,7 @@ import sys
|
|||
import threading
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from contextlib import redirect_stderr, redirect_stdout
|
||||
from datetime import timedelta
|
||||
from time import sleep
|
||||
from typing import List, Set
|
||||
|
@ -122,45 +123,40 @@ class DagFileProcessor(AbstractDagFileProcessor, LoggingMixin):
|
|||
# This helper runs in the newly created process
|
||||
log = logging.getLogger("airflow.processor")
|
||||
|
||||
stdout = StreamLogWriter(log, logging.INFO)
|
||||
stderr = StreamLogWriter(log, logging.WARN)
|
||||
|
||||
set_context(log, file_path)
|
||||
setproctitle("airflow scheduler - DagFileProcessor {}".format(file_path))
|
||||
|
||||
try:
|
||||
# redirect stdout/stderr to log
|
||||
sys.stdout = stdout
|
||||
sys.stderr = stderr
|
||||
with redirect_stdout(StreamLogWriter(log, logging.INFO)),\
|
||||
redirect_stderr(StreamLogWriter(log, logging.WARN)):
|
||||
|
||||
# Re-configure the ORM engine as there are issues with multiple processes
|
||||
settings.configure_orm()
|
||||
# Re-configure the ORM engine as there are issues with multiple processes
|
||||
settings.configure_orm()
|
||||
|
||||
# Change the thread name to differentiate log lines. This is
|
||||
# really a separate process, but changing the name of the
|
||||
# process doesn't work, so changing the thread name instead.
|
||||
threading.current_thread().name = thread_name
|
||||
start_time = time.time()
|
||||
# Change the thread name to differentiate log lines. This is
|
||||
# really a separate process, but changing the name of the
|
||||
# process doesn't work, so changing the thread name instead.
|
||||
threading.current_thread().name = thread_name
|
||||
start_time = time.time()
|
||||
|
||||
log.info("Started process (PID=%s) to work on %s",
|
||||
os.getpid(), file_path)
|
||||
scheduler_job = SchedulerJob(dag_ids=dag_id_white_list, log=log)
|
||||
result = scheduler_job.process_file(file_path,
|
||||
zombies,
|
||||
pickle_dags)
|
||||
result_channel.send(result)
|
||||
end_time = time.time()
|
||||
log.info(
|
||||
"Processing %s took %.3f seconds", file_path, end_time - start_time
|
||||
)
|
||||
log.info("Started process (PID=%s) to work on %s",
|
||||
os.getpid(), file_path)
|
||||
scheduler_job = SchedulerJob(dag_ids=dag_id_white_list, log=log)
|
||||
result = scheduler_job.process_file(file_path,
|
||||
zombies,
|
||||
pickle_dags)
|
||||
result_channel.send(result)
|
||||
end_time = time.time()
|
||||
log.info(
|
||||
"Processing %s took %.3f seconds", file_path, end_time - start_time
|
||||
)
|
||||
except Exception:
|
||||
# Log exceptions through the logging framework.
|
||||
log.exception("Got an exception! Propagating...")
|
||||
raise
|
||||
finally:
|
||||
result_channel.close()
|
||||
sys.stdout = sys.__stdout__
|
||||
sys.stderr = sys.__stderr__
|
||||
# We re-initialized the ORM within this Process above so we need to
|
||||
# tear it down manually here
|
||||
settings.dispose_orm()
|
||||
|
|
|
@ -19,7 +19,6 @@
|
|||
import logging
|
||||
import re
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
from logging import Handler, Logger, StreamHandler
|
||||
|
||||
# 7-bit C1 ANSI escape sequences
|
||||
|
@ -144,26 +143,6 @@ class RedirectStdHandler(StreamHandler):
|
|||
return sys.stdout
|
||||
|
||||
|
||||
@contextmanager
|
||||
def redirect_stdout(logger, level):
|
||||
writer = StreamLogWriter(logger, level)
|
||||
try:
|
||||
sys.stdout = writer
|
||||
yield
|
||||
finally:
|
||||
sys.stdout = sys.__stdout__
|
||||
|
||||
|
||||
@contextmanager
|
||||
def redirect_stderr(logger, level):
|
||||
writer = StreamLogWriter(logger, level)
|
||||
try:
|
||||
sys.stderr = writer
|
||||
yield
|
||||
finally:
|
||||
sys.stderr = sys.__stderr__
|
||||
|
||||
|
||||
def set_context(logger, value):
|
||||
"""
|
||||
Walks the tree of loggers and tries to set the context for each handler
|
||||
|
|
|
@ -20,7 +20,7 @@ import re
|
|||
import subprocess
|
||||
import tempfile
|
||||
import unittest
|
||||
from unittest import mock
|
||||
from contextlib import redirect_stdout
|
||||
|
||||
from airflow import settings
|
||||
from airflow.bin import cli
|
||||
|
@ -35,9 +35,9 @@ class TestCliConnections(unittest.TestCase):
|
|||
cls.parser = cli.CLIFactory.get_parser()
|
||||
|
||||
def test_cli_connections_list(self):
|
||||
with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout:
|
||||
with redirect_stdout(io.StringIO()) as stdout:
|
||||
connection_command.connections_list(self.parser.parse_args(['connections', 'list']))
|
||||
stdout = mock_stdout.getvalue()
|
||||
stdout = stdout.getvalue()
|
||||
conns = [[x.strip("'") for x in re.findall(r"'\w+'", line)[:2]]
|
||||
for ii, line in enumerate(stdout.split('\n'))
|
||||
if ii % 2 == 1]
|
||||
|
@ -71,7 +71,7 @@ class TestCliConnections(unittest.TestCase):
|
|||
db.resetdb()
|
||||
# Add connections:
|
||||
uri = 'postgresql://airflow:airflow@host:5432/airflow'
|
||||
with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout:
|
||||
with redirect_stdout(io.StringIO()) as stdout:
|
||||
connection_command.connections_add(self.parser.parse_args(
|
||||
['connections', 'add', 'new1',
|
||||
'--conn_uri=%s' % uri]))
|
||||
|
@ -92,7 +92,7 @@ class TestCliConnections(unittest.TestCase):
|
|||
connection_command.connections_add(self.parser.parse_args(
|
||||
['connections', 'add', 'new6',
|
||||
'--conn_uri', "", '--conn_type=google_cloud_platform', '--conn_extra', "{'extra': 'yes'}"]))
|
||||
stdout = mock_stdout.getvalue()
|
||||
stdout = stdout.getvalue()
|
||||
|
||||
# Check addition stdout
|
||||
lines = [l for l in stdout.split('\n') if len(l) > 0]
|
||||
|
@ -112,11 +112,11 @@ class TestCliConnections(unittest.TestCase):
|
|||
])
|
||||
|
||||
# Attempt to add duplicate
|
||||
with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout:
|
||||
with redirect_stdout(io.StringIO()) as stdout:
|
||||
connection_command.connections_add(self.parser.parse_args(
|
||||
['connections', 'add', 'new1',
|
||||
'--conn_uri=%s' % uri]))
|
||||
stdout = mock_stdout.getvalue()
|
||||
stdout = stdout.getvalue()
|
||||
|
||||
# Check stdout for addition attempt
|
||||
lines = [l for l in stdout.split('\n') if len(l) > 0]
|
||||
|
@ -161,7 +161,7 @@ class TestCliConnections(unittest.TestCase):
|
|||
None, None, "{'extra': 'yes'}"))
|
||||
|
||||
# Delete connections
|
||||
with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout:
|
||||
with redirect_stdout(io.StringIO()) as stdout:
|
||||
connection_command.connections_delete(self.parser.parse_args(
|
||||
['connections', 'delete', 'new1']))
|
||||
connection_command.connections_delete(self.parser.parse_args(
|
||||
|
@ -174,7 +174,7 @@ class TestCliConnections(unittest.TestCase):
|
|||
['connections', 'delete', 'new5']))
|
||||
connection_command.connections_delete(self.parser.parse_args(
|
||||
['connections', 'delete', 'new6']))
|
||||
stdout = mock_stdout.getvalue()
|
||||
stdout = stdout.getvalue()
|
||||
|
||||
# Check deletion stdout
|
||||
lines = [l for l in stdout.split('\n') if len(l) > 0]
|
||||
|
@ -197,10 +197,10 @@ class TestCliConnections(unittest.TestCase):
|
|||
self.assertTrue(result is None)
|
||||
|
||||
# Attempt to delete a non-existing connection
|
||||
with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout:
|
||||
with redirect_stdout(io.StringIO()) as stdout:
|
||||
connection_command.connections_delete(self.parser.parse_args(
|
||||
['connections', 'delete', 'fake']))
|
||||
stdout = mock_stdout.getvalue()
|
||||
stdout = stdout.getvalue()
|
||||
|
||||
# Check deletion attempt stdout
|
||||
lines = [l for l in stdout.split('\n') if len(l) > 0]
|
||||
|
|
|
@ -138,14 +138,12 @@ class TestCliDags(unittest.TestCase):
|
|||
mock_run.reset_mock()
|
||||
dag = self.dagbag.get_dag('example_bash_operator')
|
||||
|
||||
with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout:
|
||||
with contextlib.redirect_stdout(io.StringIO()) as stdout:
|
||||
dag_command.dag_backfill(self.parser.parse_args([
|
||||
'dags', 'backfill', 'example_bash_operator', '-t', 'runme_0', '--dry_run',
|
||||
'-s', DEFAULT_DATE.isoformat()]), dag=dag)
|
||||
|
||||
mock_stdout.seek(0, 0)
|
||||
|
||||
output = mock_stdout.read()
|
||||
output = stdout.getvalue()
|
||||
self.assertIn("Dry run of DAG example_bash_operator on {}\n".format(DEFAULT_DATE.isoformat()), output)
|
||||
self.assertIn("Task runme_0\n".format(DEFAULT_DATE.isoformat()), output)
|
||||
|
||||
|
@ -179,8 +177,7 @@ class TestCliDags(unittest.TestCase):
|
|||
mock_run.reset_mock()
|
||||
|
||||
def test_show_dag_print(self):
|
||||
temp_stdout = io.StringIO()
|
||||
with contextlib.redirect_stdout(temp_stdout):
|
||||
with contextlib.redirect_stdout(io.StringIO()) as temp_stdout:
|
||||
dag_command.dag_show(self.parser.parse_args([
|
||||
'dags', 'show', 'example_bash_operator']))
|
||||
out = temp_stdout.getvalue()
|
||||
|
@ -190,8 +187,7 @@ class TestCliDags(unittest.TestCase):
|
|||
|
||||
@mock.patch("airflow.cli.commands.dag_command.render_dag")
|
||||
def test_show_dag_dave(self, mock_render_dag):
|
||||
temp_stdout = io.StringIO()
|
||||
with contextlib.redirect_stdout(temp_stdout):
|
||||
with contextlib.redirect_stdout(io.StringIO()) as temp_stdout:
|
||||
dag_command.dag_show(self.parser.parse_args([
|
||||
'dags', 'show', 'example_bash_operator', '--save', 'awesome.png']
|
||||
))
|
||||
|
@ -206,8 +202,7 @@ class TestCliDags(unittest.TestCase):
|
|||
def test_show_dag_imgcat(self, mock_render_dag, mock_popen):
|
||||
mock_render_dag.return_value.pipe.return_value = b"DOT_DATA"
|
||||
mock_popen.return_value.communicate.return_value = (b"OUT", b"ERR")
|
||||
temp_stdout = io.StringIO()
|
||||
with contextlib.redirect_stdout(temp_stdout):
|
||||
with contextlib.redirect_stdout(io.StringIO()) as temp_stdout:
|
||||
dag_command.dag_show(self.parser.parse_args([
|
||||
'dags', 'show', 'example_bash_operator', '--imgcat']
|
||||
))
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
#
|
||||
import io
|
||||
import unittest
|
||||
from unittest import mock
|
||||
from contextlib import redirect_stdout
|
||||
|
||||
from airflow import models
|
||||
from airflow.bin import cli
|
||||
|
@ -82,9 +82,9 @@ class TestCliRoles(unittest.TestCase):
|
|||
self.appbuilder.sm.add_role('FakeTeamA')
|
||||
self.appbuilder.sm.add_role('FakeTeamB')
|
||||
|
||||
with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout:
|
||||
with redirect_stdout(io.StringIO()) as stdout:
|
||||
role_command.roles_list(self.parser.parse_args(['roles', 'list']))
|
||||
stdout = mock_stdout.getvalue()
|
||||
stdout = stdout.getvalue()
|
||||
|
||||
self.assertIn('FakeTeamA', stdout)
|
||||
self.assertIn('FakeTeamB', stdout)
|
||||
|
|
|
@ -18,9 +18,9 @@
|
|||
# under the License.
|
||||
#
|
||||
import io
|
||||
import sys
|
||||
import unittest
|
||||
from argparse import Namespace
|
||||
from contextlib import redirect_stdout
|
||||
from datetime import datetime, timedelta
|
||||
from unittest import mock
|
||||
from unittest.mock import MagicMock
|
||||
|
@ -122,16 +122,10 @@ class TestCliTasks(unittest.TestCase):
|
|||
execution_date=timezone.parse('2018-01-01')
|
||||
)
|
||||
|
||||
saved_stdout = sys.stdout
|
||||
try:
|
||||
sys.stdout = out = io.StringIO()
|
||||
with redirect_stdout(io.StringIO()) as stdout:
|
||||
task_command.task_test(args)
|
||||
|
||||
output = out.getvalue()
|
||||
# Check that prints, and log messages, are shown
|
||||
self.assertIn("'example_python_operator__print_the_context__20180101'", output)
|
||||
finally:
|
||||
sys.stdout = saved_stdout
|
||||
# Check that prints, and log messages, are shown
|
||||
self.assertIn("'example_python_operator__print_the_context__20180101'", stdout.getvalue())
|
||||
|
||||
@mock.patch("airflow.cli.commands.task_command.jobs.LocalTaskJob")
|
||||
def test_run_naive_taskinstance(self, mock_local_job):
|
||||
|
|
|
@ -20,7 +20,7 @@ import json
|
|||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from unittest import mock
|
||||
from contextlib import redirect_stdout
|
||||
|
||||
from airflow import models
|
||||
from airflow.bin import cli
|
||||
|
@ -100,9 +100,9 @@ class TestCliUsers(unittest.TestCase):
|
|||
'--use_random_password'
|
||||
])
|
||||
user_command.users_create(args)
|
||||
with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout:
|
||||
with redirect_stdout(io.StringIO()) as stdout:
|
||||
user_command.users_list(self.parser.parse_args(['users', 'list']))
|
||||
stdout = mock_stdout.getvalue()
|
||||
stdout = stdout.getvalue()
|
||||
for i in range(0, 3):
|
||||
self.assertIn('user{}'.format(i), stdout)
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
|
||||
import io
|
||||
import unittest
|
||||
from unittest import mock
|
||||
from contextlib import redirect_stdout
|
||||
|
||||
import airflow.cli.commands.version_command
|
||||
from airflow.bin import cli
|
||||
|
@ -30,7 +30,6 @@ class TestCliVersion(unittest.TestCase):
|
|||
cls.parser = cli.CLIFactory.get_parser()
|
||||
|
||||
def test_cli_version(self):
|
||||
with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout:
|
||||
with redirect_stdout(io.StringIO()) as stdout:
|
||||
airflow.cli.commands.version_command.version(self.parser.parse_args(['version']))
|
||||
stdout = mock_stdout.getvalue()
|
||||
self.assertIn(version, stdout)
|
||||
self.assertIn(version, stdout.getvalue())
|
||||
|
|
|
@ -23,8 +23,8 @@ import logging
|
|||
import os
|
||||
import re
|
||||
import unittest
|
||||
from contextlib import redirect_stdout
|
||||
from tempfile import NamedTemporaryFile
|
||||
from unittest import mock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pendulum
|
||||
|
@ -908,9 +908,9 @@ class TestDag(unittest.TestCase):
|
|||
t3 = DummyOperator(task_id="t3")
|
||||
t1 >> t2 >> t3
|
||||
|
||||
with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout:
|
||||
with redirect_stdout(io.StringIO()) as stdout:
|
||||
dag.tree_view()
|
||||
stdout = mock_stdout.getvalue()
|
||||
stdout = stdout.getvalue()
|
||||
|
||||
stdout_lines = stdout.split("\n")
|
||||
self.assertIn('t1', stdout_lines[0])
|
||||
|
|
Загрузка…
Ссылка в новой задаче