[AIRFLOW-1330] Add conn_type argument to CLI when adding connection

Closes #2525 from mrkm4ntr/airflow-1330
This commit is contained in:
Shintaro Murakami 2017-10-19 09:28:09 -07:00 коммит произвёл Chris Riccomini
Родитель b464d23a6d
Коммит 2f107d8a30
2 изменённых файлов: 95 добавлений и 20 удалений

Просмотреть файл

@ -40,6 +40,7 @@ import traceback
import time
import psutil
import re
from urllib.parse import urlunparse
import airflow
from airflow import api
@ -931,11 +932,15 @@ def version(args): # noqa
print(settings.HEADER + " v" + airflow.__version__)
alternative_conn_specs = ['conn_type', 'conn_host',
'conn_login', 'conn_password', 'conn_schema', 'conn_port']
def connections(args):
if args.list:
# Check that no other flags were passed to the command
invalid_args = list()
for arg in ['conn_id', 'conn_uri', 'conn_extra']:
for arg in ['conn_id', 'conn_uri', 'conn_extra'] + alternative_conn_specs:
if getattr(args, arg) is not None:
invalid_args.append(arg)
if invalid_args:
@ -960,7 +965,7 @@ def connections(args):
if args.delete:
# Check that only the `conn_id` arg was passed to the command
invalid_args = list()
for arg in ['conn_uri', 'conn_extra']:
for arg in ['conn_uri', 'conn_extra'] + alternative_conn_specs:
if getattr(args, arg) is not None:
invalid_args.append(arg)
if invalid_args:
@ -1004,16 +1009,32 @@ def connections(args):
if args.add:
# Check that the conn_id and conn_uri args were passed to the command:
missing_args = list()
for arg in ['conn_id', 'conn_uri']:
if getattr(args, arg) is None:
missing_args.append(arg)
invalid_args = list()
if not args.conn_id:
missing_args.append('conn_id')
if args.conn_uri:
for arg in alternative_conn_specs:
if getattr(args, arg) is not None:
invalid_args.append(arg)
elif not args.conn_type:
missing_args.append('conn_uri or conn_type')
if missing_args:
msg = ('\n\tThe following args are required to add a connection:' +
' {missing!r}\n'.format(missing=missing_args))
print(msg)
if invalid_args:
msg = ('\n\tThe following args are not compatible with the ' +
'--add flag and --conn_uri flag: {invalid!r}\n')
msg = msg.format(invalid=invalid_args)
print(msg)
if missing_args or invalid_args:
return
new_conn = Connection(conn_id=args.conn_id, uri=args.conn_uri)
if args.conn_uri:
new_conn = Connection(conn_id=args.conn_id, uri=args.conn_uri)
else:
new_conn = Connection(conn_id=args.conn_id, conn_type=args.conn_type, host=args.conn_host,
login=args.conn_login, password=args.conn_password, schema=args.conn_schema, port=args.conn_port)
if args.conn_extra is not None:
new_conn.set_extra(args.conn_extra)
@ -1024,7 +1045,8 @@ def connections(args):
session.add(new_conn)
session.commit()
msg = '\n\tSuccessfully added `conn_id`={conn_id} : {uri}\n'
msg = msg.format(conn_id=new_conn.conn_id, uri=args.conn_uri)
msg = msg.format(conn_id=new_conn.conn_id, uri=args.conn_uri or urlunparse((args.conn_type, '{login}:{password}@{host}:{port}'.format(
login=args.conn_login or '', password=args.conn_password or '', host=args.conn_host or '', port=args.conn_port or ''), args.conn_schema or '', '', '', '')))
print(msg)
else:
msg = '\n\tA connection with `conn_id`={conn_id} already exists\n'
@ -1420,7 +1442,31 @@ class CLIFactory(object):
type=str),
'conn_uri': Arg(
('--conn_uri',),
help='Connection URI, required to add a connection',
help='Connection URI, required to add a connection without conn_type',
type=str),
'conn_type': Arg(
('--conn_type',),
help='Connection type, required to add a connection without conn_uri',
type=str),
'conn_host': Arg(
('--conn_host',),
help='Connection host, optional when adding a connection',
type=str),
'conn_login': Arg(
('--conn_login',),
help='Connection login, optional when adding a connection',
type=str),
'conn_password': Arg(
('--conn_password',),
help='Connection password, optional when adding a connection',
type=str),
'conn_schema': Arg(
('--conn_schema',),
help='Connection schema, optional when adding a connection',
type=str),
'conn_port': Arg(
('--conn_port',),
help='Connection port, optional when adding a connection',
type=str),
'conn_extra': Arg(
('--conn_extra',),
@ -1558,7 +1604,7 @@ class CLIFactory(object):
'func': connections,
'help': "List/Add/Delete connections",
'args': ('list_connections', 'add_connection', 'delete_connection',
'conn_id', 'conn_uri', 'conn_extra'),
'conn_id', 'conn_uri', 'conn_extra') + tuple(alternative_conn_specs),
},
)
subparsers_dict = {sp['func'].__name__: sp for sp in subparsers}

Просмотреть файл

