[AIRFLOW-6026] Use contextlib to redirect stderr and stdout (#6624)
This commit is contained in:
Родитель
1d8b8cfcbc
Коммит
da086661f7
30
UPDATING.md
30
UPDATING.md
|
@ -41,6 +41,36 @@ assists users migrating to a new version.
|
||||||
|
|
||||||
## Airflow Master
|
## Airflow Master
|
||||||
|
|
||||||
|
### Removal of redirect_stdout, redirect_stderr
|
||||||
|
|
||||||
|
Function `redirect_stderr` and `redirect_stdout` from `airflow.utils.log.logging_mixin` module has
|
||||||
|
been deleted because it can be easily replaced by the standard library.
|
||||||
|
The functions of the standard library are more flexible and can be used in larger cases.
|
||||||
|
|
||||||
|
The code below
|
||||||
|
```python
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from airflow.utils.log.logging_mixin import redirect_stderr, redirect_stdout
|
||||||
|
|
||||||
|
logger = logging.getLogger("custom-logger")
|
||||||
|
with redirect_stdout(logger, logging.INFO), redirect_stderr(logger, logging.WARN):
|
||||||
|
print("I love Airflow")
|
||||||
|
```
|
||||||
|
can be replaced by the following code:
|
||||||
|
```python
|
||||||
|
from contextlib import redirect_stdout, redirect_stderr
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from airflow.utils.log.logging_mixin import StreamLogWriter
|
||||||
|
|
||||||
|
logger = logging.getLogger("custom-logger")
|
||||||
|
|
||||||
|
with redirect_stdout(StreamLogWriter(logger, logging.INFO)), \
|
||||||
|
redirect_stderr(StreamLogWriter(logger, logging.WARN)):
|
||||||
|
print("I Love Airflow")
|
||||||
|
```
|
||||||
|
|
||||||
### Removal of XCom.get_one()
|
### Removal of XCom.get_one()
|
||||||
|
|
||||||
This one is supersede by `XCom.get_many().first()` which will return the same result.
|
This one is supersede by `XCom.get_many().first()` which will return the same result.
|
||||||
|
|
|
@ -22,6 +22,7 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import textwrap
|
import textwrap
|
||||||
|
from contextlib import redirect_stderr, redirect_stdout
|
||||||
|
|
||||||
from airflow import DAG, AirflowException, LoggingMixin, conf, jobs, settings
|
from airflow import DAG, AirflowException, LoggingMixin, conf, jobs, settings
|
||||||
from airflow.executors import get_default_executor
|
from airflow.executors import get_default_executor
|
||||||
|
@ -29,7 +30,7 @@ from airflow.models import DagPickle, TaskInstance
|
||||||
from airflow.ti_deps.dep_context import SCHEDULER_QUEUED_DEPS, DepContext
|
from airflow.ti_deps.dep_context import SCHEDULER_QUEUED_DEPS, DepContext
|
||||||
from airflow.utils import cli as cli_utils, db
|
from airflow.utils import cli as cli_utils, db
|
||||||
from airflow.utils.cli import get_dag, get_dags
|
from airflow.utils.cli import get_dag, get_dags
|
||||||
from airflow.utils.log.logging_mixin import redirect_stderr, redirect_stdout
|
from airflow.utils.log.logging_mixin import StreamLogWriter
|
||||||
from airflow.utils.net import get_hostname
|
from airflow.utils.net import get_hostname
|
||||||
|
|
||||||
|
|
||||||
|
@ -131,7 +132,8 @@ def task_run(args, dag=None):
|
||||||
if args.interactive:
|
if args.interactive:
|
||||||
_run(args, dag, ti)
|
_run(args, dag, ti)
|
||||||
else:
|
else:
|
||||||
with redirect_stdout(ti.log, logging.INFO), redirect_stderr(ti.log, logging.WARN):
|
with redirect_stdout(StreamLogWriter(ti.log, logging.INFO)), \
|
||||||
|
redirect_stderr(StreamLogWriter(ti.log, logging.WARN)):
|
||||||
_run(args, dag, ti)
|
_run(args, dag, ti)
|
||||||
logging.shutdown()
|
logging.shutdown()
|
||||||
|
|
||||||
|
|
|
@ -26,6 +26,7 @@ import sys
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from contextlib import redirect_stderr, redirect_stdout
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from time import sleep
|
from time import sleep
|
||||||
from typing import List, Set
|
from typing import List, Set
|
||||||
|
@ -122,45 +123,40 @@ class DagFileProcessor(AbstractDagFileProcessor, LoggingMixin):
|
||||||
# This helper runs in the newly created process
|
# This helper runs in the newly created process
|
||||||
log = logging.getLogger("airflow.processor")
|
log = logging.getLogger("airflow.processor")
|
||||||
|
|
||||||
stdout = StreamLogWriter(log, logging.INFO)
|
|
||||||
stderr = StreamLogWriter(log, logging.WARN)
|
|
||||||
|
|
||||||
set_context(log, file_path)
|
set_context(log, file_path)
|
||||||
setproctitle("airflow scheduler - DagFileProcessor {}".format(file_path))
|
setproctitle("airflow scheduler - DagFileProcessor {}".format(file_path))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# redirect stdout/stderr to log
|
# redirect stdout/stderr to log
|
||||||
sys.stdout = stdout
|
with redirect_stdout(StreamLogWriter(log, logging.INFO)),\
|
||||||
sys.stderr = stderr
|
redirect_stderr(StreamLogWriter(log, logging.WARN)):
|
||||||
|
|
||||||
# Re-configure the ORM engine as there are issues with multiple processes
|
# Re-configure the ORM engine as there are issues with multiple processes
|
||||||
settings.configure_orm()
|
settings.configure_orm()
|
||||||
|
|
||||||
# Change the thread name to differentiate log lines. This is
|
# Change the thread name to differentiate log lines. This is
|
||||||
# really a separate process, but changing the name of the
|
# really a separate process, but changing the name of the
|
||||||
# process doesn't work, so changing the thread name instead.
|
# process doesn't work, so changing the thread name instead.
|
||||||
threading.current_thread().name = thread_name
|
threading.current_thread().name = thread_name
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
log.info("Started process (PID=%s) to work on %s",
|
log.info("Started process (PID=%s) to work on %s",
|
||||||
os.getpid(), file_path)
|
os.getpid(), file_path)
|
||||||
scheduler_job = SchedulerJob(dag_ids=dag_id_white_list, log=log)
|
scheduler_job = SchedulerJob(dag_ids=dag_id_white_list, log=log)
|
||||||
result = scheduler_job.process_file(file_path,
|
result = scheduler_job.process_file(file_path,
|
||||||
zombies,
|
zombies,
|
||||||
pickle_dags)
|
pickle_dags)
|
||||||
result_channel.send(result)
|
result_channel.send(result)
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
log.info(
|
log.info(
|
||||||
"Processing %s took %.3f seconds", file_path, end_time - start_time
|
"Processing %s took %.3f seconds", file_path, end_time - start_time
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
# Log exceptions through the logging framework.
|
# Log exceptions through the logging framework.
|
||||||
log.exception("Got an exception! Propagating...")
|
log.exception("Got an exception! Propagating...")
|
||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
result_channel.close()
|
result_channel.close()
|
||||||
sys.stdout = sys.__stdout__
|
|
||||||
sys.stderr = sys.__stderr__
|
|
||||||
# We re-initialized the ORM within this Process above so we need to
|
# We re-initialized the ORM within this Process above so we need to
|
||||||
# tear it down manually here
|
# tear it down manually here
|
||||||
settings.dispose_orm()
|
settings.dispose_orm()
|
||||||
|
|
|
@ -19,7 +19,6 @@
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
from contextlib import contextmanager
|
|
||||||
from logging import Handler, Logger, StreamHandler
|
from logging import Handler, Logger, StreamHandler
|
||||||
|
|
||||||
# 7-bit C1 ANSI escape sequences
|
# 7-bit C1 ANSI escape sequences
|
||||||
|
@ -144,26 +143,6 @@ class RedirectStdHandler(StreamHandler):
|
||||||
return sys.stdout
|
return sys.stdout
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def redirect_stdout(logger, level):
|
|
||||||
writer = StreamLogWriter(logger, level)
|
|
||||||
try:
|
|
||||||
sys.stdout = writer
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
sys.stdout = sys.__stdout__
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def redirect_stderr(logger, level):
|
|
||||||
writer = StreamLogWriter(logger, level)
|
|
||||||
try:
|
|
||||||
sys.stderr = writer
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
sys.stderr = sys.__stderr__
|
|
||||||
|
|
||||||
|
|
||||||
def set_context(logger, value):
|
def set_context(logger, value):
|
||||||
"""
|
"""
|
||||||
Walks the tree of loggers and tries to set the context for each handler
|
Walks the tree of loggers and tries to set the context for each handler
|
||||||
|
|
|
@ -20,7 +20,7 @@ import re
|
||||||
import subprocess
|
import subprocess
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
from unittest import mock
|
from contextlib import redirect_stdout
|
||||||
|
|
||||||
from airflow import settings
|
from airflow import settings
|
||||||
from airflow.bin import cli
|
from airflow.bin import cli
|
||||||
|
@ -35,9 +35,9 @@ class TestCliConnections(unittest.TestCase):
|
||||||
cls.parser = cli.CLIFactory.get_parser()
|
cls.parser = cli.CLIFactory.get_parser()
|
||||||
|
|
||||||
def test_cli_connections_list(self):
|
def test_cli_connections_list(self):
|
||||||
with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout:
|
with redirect_stdout(io.StringIO()) as stdout:
|
||||||
connection_command.connections_list(self.parser.parse_args(['connections', 'list']))
|
connection_command.connections_list(self.parser.parse_args(['connections', 'list']))
|
||||||
stdout = mock_stdout.getvalue()
|
stdout = stdout.getvalue()
|
||||||
conns = [[x.strip("'") for x in re.findall(r"'\w+'", line)[:2]]
|
conns = [[x.strip("'") for x in re.findall(r"'\w+'", line)[:2]]
|
||||||
for ii, line in enumerate(stdout.split('\n'))
|
for ii, line in enumerate(stdout.split('\n'))
|
||||||
if ii % 2 == 1]
|
if ii % 2 == 1]
|
||||||
|
@ -71,7 +71,7 @@ class TestCliConnections(unittest.TestCase):
|
||||||
db.resetdb()
|
db.resetdb()
|
||||||
# Add connections:
|
# Add connections:
|
||||||
uri = 'postgresql://airflow:airflow@host:5432/airflow'
|
uri = 'postgresql://airflow:airflow@host:5432/airflow'
|
||||||
with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout:
|
with redirect_stdout(io.StringIO()) as stdout:
|
||||||
connection_command.connections_add(self.parser.parse_args(
|
connection_command.connections_add(self.parser.parse_args(
|
||||||
['connections', 'add', 'new1',
|
['connections', 'add', 'new1',
|
||||||
'--conn_uri=%s' % uri]))
|
'--conn_uri=%s' % uri]))
|
||||||
|
@ -92,7 +92,7 @@ class TestCliConnections(unittest.TestCase):
|
||||||
connection_command.connections_add(self.parser.parse_args(
|
connection_command.connections_add(self.parser.parse_args(
|
||||||
['connections', 'add', 'new6',
|
['connections', 'add', 'new6',
|
||||||
'--conn_uri', "", '--conn_type=google_cloud_platform', '--conn_extra', "{'extra': 'yes'}"]))
|
'--conn_uri', "", '--conn_type=google_cloud_platform', '--conn_extra', "{'extra': 'yes'}"]))
|
||||||
stdout = mock_stdout.getvalue()
|
stdout = stdout.getvalue()
|
||||||
|
|
||||||
# Check addition stdout
|
# Check addition stdout
|
||||||
lines = [l for l in stdout.split('\n') if len(l) > 0]
|
lines = [l for l in stdout.split('\n') if len(l) > 0]
|
||||||
|
@ -112,11 +112,11 @@ class TestCliConnections(unittest.TestCase):
|
||||||
])
|
])
|
||||||
|
|
||||||
# Attempt to add duplicate
|
# Attempt to add duplicate
|
||||||
with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout:
|
with redirect_stdout(io.StringIO()) as stdout:
|
||||||
connection_command.connections_add(self.parser.parse_args(
|
connection_command.connections_add(self.parser.parse_args(
|
||||||
['connections', 'add', 'new1',
|
['connections', 'add', 'new1',
|
||||||
'--conn_uri=%s' % uri]))
|
'--conn_uri=%s' % uri]))
|
||||||
stdout = mock_stdout.getvalue()
|
stdout = stdout.getvalue()
|
||||||
|
|
||||||
# Check stdout for addition attempt
|
# Check stdout for addition attempt
|
||||||
lines = [l for l in stdout.split('\n') if len(l) > 0]
|
lines = [l for l in stdout.split('\n') if len(l) > 0]
|
||||||
|
@ -161,7 +161,7 @@ class TestCliConnections(unittest.TestCase):
|
||||||
None, None, "{'extra': 'yes'}"))
|
None, None, "{'extra': 'yes'}"))
|
||||||
|
|
||||||
# Delete connections
|
# Delete connections
|
||||||
with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout:
|
with redirect_stdout(io.StringIO()) as stdout:
|
||||||
connection_command.connections_delete(self.parser.parse_args(
|
connection_command.connections_delete(self.parser.parse_args(
|
||||||
['connections', 'delete', 'new1']))
|
['connections', 'delete', 'new1']))
|
||||||
connection_command.connections_delete(self.parser.parse_args(
|
connection_command.connections_delete(self.parser.parse_args(
|
||||||
|
@ -174,7 +174,7 @@ class TestCliConnections(unittest.TestCase):
|
||||||
['connections', 'delete', 'new5']))
|
['connections', 'delete', 'new5']))
|
||||||
connection_command.connections_delete(self.parser.parse_args(
|
connection_command.connections_delete(self.parser.parse_args(
|
||||||
['connections', 'delete', 'new6']))
|
['connections', 'delete', 'new6']))
|
||||||
stdout = mock_stdout.getvalue()
|
stdout = stdout.getvalue()
|
||||||
|
|
||||||
# Check deletion stdout
|
# Check deletion stdout
|
||||||
lines = [l for l in stdout.split('\n') if len(l) > 0]
|
lines = [l for l in stdout.split('\n') if len(l) > 0]
|
||||||
|
@ -197,10 +197,10 @@ class TestCliConnections(unittest.TestCase):
|
||||||
self.assertTrue(result is None)
|
self.assertTrue(result is None)
|
||||||
|
|
||||||
# Attempt to delete a non-existing connection
|
# Attempt to delete a non-existing connection
|
||||||
with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout:
|
with redirect_stdout(io.StringIO()) as stdout:
|
||||||
connection_command.connections_delete(self.parser.parse_args(
|
connection_command.connections_delete(self.parser.parse_args(
|
||||||
['connections', 'delete', 'fake']))
|
['connections', 'delete', 'fake']))
|
||||||
stdout = mock_stdout.getvalue()
|
stdout = stdout.getvalue()
|
||||||
|
|
||||||
# Check deletion attempt stdout
|
# Check deletion attempt stdout
|
||||||
lines = [l for l in stdout.split('\n') if len(l) > 0]
|
lines = [l for l in stdout.split('\n') if len(l) > 0]
|
||||||
|
|
|
@ -138,14 +138,12 @@ class TestCliDags(unittest.TestCase):
|
||||||
mock_run.reset_mock()
|
mock_run.reset_mock()
|
||||||
dag = self.dagbag.get_dag('example_bash_operator')
|
dag = self.dagbag.get_dag('example_bash_operator')
|
||||||
|
|
||||||
with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout:
|
with contextlib.redirect_stdout(io.StringIO()) as stdout:
|
||||||
dag_command.dag_backfill(self.parser.parse_args([
|
dag_command.dag_backfill(self.parser.parse_args([
|
||||||
'dags', 'backfill', 'example_bash_operator', '-t', 'runme_0', '--dry_run',
|
'dags', 'backfill', 'example_bash_operator', '-t', 'runme_0', '--dry_run',
|
||||||
'-s', DEFAULT_DATE.isoformat()]), dag=dag)
|
'-s', DEFAULT_DATE.isoformat()]), dag=dag)
|
||||||
|
|
||||||
mock_stdout.seek(0, 0)
|
output = stdout.getvalue()
|
||||||
|
|
||||||
output = mock_stdout.read()
|
|
||||||
self.assertIn("Dry run of DAG example_bash_operator on {}\n".format(DEFAULT_DATE.isoformat()), output)
|
self.assertIn("Dry run of DAG example_bash_operator on {}\n".format(DEFAULT_DATE.isoformat()), output)
|
||||||
self.assertIn("Task runme_0\n".format(DEFAULT_DATE.isoformat()), output)
|
self.assertIn("Task runme_0\n".format(DEFAULT_DATE.isoformat()), output)
|
||||||
|
|
||||||
|
@ -179,8 +177,7 @@ class TestCliDags(unittest.TestCase):
|
||||||
mock_run.reset_mock()
|
mock_run.reset_mock()
|
||||||
|
|
||||||
def test_show_dag_print(self):
|
def test_show_dag_print(self):
|
||||||
temp_stdout = io.StringIO()
|
with contextlib.redirect_stdout(io.StringIO()) as temp_stdout:
|
||||||
with contextlib.redirect_stdout(temp_stdout):
|
|
||||||
dag_command.dag_show(self.parser.parse_args([
|
dag_command.dag_show(self.parser.parse_args([
|
||||||
'dags', 'show', 'example_bash_operator']))
|
'dags', 'show', 'example_bash_operator']))
|
||||||
out = temp_stdout.getvalue()
|
out = temp_stdout.getvalue()
|
||||||
|
@ -190,8 +187,7 @@ class TestCliDags(unittest.TestCase):
|
||||||
|
|
||||||
@mock.patch("airflow.cli.commands.dag_command.render_dag")
|
@mock.patch("airflow.cli.commands.dag_command.render_dag")
|
||||||
def test_show_dag_dave(self, mock_render_dag):
|
def test_show_dag_dave(self, mock_render_dag):
|
||||||
temp_stdout = io.StringIO()
|
with contextlib.redirect_stdout(io.StringIO()) as temp_stdout:
|
||||||
with contextlib.redirect_stdout(temp_stdout):
|
|
||||||
dag_command.dag_show(self.parser.parse_args([
|
dag_command.dag_show(self.parser.parse_args([
|
||||||
'dags', 'show', 'example_bash_operator', '--save', 'awesome.png']
|
'dags', 'show', 'example_bash_operator', '--save', 'awesome.png']
|
||||||
))
|
))
|
||||||
|
@ -206,8 +202,7 @@ class TestCliDags(unittest.TestCase):
|
||||||
def test_show_dag_imgcat(self, mock_render_dag, mock_popen):
|
def test_show_dag_imgcat(self, mock_render_dag, mock_popen):
|
||||||
mock_render_dag.return_value.pipe.return_value = b"DOT_DATA"
|
mock_render_dag.return_value.pipe.return_value = b"DOT_DATA"
|
||||||
mock_popen.return_value.communicate.return_value = (b"OUT", b"ERR")
|
mock_popen.return_value.communicate.return_value = (b"OUT", b"ERR")
|
||||||
temp_stdout = io.StringIO()
|
with contextlib.redirect_stdout(io.StringIO()) as temp_stdout:
|
||||||
with contextlib.redirect_stdout(temp_stdout):
|
|
||||||
dag_command.dag_show(self.parser.parse_args([
|
dag_command.dag_show(self.parser.parse_args([
|
||||||
'dags', 'show', 'example_bash_operator', '--imgcat']
|
'dags', 'show', 'example_bash_operator', '--imgcat']
|
||||||
))
|
))
|
||||||
|
|
|
@ -19,7 +19,7 @@
|
||||||
#
|
#
|
||||||
import io
|
import io
|
||||||
import unittest
|
import unittest
|
||||||
from unittest import mock
|
from contextlib import redirect_stdout
|
||||||
|
|
||||||
from airflow import models
|
from airflow import models
|
||||||
from airflow.bin import cli
|
from airflow.bin import cli
|
||||||
|
@ -82,9 +82,9 @@ class TestCliRoles(unittest.TestCase):
|
||||||
self.appbuilder.sm.add_role('FakeTeamA')
|
self.appbuilder.sm.add_role('FakeTeamA')
|
||||||
self.appbuilder.sm.add_role('FakeTeamB')
|
self.appbuilder.sm.add_role('FakeTeamB')
|
||||||
|
|
||||||
with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout:
|
with redirect_stdout(io.StringIO()) as stdout:
|
||||||
role_command.roles_list(self.parser.parse_args(['roles', 'list']))
|
role_command.roles_list(self.parser.parse_args(['roles', 'list']))
|
||||||
stdout = mock_stdout.getvalue()
|
stdout = stdout.getvalue()
|
||||||
|
|
||||||
self.assertIn('FakeTeamA', stdout)
|
self.assertIn('FakeTeamA', stdout)
|
||||||
self.assertIn('FakeTeamB', stdout)
|
self.assertIn('FakeTeamB', stdout)
|
||||||
|
|
|
@ -18,9 +18,9 @@
|
||||||
# under the License.
|
# under the License.
|
||||||
#
|
#
|
||||||
import io
|
import io
|
||||||
import sys
|
|
||||||
import unittest
|
import unittest
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
|
from contextlib import redirect_stdout
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
@ -122,16 +122,10 @@ class TestCliTasks(unittest.TestCase):
|
||||||
execution_date=timezone.parse('2018-01-01')
|
execution_date=timezone.parse('2018-01-01')
|
||||||
)
|
)
|
||||||
|
|
||||||
saved_stdout = sys.stdout
|
with redirect_stdout(io.StringIO()) as stdout:
|
||||||
try:
|
|
||||||
sys.stdout = out = io.StringIO()
|
|
||||||
task_command.task_test(args)
|
task_command.task_test(args)
|
||||||
|
# Check that prints, and log messages, are shown
|
||||||
output = out.getvalue()
|
self.assertIn("'example_python_operator__print_the_context__20180101'", stdout.getvalue())
|
||||||
# Check that prints, and log messages, are shown
|
|
||||||
self.assertIn("'example_python_operator__print_the_context__20180101'", output)
|
|
||||||
finally:
|
|
||||||
sys.stdout = saved_stdout
|
|
||||||
|
|
||||||
@mock.patch("airflow.cli.commands.task_command.jobs.LocalTaskJob")
|
@mock.patch("airflow.cli.commands.task_command.jobs.LocalTaskJob")
|
||||||
def test_run_naive_taskinstance(self, mock_local_job):
|
def test_run_naive_taskinstance(self, mock_local_job):
|
||||||
|
|
|
@ -20,7 +20,7 @@ import json
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
from unittest import mock
|
from contextlib import redirect_stdout
|
||||||
|
|
||||||
from airflow import models
|
from airflow import models
|
||||||
from airflow.bin import cli
|
from airflow.bin import cli
|
||||||
|
@ -100,9 +100,9 @@ class TestCliUsers(unittest.TestCase):
|
||||||
'--use_random_password'
|
'--use_random_password'
|
||||||
])
|
])
|
||||||
user_command.users_create(args)
|
user_command.users_create(args)
|
||||||
with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout:
|
with redirect_stdout(io.StringIO()) as stdout:
|
||||||
user_command.users_list(self.parser.parse_args(['users', 'list']))
|
user_command.users_list(self.parser.parse_args(['users', 'list']))
|
||||||
stdout = mock_stdout.getvalue()
|
stdout = stdout.getvalue()
|
||||||
for i in range(0, 3):
|
for i in range(0, 3):
|
||||||
self.assertIn('user{}'.format(i), stdout)
|
self.assertIn('user{}'.format(i), stdout)
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
|
|
||||||
import io
|
import io
|
||||||
import unittest
|
import unittest
|
||||||
from unittest import mock
|
from contextlib import redirect_stdout
|
||||||
|
|
||||||
import airflow.cli.commands.version_command
|
import airflow.cli.commands.version_command
|
||||||
from airflow.bin import cli
|
from airflow.bin import cli
|
||||||
|
@ -30,7 +30,6 @@ class TestCliVersion(unittest.TestCase):
|
||||||
cls.parser = cli.CLIFactory.get_parser()
|
cls.parser = cli.CLIFactory.get_parser()
|
||||||
|
|
||||||
def test_cli_version(self):
|
def test_cli_version(self):
|
||||||
with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout:
|
with redirect_stdout(io.StringIO()) as stdout:
|
||||||
airflow.cli.commands.version_command.version(self.parser.parse_args(['version']))
|
airflow.cli.commands.version_command.version(self.parser.parse_args(['version']))
|
||||||
stdout = mock_stdout.getvalue()
|
self.assertIn(version, stdout.getvalue())
|
||||||
self.assertIn(version, stdout)
|
|
||||||
|
|
|
@ -23,8 +23,8 @@ import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import unittest
|
import unittest
|
||||||
|
from contextlib import redirect_stdout
|
||||||
from tempfile import NamedTemporaryFile
|
from tempfile import NamedTemporaryFile
|
||||||
from unittest import mock
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pendulum
|
import pendulum
|
||||||
|
@ -908,9 +908,9 @@ class TestDag(unittest.TestCase):
|
||||||
t3 = DummyOperator(task_id="t3")
|
t3 = DummyOperator(task_id="t3")
|
||||||
t1 >> t2 >> t3
|
t1 >> t2 >> t3
|
||||||
|
|
||||||
with mock.patch('sys.stdout', new_callable=io.StringIO) as mock_stdout:
|
with redirect_stdout(io.StringIO()) as stdout:
|
||||||
dag.tree_view()
|
dag.tree_view()
|
||||||
stdout = mock_stdout.getvalue()
|
stdout = stdout.getvalue()
|
||||||
|
|
||||||
stdout_lines = stdout.split("\n")
|
stdout_lines = stdout.split("\n")
|
||||||
self.assertIn('t1', stdout_lines[0])
|
self.assertIn('t1', stdout_lines[0])
|
||||||
|
|
Загрузка…
Ссылка в новой задаче