[AIRFLOW-3672] Add support for Mongo DB DNS Seedlist Connection Format (#4481)
* [AIRFLOW-3672] Add support for Mongo DB DNS Seedlist Connection Format https://docs.mongodb.com/manual/reference/connection-string/index.html#dns-seedlist-connection-format http://api.mongodb.com/python/current/api/pymongo/mongo_client.html#pymongo.mongo_client.MongoClient * [AIRFLOW-3672] Add unit test for srv uri * [AIRFLOW-3672] Fix unit test for Mongo srv uri * [AIRFLOW-3672] Construct MongoDB URI when hook init
This commit is contained in:
Родитель
c6efd01264
Коммит
ac464be88e
|
@ -29,7 +29,11 @@ class MongoHook(BaseHook):
|
|||
https://docs.mongodb.com/manual/reference/connection-string/index.html
|
||||
You can specify connection string options in extra field of your connection
|
||||
https://docs.mongodb.com/manual/reference/connection-string/index.html#connection-string-options
|
||||
ex. ``{replicaSet: test, ssl: True, connectTimeoutMS: 30000}``
|
||||
|
||||
If you want use DNS seedlist, set `srv` to True.
|
||||
|
||||
ex.
|
||||
{"srv": true, "replicaSet": "test", "ssl": true, "connectTimeoutMS": 30000}
|
||||
"""
|
||||
conn_type = 'mongo'
|
||||
|
||||
|
@ -38,9 +42,23 @@ class MongoHook(BaseHook):
|
|||
|
||||
self.mongo_conn_id = conn_id
|
||||
self.connection = self.get_connection(conn_id)
|
||||
self.extras = self.connection.extra_dejson
|
||||
self.extras = self.connection.extra_dejson.copy()
|
||||
self.client = None
|
||||
|
||||
srv = self.extras.pop('srv', False)
|
||||
scheme = 'mongodb+srv' if srv else 'mongodb'
|
||||
|
||||
self.uri = '{scheme}://{creds}{host}{port}/{database}'.format(
|
||||
scheme=scheme,
|
||||
creds='{}:{}@'.format(
|
||||
self.connection.login, self.connection.password
|
||||
) if self.connection.login else '',
|
||||
|
||||
host=self.connection.host,
|
||||
port='' if self.connection.port is None else ':{}'.format(self.connection.port),
|
||||
database=self.connection.schema
|
||||
)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
|
@ -55,18 +73,6 @@ class MongoHook(BaseHook):
|
|||
if self.client is not None:
|
||||
return self.client
|
||||
|
||||
conn = self.connection
|
||||
|
||||
uri = 'mongodb://{creds}{host}{port}/{database}'.format(
|
||||
creds='{}:{}@'.format(
|
||||
conn.login, conn.password
|
||||
) if conn.login else '',
|
||||
|
||||
host=conn.host,
|
||||
port='' if conn.port is None else ':{}'.format(conn.port),
|
||||
database=conn.schema
|
||||
)
|
||||
|
||||
# Mongo Connection Options dict that is unpacked when passed to MongoClient
|
||||
options = self.extras
|
||||
|
||||
|
@ -74,7 +80,7 @@ class MongoHook(BaseHook):
|
|||
if options.get('ssl', False):
|
||||
options.update({'ssl_cert_reqs': CERT_NONE})
|
||||
|
||||
self.client = MongoClient(uri, **options)
|
||||
self.client = MongoClient(self.uri, **options)
|
||||
|
||||
return self.client
|
||||
|
||||
|
|
2
setup.py
2
setup.py
|
@ -238,7 +238,7 @@ samba = ['pysmbclient>=0.1.3']
|
|||
segment = ['analytics-python>=1.2.9']
|
||||
sendgrid = ['sendgrid>=5.2.0,<6']
|
||||
slack = ['slackclient>=1.0.0']
|
||||
mongo = ['pymongo>=3.6.0']
|
||||
mongo = ['pymongo>=3.6.0', 'dnspython>=1.13.0,<2.0.0']
|
||||
snowflake = ['snowflake-connector-python>=1.5.2',
|
||||
'snowflake-sqlalchemy>=1.1.0']
|
||||
ssh = ['paramiko>=2.1.1', 'pysftp>=0.2.9', 'sshtunnel>=0.1.4,<0.2']
|
||||
|
|
|
@ -25,6 +25,8 @@ except ImportError:
|
|||
|
||||
from airflow import configuration
|
||||
from airflow.contrib.hooks.mongo_hook import MongoHook
|
||||
from airflow.models import Connection
|
||||
from airflow.utils import db
|
||||
|
||||
|
||||
class MongoHookTest(MongoHook):
|
||||
|
@ -44,12 +46,21 @@ class TestMongoHook(unittest.TestCase):
|
|||
configuration.load_test_config()
|
||||
self.hook = MongoHookTest(conn_id='mongo_default', mongo_db='default')
|
||||
self.conn = self.hook.get_conn()
|
||||
db.merge_conn(
|
||||
Connection(
|
||||
conn_id='mongo_default_with_srv', conn_type='mongo',
|
||||
host='mongo', port='27017', extra='{"srv": true}'))
|
||||
|
||||
@unittest.skipIf(mongomock is None, 'mongomock package not present')
|
||||
def test_get_conn(self):
|
||||
self.assertEqual(self.hook.connection.port, 27017)
|
||||
self.assertIsInstance(self.conn, pymongo.MongoClient)
|
||||
|
||||
@unittest.skipIf(mongomock is None, 'mongomock package not present')
|
||||
def test_srv(self):
|
||||
hook = MongoHook(conn_id='mongo_default_with_srv')
|
||||
self.assertTrue(hook.uri.startswith('mongodb+srv://'))
|
||||
|
||||
@unittest.skipIf(mongomock is None, 'mongomock package not present')
|
||||
def test_insert_one(self):
|
||||
collection = mongomock.MongoClient().db.collection
|
||||
|
|
Загрузка…
Ссылка в новой задаче