[AIRFLOW-3272] Add base grpc hook (#4101)
* [AIRFLOW-3272] add base grpc hook * [AIRFLOW-3272] fix based on comments and add more docs * [AIRFLOW-3272] add extra fields to www_rabc view in connection model * [AIRFLOW-3272] change url for grpc, fix some bugs * [AIRFLOW-3272] Add mcck for grpc * [AIRFLOW-3272] add unit tests for grpc hook * [AIRFLOW-3272] add gRPC connection howto doc
This commit is contained in:
Родитель
ddec6bbeb3
Коммит
8d5d46022b
|
@ -0,0 +1,123 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import grpc
|
||||||
|
from google import auth as google_auth
|
||||||
|
from google.auth import jwt as google_auth_jwt
|
||||||
|
from google.auth.transport import grpc as google_auth_transport_grpc
|
||||||
|
from google.auth.transport import requests as google_auth_transport_requests
|
||||||
|
|
||||||
|
from airflow.hooks.base_hook import BaseHook
|
||||||
|
from airflow.exceptions import AirflowConfigException
|
||||||
|
|
||||||
|
|
||||||
|
class GrpcHook(BaseHook):
|
||||||
|
"""
|
||||||
|
General interaction with gRPC servers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, grpc_conn_id, interceptors=None, custom_connection_func=None):
|
||||||
|
"""
|
||||||
|
:param grpc_conn_id: The connection ID to use when fetching connection info.
|
||||||
|
:type grpc_conn_id: str
|
||||||
|
:param interceptors: a list of gRPC interceptor objects which would be applied
|
||||||
|
to the connected gRPC channel. None by default.
|
||||||
|
:type interceptors: a list of gRPC interceptors based on or extends the four
|
||||||
|
official gRPC interceptors, eg, UnaryUnaryClientInterceptor,
|
||||||
|
UnaryStreamClientInterceptor, StreamUnaryClientInterceptor,
|
||||||
|
StreamStreamClientInterceptor.
|
||||||
|
::param custom_connection_func: The customized connection function to return gRPC channel.
|
||||||
|
:type custom_connection_func: python callable objects that accept the connection as
|
||||||
|
its only arg. Could be partial or lambda.
|
||||||
|
"""
|
||||||
|
self.grpc_conn_id = grpc_conn_id
|
||||||
|
self.conn = self.get_connection(self.grpc_conn_id)
|
||||||
|
self.extras = self.conn.extra_dejson
|
||||||
|
self.interceptors = interceptors if interceptors else []
|
||||||
|
self.custom_connection_func = custom_connection_func
|
||||||
|
|
||||||
|
def get_conn(self):
|
||||||
|
base_url = self.conn.host
|
||||||
|
|
||||||
|
if self.conn.port:
|
||||||
|
base_url = base_url + ":" + str(self.conn.port)
|
||||||
|
|
||||||
|
auth_type = self._get_field("auth_type")
|
||||||
|
|
||||||
|
if auth_type == "NO_AUTH":
|
||||||
|
channel = grpc.insecure_channel(base_url)
|
||||||
|
elif auth_type == "SSL" or auth_type == "TLS":
|
||||||
|
credential_file_name = self._get_field("credential_pem_file")
|
||||||
|
creds = grpc.ssl_channel_credentials(open(credential_file_name).read())
|
||||||
|
channel = grpc.secure_channel(base_url, creds)
|
||||||
|
elif auth_type == "JWT_GOOGLE":
|
||||||
|
credentials, _ = google_auth.default()
|
||||||
|
jwt_creds = google_auth_jwt.OnDemandCredentials.from_signing_credentials(
|
||||||
|
credentials)
|
||||||
|
channel = google_auth_transport_grpc.secure_authorized_channel(
|
||||||
|
jwt_creds, None, base_url)
|
||||||
|
elif auth_type == "OATH_GOOGLE":
|
||||||
|
scopes = self._get_field("scopes").split(",")
|
||||||
|
credentials, _ = google_auth.default(scopes=scopes)
|
||||||
|
request = google_auth_transport_requests.Request()
|
||||||
|
channel = google_auth_transport_grpc.secure_authorized_channel(
|
||||||
|
credentials, request, base_url)
|
||||||
|
elif auth_type == "CUSTOM":
|
||||||
|
if not self.custom_connection_func:
|
||||||
|
raise AirflowConfigException(
|
||||||
|
"Customized connection function not set, not able to establish a channel")
|
||||||
|
channel = self.custom_connection_func(self.conn)
|
||||||
|
else:
|
||||||
|
raise AirflowConfigException(
|
||||||
|
"auth_type not supported or not provided, channel cannot be established,\
|
||||||
|
given value: %s" % str(auth_type))
|
||||||
|
|
||||||
|
if self.interceptors:
|
||||||
|
for interceptor in self.interceptors:
|
||||||
|
channel = grpc.intercept_channel(channel,
|
||||||
|
interceptor)
|
||||||
|
|
||||||
|
return channel
|
||||||
|
|
||||||
|
def run(self, stub_class, call_func, streaming=False, data={}):
|
||||||
|
with self.get_conn() as channel:
|
||||||
|
stub = stub_class(channel)
|
||||||
|
try:
|
||||||
|
rpc_func = getattr(stub, call_func)
|
||||||
|
response = rpc_func(**data)
|
||||||
|
if not streaming:
|
||||||
|
yield response
|
||||||
|
else:
|
||||||
|
for single_response in response:
|
||||||
|
yield single_response
|
||||||
|
except grpc.RpcError as ex:
|
||||||
|
self.log.exception(
|
||||||
|
"Error occured when calling the grpc service: {0}, method: {1} \
|
||||||
|
status code: {2}, error details: {3}"
|
||||||
|
.format(stub.__class__.__name__, call_func, ex.code(), ex.details()))
|
||||||
|
raise ex
|
||||||
|
|
||||||
|
def _get_field(self, field_name, default=None):
|
||||||
|
"""
|
||||||
|
Fetches a field from extras, and returns it. This is some Airflow
|
||||||
|
magic. The grpc hook type adds custom UI elements
|
||||||
|
to the hook page, which allow admins to specify scopes, credential pem files, etc.
|
||||||
|
They get formatted as shown below.
|
||||||
|
"""
|
||||||
|
full_field_name = 'extra__grpc__{}'.format(field_name)
|
||||||
|
if full_field_name in self.extras:
|
||||||
|
return self.extras[full_field_name]
|
||||||
|
else:
|
||||||
|
return default
|
|
@ -90,6 +90,7 @@ class Connection(Base, LoggingMixin):
|
||||||
('qubole', 'Qubole'),
|
('qubole', 'Qubole'),
|
||||||
('mongo', 'MongoDB'),
|
('mongo', 'MongoDB'),
|
||||||
('gcpcloudsql', 'Google Cloud SQL'),
|
('gcpcloudsql', 'Google Cloud SQL'),
|
||||||
|
('grpc', 'GRPC Connection'),
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -247,6 +248,9 @@ class Connection(Base, LoggingMixin):
|
||||||
elif self.conn_type == 'gcpcloudsql':
|
elif self.conn_type == 'gcpcloudsql':
|
||||||
from airflow.contrib.hooks.gcp_sql_hook import CloudSqlDatabaseHook
|
from airflow.contrib.hooks.gcp_sql_hook import CloudSqlDatabaseHook
|
||||||
return CloudSqlDatabaseHook(gcp_cloudsql_conn_id=self.conn_id)
|
return CloudSqlDatabaseHook(gcp_cloudsql_conn_id=self.conn_id)
|
||||||
|
elif self.conn_type == 'grpc':
|
||||||
|
from airflow.contrib.hooks.grpc_hook import GrpcHook
|
||||||
|
return GrpcHook(grpc_conn_id=self.conn_id)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return self.conn_id
|
return self.conn_id
|
||||||
|
|
|
@ -1928,7 +1928,10 @@ class ConnectionModelView(AirflowModelView):
|
||||||
'extra__google_cloud_platform__project',
|
'extra__google_cloud_platform__project',
|
||||||
'extra__google_cloud_platform__key_path',
|
'extra__google_cloud_platform__key_path',
|
||||||
'extra__google_cloud_platform__keyfile_dict',
|
'extra__google_cloud_platform__keyfile_dict',
|
||||||
'extra__google_cloud_platform__scope']
|
'extra__google_cloud_platform__scope',
|
||||||
|
'extra__grpc__auth_type',
|
||||||
|
'extra__grpc__credential_pem_file',
|
||||||
|
'extra__grpc__scopes']
|
||||||
list_columns = ['conn_id', 'conn_type', 'host', 'port', 'is_encrypted',
|
list_columns = ['conn_id', 'conn_type', 'host', 'port', 'is_encrypted',
|
||||||
'is_extra_encrypted']
|
'is_extra_encrypted']
|
||||||
add_columns = edit_columns = ['conn_id', 'conn_type', 'host', 'schema',
|
add_columns = edit_columns = ['conn_id', 'conn_type', 'host', 'schema',
|
||||||
|
@ -1949,7 +1952,7 @@ class ConnectionModelView(AirflowModelView):
|
||||||
|
|
||||||
def process_form(self, form, is_created):
|
def process_form(self, form, is_created):
|
||||||
formdata = form.data
|
formdata = form.data
|
||||||
if formdata['conn_type'] in ['jdbc', 'google_cloud_platform']:
|
if formdata['conn_type'] in ['jdbc', 'google_cloud_platform', 'grpc']:
|
||||||
extra = {
|
extra = {
|
||||||
key: formdata[key]
|
key: formdata[key]
|
||||||
for key in self.extra_fields if key in formdata}
|
for key in self.extra_fields if key in formdata}
|
||||||
|
|
|
@ -510,6 +510,7 @@ Community contributed hooks
|
||||||
.. autoclass:: airflow.contrib.hooks.gcp_api_base_hook.GoogleCloudBaseHook
|
.. autoclass:: airflow.contrib.hooks.gcp_api_base_hook.GoogleCloudBaseHook
|
||||||
.. autoclass:: airflow.contrib.hooks.gcp_kms_hook.GoogleCloudKMSHook
|
.. autoclass:: airflow.contrib.hooks.gcp_kms_hook.GoogleCloudKMSHook
|
||||||
.. autoclass:: airflow.contrib.hooks.gcs_hook.GoogleCloudStorageHook
|
.. autoclass:: airflow.contrib.hooks.gcs_hook.GoogleCloudStorageHook
|
||||||
|
.. autoclass:: airflow.contrib.hooks.grpc_hook.GrpcHook
|
||||||
.. autoclass:: airflow.contrib.hooks.imap_hook.ImapHook
|
.. autoclass:: airflow.contrib.hooks.imap_hook.ImapHook
|
||||||
.. autoclass:: airflow.contrib.hooks.jenkins_hook.JenkinsHook
|
.. autoclass:: airflow.contrib.hooks.jenkins_hook.JenkinsHook
|
||||||
.. autoclass:: airflow.contrib.hooks.jira_hook.JiraHook
|
.. autoclass:: airflow.contrib.hooks.jira_hook.JiraHook
|
||||||
|
|
|
@ -58,6 +58,7 @@ autodoc_mock_imports = [
|
||||||
'google',
|
'google',
|
||||||
'google_auth_httplib2',
|
'google_auth_httplib2',
|
||||||
'googleapiclient',
|
'googleapiclient',
|
||||||
|
'grpc',
|
||||||
'hdfs',
|
'hdfs',
|
||||||
'httplib2',
|
'httplib2',
|
||||||
'jaydebeapi',
|
'jaydebeapi',
|
||||||
|
|
|
@ -0,0 +1,73 @@
|
||||||
|
.. Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
or more contributor license agreements. See the NOTICE file
|
||||||
|
distributed with this work for additional information
|
||||||
|
regarding copyright ownership. The ASF licenses this file
|
||||||
|
to you under the Apache License, Version 2.0 (the
|
||||||
|
"License"); you may not use this file except in compliance
|
||||||
|
with the License. You may obtain a copy of the License at
|
||||||
|
|
||||||
|
.. http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
.. Unless required by applicable law or agreed to in writing,
|
||||||
|
software distributed under the License is distributed on an
|
||||||
|
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
KIND, either express or implied. See the License for the
|
||||||
|
specific language governing permissions and limitations
|
||||||
|
under the License.
|
||||||
|
|
||||||
|
gRPC
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
The gRPC connection type enables integrated connection to a gRPC service
|
||||||
|
|
||||||
|
Authenticating to gRPC
|
||||||
|
'''''''''''''''''''''''
|
||||||
|
|
||||||
|
There are several ways to connect to gRPC service using Airflow.
|
||||||
|
|
||||||
|
1. Using `NO_AUTH` mode, simply setup an insecure channel of connection.
|
||||||
|
2. Using `SSL` or `TLS` mode, supply a credential pem file for the connection id,
|
||||||
|
this will setup SSL or TLS secured connection with gRPC service.
|
||||||
|
3. Using `JWT_GOOGLE` mode. It is using google auth default credentials by default,
|
||||||
|
further use case of getting credentials from service account can be add later on.
|
||||||
|
4. Using `OATH_GOOGLE` mode. Scopes are required in the extra field, can be setup in the UI.
|
||||||
|
It is using google auth default credentials by default,
|
||||||
|
further use case of getting credentials from service account can be add later on.
|
||||||
|
5. Using `CUSTOM` mode. For this type of connection, you can pass in a connection
|
||||||
|
function takes in the connection object and return a gRPC channel and supply whatever
|
||||||
|
authentication type you want.
|
||||||
|
|
||||||
|
Default Connection IDs
|
||||||
|
''''''''''''''''''''''
|
||||||
|
|
||||||
|
The following connection IDs are used by default.
|
||||||
|
|
||||||
|
``grpc_default``
|
||||||
|
Used by the :class:`~airflow.contrib.hooks.grpc_hook.GrpcHook`
|
||||||
|
hook.
|
||||||
|
|
||||||
|
Configuring the Connection
|
||||||
|
''''''''''''''''''''''''''
|
||||||
|
|
||||||
|
Host
|
||||||
|
The host url of the gRPC server
|
||||||
|
|
||||||
|
Port (Optional)
|
||||||
|
The port to connect to on gRPC server
|
||||||
|
|
||||||
|
Auth Type
|
||||||
|
Authentication type of the gRPC connection.
|
||||||
|
`NO_AUTH` by default, possible values are
|
||||||
|
`NO_AUTH`, `SSL`, `TLS`, `JWT_GOOGLE`,
|
||||||
|
`OATH_GOOGLE`, `CUSTOM`
|
||||||
|
|
||||||
|
Credential Pem File (Optional)
|
||||||
|
Pem file that contains credentials for
|
||||||
|
`SSL` and `TLS` type auth
|
||||||
|
Not required for other types.
|
||||||
|
|
||||||
|
Scopes (comma separated) (Optional)
|
||||||
|
A list of comma-separated `Google Cloud scopes
|
||||||
|
<https://developers.google.com/identity/protocols/googlescopes>`_ to
|
||||||
|
authenticate with.
|
||||||
|
Only for `OATH_GOOGLE` type connection
|
4
setup.py
4
setup.py
|
@ -186,6 +186,7 @@ gcp = [
|
||||||
'pandas-gbq'
|
'pandas-gbq'
|
||||||
]
|
]
|
||||||
github_enterprise = ['Flask-OAuthlib>=0.9.1']
|
github_enterprise = ['Flask-OAuthlib>=0.9.1']
|
||||||
|
grpc = ['grpcio>=1.15.0']
|
||||||
google_auth = ['Flask-OAuthlib>=0.9.1']
|
google_auth = ['Flask-OAuthlib>=0.9.1']
|
||||||
hdfs = ['snakebite>=2.7.8']
|
hdfs = ['snakebite>=2.7.8']
|
||||||
hive = [
|
hive = [
|
||||||
|
@ -260,7 +261,7 @@ if not PY3:
|
||||||
devel_minreq = devel + kubernetes + mysql + doc + password + cgroups
|
devel_minreq = devel + kubernetes + mysql + doc + password + cgroups
|
||||||
devel_hadoop = devel_minreq + hive + hdfs + webhdfs + kerberos
|
devel_hadoop = devel_minreq + hive + hdfs + webhdfs + kerberos
|
||||||
devel_all = (sendgrid + devel + all_dbs + doc + samba + slack + crypto + oracle +
|
devel_all = (sendgrid + devel + all_dbs + doc + samba + slack + crypto + oracle +
|
||||||
docker + ssh + kubernetes + celery + redis + gcp +
|
docker + ssh + kubernetes + celery + redis + gcp + grpc +
|
||||||
datadog + zendesk + jdbc + ldap + kerberos + password + webhdfs + jenkins +
|
datadog + zendesk + jdbc + ldap + kerberos + password + webhdfs + jenkins +
|
||||||
druid + pinot + segment + snowflake + elasticsearch +
|
druid + pinot + segment + snowflake + elasticsearch +
|
||||||
atlas + azure + aws)
|
atlas + azure + aws)
|
||||||
|
@ -355,6 +356,7 @@ def do_setup():
|
||||||
'gcp_api': gcp, # TODO: remove this in Airflow 2.1
|
'gcp_api': gcp, # TODO: remove this in Airflow 2.1
|
||||||
'github_enterprise': github_enterprise,
|
'github_enterprise': github_enterprise,
|
||||||
'google_auth': google_auth,
|
'google_auth': google_auth,
|
||||||
|
'grpc': grpc,
|
||||||
'hdfs': hdfs,
|
'hdfs': hdfs,
|
||||||
'hive': hive,
|
'hive': hive,
|
||||||
'jdbc': jdbc,
|
'jdbc': jdbc,
|
||||||
|
|
|
@ -0,0 +1,314 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import unittest
|
||||||
|
try:
|
||||||
|
from StringIO import StringIO
|
||||||
|
except ImportError:
|
||||||
|
from io import StringIO
|
||||||
|
|
||||||
|
from airflow import configuration
|
||||||
|
from airflow.exceptions import AirflowConfigException
|
||||||
|
from airflow.contrib.hooks.grpc_hook import GrpcHook
|
||||||
|
from airflow.models.connection import Connection
|
||||||
|
|
||||||
|
try:
|
||||||
|
from unittest import mock
|
||||||
|
except ImportError:
|
||||||
|
try:
|
||||||
|
import mock
|
||||||
|
except ImportError:
|
||||||
|
mock = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_airflow_connection(auth_type="NO_AUTH", credential_pem_file=None, scopes=None):
|
||||||
|
extra = \
|
||||||
|
'{{"extra__grpc__auth_type": "{auth_type}",' \
|
||||||
|
'"extra__grpc__credential_pem_file": "{credential_pem_file}",' \
|
||||||
|
'"extra__grpc__scopes": "{scopes}"}}' \
|
||||||
|
.format(auth_type=auth_type,
|
||||||
|
credential_pem_file=credential_pem_file,
|
||||||
|
scopes=scopes)
|
||||||
|
|
||||||
|
return Connection(
|
||||||
|
conn_id='grpc_default',
|
||||||
|
conn_type='grpc',
|
||||||
|
host='test:8080',
|
||||||
|
extra=extra
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_airflow_connection_with_port():
|
||||||
|
return Connection(
|
||||||
|
conn_id='grpc_default',
|
||||||
|
conn_type='grpc',
|
||||||
|
host='test.com',
|
||||||
|
port=1234,
|
||||||
|
extra='{"extra__grpc__auth_type": "NO_AUTH"}'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class StubClass(object):
|
||||||
|
def __init__(self, channel):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def single_call(self, data):
|
||||||
|
return data
|
||||||
|
|
||||||
|
def stream_call(self, data):
|
||||||
|
return ["streaming", "call"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestGrpcHook(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
configuration.load_test_config()
|
||||||
|
self.channel_mock = mock.patch('grpc.Channel').start()
|
||||||
|
|
||||||
|
def custom_conn_func(self, connection):
|
||||||
|
mocked_channel = self.channel_mock.return_value
|
||||||
|
return mocked_channel
|
||||||
|
|
||||||
|
@mock.patch('grpc.insecure_channel')
|
||||||
|
@mock.patch('airflow.hooks.base_hook.BaseHook.get_connection')
|
||||||
|
def test_no_auth_connection(self, mock_get_connection, mock_insecure_channel):
|
||||||
|
conn = get_airflow_connection()
|
||||||
|
mock_get_connection.return_value = conn
|
||||||
|
hook = GrpcHook("grpc_default")
|
||||||
|
mocked_channel = self.channel_mock.return_value
|
||||||
|
mock_insecure_channel.return_value = mocked_channel
|
||||||
|
|
||||||
|
channel = hook.get_conn()
|
||||||
|
expected_url = "test:8080"
|
||||||
|
|
||||||
|
mock_insecure_channel.assert_called_once_with(expected_url)
|
||||||
|
self.assertEquals(channel, mocked_channel)
|
||||||
|
|
||||||
|
@mock.patch('grpc.insecure_channel')
|
||||||
|
@mock.patch('airflow.hooks.base_hook.BaseHook.get_connection')
|
||||||
|
def test_connection_with_port(self, mock_get_connection, mock_insecure_channel):
|
||||||
|
conn = get_airflow_connection_with_port()
|
||||||
|
mock_get_connection.return_value = conn
|
||||||
|
hook = GrpcHook("grpc_default")
|
||||||
|
mocked_channel = self.channel_mock.return_value
|
||||||
|
mock_insecure_channel.return_value = mocked_channel
|
||||||
|
|
||||||
|
channel = hook.get_conn()
|
||||||
|
expected_url = "test.com:1234"
|
||||||
|
|
||||||
|
mock_insecure_channel.assert_called_once_with(expected_url)
|
||||||
|
self.assertEquals(channel, mocked_channel)
|
||||||
|
|
||||||
|
@mock.patch('airflow.contrib.hooks.grpc_hook.open')
|
||||||
|
@mock.patch('airflow.hooks.base_hook.BaseHook.get_connection')
|
||||||
|
@mock.patch('grpc.ssl_channel_credentials')
|
||||||
|
@mock.patch('grpc.secure_channel')
|
||||||
|
def test_connection_with_ssl(self,
|
||||||
|
mock_secure_channel,
|
||||||
|
mock_channel_credentials,
|
||||||
|
mock_get_connection,
|
||||||
|
mock_open):
|
||||||
|
conn = get_airflow_connection(
|
||||||
|
auth_type="SSL",
|
||||||
|
credential_pem_file="pem"
|
||||||
|
)
|
||||||
|
mock_get_connection.return_value = conn
|
||||||
|
mock_open.return_value = StringIO('credential')
|
||||||
|
hook = GrpcHook("grpc_default")
|
||||||
|
mocked_channel = self.channel_mock.return_value
|
||||||
|
mock_secure_channel.return_value = mocked_channel
|
||||||
|
mock_credential_object = "test_credential_object"
|
||||||
|
mock_channel_credentials.return_value = mock_credential_object
|
||||||
|
|
||||||
|
channel = hook.get_conn()
|
||||||
|
expected_url = "test:8080"
|
||||||
|
|
||||||
|
mock_open.assert_called_once_with("pem")
|
||||||
|
mock_channel_credentials.assert_called_once_with('credential')
|
||||||
|
mock_secure_channel.assert_called_once_with(
|
||||||
|
expected_url,
|
||||||
|
mock_credential_object
|
||||||
|
)
|
||||||
|
self.assertEquals(channel, mocked_channel)
|
||||||
|
|
||||||
|
@mock.patch('airflow.contrib.hooks.grpc_hook.open')
|
||||||
|
@mock.patch('airflow.hooks.base_hook.BaseHook.get_connection')
|
||||||
|
@mock.patch('grpc.ssl_channel_credentials')
|
||||||
|
@mock.patch('grpc.secure_channel')
|
||||||
|
def test_connection_with_tls(self,
|
||||||
|
mock_secure_channel,
|
||||||
|
mock_channel_credentials,
|
||||||
|
mock_get_connection,
|
||||||
|
mock_open):
|
||||||
|
conn = get_airflow_connection(
|
||||||
|
auth_type="TLS",
|
||||||
|
credential_pem_file="pem"
|
||||||
|
)
|
||||||
|
mock_get_connection.return_value = conn
|
||||||
|
mock_open.return_value = StringIO('credential')
|
||||||
|
hook = GrpcHook("grpc_default")
|
||||||
|
mocked_channel = self.channel_mock.return_value
|
||||||
|
mock_secure_channel.return_value = mocked_channel
|
||||||
|
mock_credential_object = "test_credential_object"
|
||||||
|
mock_channel_credentials.return_value = mock_credential_object
|
||||||
|
|
||||||
|
channel = hook.get_conn()
|
||||||
|
expected_url = "test:8080"
|
||||||
|
|
||||||
|
mock_open.assert_called_once_with("pem")
|
||||||
|
mock_channel_credentials.assert_called_once_with('credential')
|
||||||
|
mock_secure_channel.assert_called_once_with(
|
||||||
|
expected_url,
|
||||||
|
mock_credential_object
|
||||||
|
)
|
||||||
|
self.assertEquals(channel, mocked_channel)
|
||||||
|
|
||||||
|
@mock.patch('airflow.hooks.base_hook.BaseHook.get_connection')
|
||||||
|
@mock.patch('google.auth.jwt.OnDemandCredentials.from_signing_credentials')
|
||||||
|
@mock.patch('google.auth.default')
|
||||||
|
@mock.patch('google.auth.transport.grpc.secure_authorized_channel')
|
||||||
|
def test_connection_with_jwt(self,
|
||||||
|
mock_secure_channel,
|
||||||
|
mock_google_default_auth,
|
||||||
|
mock_google_cred,
|
||||||
|
mock_get_connection):
|
||||||
|
conn = get_airflow_connection(
|
||||||
|
auth_type="JWT_GOOGLE"
|
||||||
|
)
|
||||||
|
mock_get_connection.return_value = conn
|
||||||
|
hook = GrpcHook("grpc_default")
|
||||||
|
mocked_channel = self.channel_mock.return_value
|
||||||
|
mock_secure_channel.return_value = mocked_channel
|
||||||
|
mock_credential_object = "test_credential_object"
|
||||||
|
mock_google_default_auth.return_value = (mock_credential_object, "")
|
||||||
|
mock_google_cred.return_value = mock_credential_object
|
||||||
|
|
||||||
|
channel = hook.get_conn()
|
||||||
|
expected_url = "test:8080"
|
||||||
|
|
||||||
|
mock_google_cred.assert_called_once_with(mock_credential_object)
|
||||||
|
mock_secure_channel.assert_called_once_with(
|
||||||
|
mock_credential_object,
|
||||||
|
None,
|
||||||
|
expected_url
|
||||||
|
)
|
||||||
|
self.assertEquals(channel, mocked_channel)
|
||||||
|
|
||||||
|
@mock.patch('airflow.hooks.base_hook.BaseHook.get_connection')
|
||||||
|
@mock.patch('google.auth.transport.requests.Request')
|
||||||
|
@mock.patch('google.auth.default')
|
||||||
|
@mock.patch('google.auth.transport.grpc.secure_authorized_channel')
|
||||||
|
def test_connection_with_google_oauth(self,
|
||||||
|
mock_secure_channel,
|
||||||
|
mock_google_default_auth,
|
||||||
|
mock_google_auth_request,
|
||||||
|
mock_get_connection):
|
||||||
|
conn = get_airflow_connection(
|
||||||
|
auth_type="OATH_GOOGLE",
|
||||||
|
scopes="grpc,gcs"
|
||||||
|
)
|
||||||
|
mock_get_connection.return_value = conn
|
||||||
|
hook = GrpcHook("grpc_default")
|
||||||
|
mocked_channel = self.channel_mock.return_value
|
||||||
|
mock_secure_channel.return_value = mocked_channel
|
||||||
|
mock_credential_object = "test_credential_object"
|
||||||
|
mock_google_default_auth.return_value = (mock_credential_object, "")
|
||||||
|
mock_google_auth_request.return_value = "request"
|
||||||
|
|
||||||
|
channel = hook.get_conn()
|
||||||
|
expected_url = "test:8080"
|
||||||
|
|
||||||
|
mock_google_default_auth.assert_called_once_with(scopes=[u"grpc", u"gcs"])
|
||||||
|
mock_secure_channel.assert_called_once_with(
|
||||||
|
mock_credential_object,
|
||||||
|
"request",
|
||||||
|
expected_url
|
||||||
|
)
|
||||||
|
self.assertEquals(channel, mocked_channel)
|
||||||
|
|
||||||
|
@mock.patch('airflow.hooks.base_hook.BaseHook.get_connection')
|
||||||
|
def test_custom_connection(self, mock_get_connection):
|
||||||
|
conn = get_airflow_connection("CUSTOM")
|
||||||
|
mock_get_connection.return_value = conn
|
||||||
|
mocked_channel = self.channel_mock.return_value
|
||||||
|
hook = GrpcHook("grpc_default", custom_connection_func=self.custom_conn_func)
|
||||||
|
|
||||||
|
channel = hook.get_conn()
|
||||||
|
|
||||||
|
self.assertEquals(channel, mocked_channel)
|
||||||
|
|
||||||
|
@mock.patch('airflow.hooks.base_hook.BaseHook.get_connection')
|
||||||
|
def test_custom_connection_with_no_connection_func(self, mock_get_connection):
|
||||||
|
conn = get_airflow_connection("CUSTOM")
|
||||||
|
mock_get_connection.return_value = conn
|
||||||
|
hook = GrpcHook("grpc_default")
|
||||||
|
|
||||||
|
with self.assertRaises(AirflowConfigException):
|
||||||
|
hook.get_conn()
|
||||||
|
|
||||||
|
@mock.patch('airflow.hooks.base_hook.BaseHook.get_connection')
|
||||||
|
def test_connection_type_not_supported(self, mock_get_connection):
|
||||||
|
conn = get_airflow_connection("NOT_SUPPORT")
|
||||||
|
mock_get_connection.return_value = conn
|
||||||
|
hook = GrpcHook("grpc_default")
|
||||||
|
|
||||||
|
with self.assertRaises(AirflowConfigException):
|
||||||
|
hook.get_conn()
|
||||||
|
|
||||||
|
@mock.patch('grpc.intercept_channel')
|
||||||
|
@mock.patch('airflow.hooks.base_hook.BaseHook.get_connection')
|
||||||
|
@mock.patch('grpc.insecure_channel')
|
||||||
|
def test_connection_with_interceptors(self,
|
||||||
|
mock_insecure_channel,
|
||||||
|
mock_get_connection,
|
||||||
|
mock_intercept_channel):
|
||||||
|
conn = get_airflow_connection()
|
||||||
|
mock_get_connection.return_value = conn
|
||||||
|
mocked_channel = self.channel_mock.return_value
|
||||||
|
hook = GrpcHook("grpc_default", interceptors=["test1"])
|
||||||
|
mock_insecure_channel.return_value = mocked_channel
|
||||||
|
mock_intercept_channel.return_value = mocked_channel
|
||||||
|
|
||||||
|
channel = hook.get_conn()
|
||||||
|
|
||||||
|
self.assertEquals(channel, mocked_channel)
|
||||||
|
mock_intercept_channel.assert_called_once_with(mocked_channel, "test1")
|
||||||
|
|
||||||
|
@mock.patch('airflow.hooks.base_hook.BaseHook.get_connection')
|
||||||
|
@mock.patch('airflow.contrib.hooks.grpc_hook.GrpcHook.get_conn')
|
||||||
|
def test_simple_run(self, mock_get_conn, mock_get_connection):
|
||||||
|
conn = get_airflow_connection()
|
||||||
|
mock_get_connection.return_value = conn
|
||||||
|
mocked_channel = mock.Mock()
|
||||||
|
mocked_channel.__enter__ = mock.Mock(return_value=(mock.Mock(), None))
|
||||||
|
mocked_channel.__exit__ = mock.Mock(return_value=None)
|
||||||
|
hook = GrpcHook("grpc_default")
|
||||||
|
mock_get_conn.return_value = mocked_channel
|
||||||
|
|
||||||
|
response = hook.run(StubClass, "single_call", data={'data': 'hello'})
|
||||||
|
|
||||||
|
self.assertEquals(next(response), "hello")
|
||||||
|
|
||||||
|
@mock.patch('airflow.hooks.base_hook.BaseHook.get_connection')
|
||||||
|
@mock.patch('airflow.contrib.hooks.grpc_hook.GrpcHook.get_conn')
|
||||||
|
def test_stream_run(self, mock_get_conn, mock_get_connection):
|
||||||
|
conn = get_airflow_connection()
|
||||||
|
mock_get_connection.return_value = conn
|
||||||
|
mocked_channel = mock.Mock()
|
||||||
|
mocked_channel.__enter__ = mock.Mock(return_value=(mock.Mock(), None))
|
||||||
|
mocked_channel.__exit__ = mock.Mock(return_value=None)
|
||||||
|
hook = GrpcHook("grpc_default")
|
||||||
|
mock_get_conn.return_value = mocked_channel
|
||||||
|
|
||||||
|
response = hook.run(StubClass, "stream_call", data={'data': ['hello!', "hi"]})
|
||||||
|
|
||||||
|
self.assertEquals(next(response), ["streaming", "call"])
|
Загрузка…
Ссылка в новой задаче