@ -1136,15 +1136,17 @@ class CliTests(unittest.TestCase):
with mock.patch('sys.stdout',
new_callable=six.StringIO) as mock_stdout:
cli.connections(self.parser.parse_args(
['connections', '--list', '--conn_id=fake',
'--conn_uri=fake-uri']))
['connections', '--list', '--conn_id=fake', '--conn_uri=fake-uri',
'--conn_type=fake-type', '--conn_host=fake_host',
'--conn_login=fake_login', '--conn_password=fake_password',
'--conn_schema=fake_schema', '--conn_port=fake_port', '--conn_extra=fake_extra']))
stdout = mock_stdout.getvalue()
# Check list attempt stdout
lines = [l for l in stdout.split('\n') if len(l) > 0]
self.assertListEqual(lines, [
("\tThe following args are not compatible with the " +
"--list flag: ['conn_id', 'conn_uri']"),
"--list flag: ['conn_id', 'conn_uri', 'conn_extra', 'conn_type', 'conn_host', 'conn_login', 'conn_password', 'conn_schema', 'conn_port']"),
])
def test_cli_connections_add_delete(self):
@ -1164,6 +1166,14 @@ class CliTests(unittest.TestCase):
cli.connections(self.parser.parse_args(
['connections', '-a', '--conn_id=new4',
'--conn_uri=%s' % uri, '--conn_extra', "{'extra': 'yes'}"]))
cli.connections(self.parser.parse_args(
['connections', '--add', '--conn_id=new5',
'--conn_type=hive_metastore', '--conn_login=airflow',
'--conn_password=airflow', '--conn_host=host',
'--conn_port=9083', '--conn_schema=airflow']))
cli.connections(self.parser.parse_args(
['connections', '-a', '--conn_id=new6',
'--conn_uri', "", '--conn_type=google_cloud_platform', '--conn_extra', "{'extra': 'yes'}"]))
stdout = mock_stdout.getvalue()
# Check addition stdout
@ -1177,6 +1187,10 @@ class CliTests(unittest.TestCase):
"postgresql://airflow:airflow@host:5432/airflow"),
("\tSuccessfully added `conn_id`=new4 : " +
"postgresql://airflow:airflow@host:5432/airflow"),
("\tSuccessfully added `conn_id`=new5 : " +
"hive_metastore://airflow:airflow@host:9083/airflow"),
("\tSuccessfully added `conn_id`=new6 : " +
"google_cloud_platform://:@:")
])
# Attempt to add duplicate
@ -1218,7 +1232,7 @@ class CliTests(unittest.TestCase):
lines = [l for l in stdout.split('\n') if len(l) > 0]
self.assertListEqual(lines, [
("\tThe following args are required to add a connection:" +
" ['conn_uri']"),
" ['conn_uri or conn_type']"),
])
# Prepare to add connections
@ -1229,15 +1243,23 @@ class CliTests(unittest.TestCase):
'new4': "{'extra': 'yes'}"}
# Add connections
for conn_id in ['new1', 'new2', 'new3', 'new4']:
for index in range(1, 6):
conn_id = 'new%s' % index
result = (session
.query(models.Connection)
.filter(models.Connection.conn_id == conn_id)
.first())
result = (result.conn_id, result.conn_type, result.host,
result.port, result.get_extra())
self.assertEqual(result, (conn_id, 'postgres', 'host', 5432,
extra[conn_id]))
if conn_id in ['new1', 'new2', 'new3', 'new4']:
self.assertEqual(result, (conn_id, 'postgres', 'host', 5432,
extra[conn_id]))
elif conn_id == 'new5':
self.assertEqual(result, (conn_id, 'hive_metastore', 'host',
9083, None))
elif conn_id == 'new6':
self.assertEqual(result, (conn_id, 'google_cloud_platform',
None, None, "{'extra': 'yes'}"))
# Delete connections
with mock.patch('sys.stdout',
@ -1250,6 +1272,10 @@ class CliTests(unittest.TestCase):
['connections', '--delete', '--conn_id=new3']))
cli.connections(self.parser.parse_args(
['connections', '--delete', '--conn_id=new4']))
cli.connections(self.parser.parse_args(
['connections', '--delete', '--conn_id=new5']))
cli.connections(self.parser.parse_args(
['connections', '--delete', '--conn_id=new6']))
stdout = mock_stdout.getvalue()
# Check deletion stdout
@ -1258,11 +1284,14 @@ class CliTests(unittest.TestCase):
"\tSuccessfully deleted `conn_id`=new1",
"\tSuccessfully deleted `conn_id`=new2",
"\tSuccessfully deleted `conn_id`=new3",
"\tSuccessfully deleted `conn_id`=new4"
"\tSuccessfully deleted `conn_id`=new4",
"\tSuccessfully deleted `conn_id`=new5",
"\tSuccessfully deleted `conn_id`=new6"
])
# Check deletions
for conn_id in ['new1', 'new2', 'new3', 'new4']:
for index in range(1, 7):
conn_id = 'new%s' % index
result = (session
.query(models.Connection)
.filter(models.Connection.conn_id == conn_id)
@ -1288,14 +1317,14 @@ class CliTests(unittest.TestCase):
new_callable=six.StringIO) as mock_stdout:
cli.connections(self.parser.parse_args(
['connections', '--delete', '--conn_id=fake',
'--conn_uri=%s' % uri]))
'--conn_uri=%s' % uri, '--conn_type=fake-type']))
stdout = mock_stdout.getvalue()
# Check deletion attempt stdout
lines = [l for l in stdout.split('\n') if len(l) > 0]
self.assertListEqual(lines, [
("\tThe following args are not compatible with the " +
"--delete flag: ['conn_uri']"),
"--delete flag: ['conn_uri', 'conn_type']"),
])
session.close()