[AIRFLOW-3272] Add base grpc hook (#4101)
* [AIRFLOW-3272] add base grpc hook * [AIRFLOW-3272] fix based on comments and add more docs * [AIRFLOW-3272] add extra fields to www_rabc view in connection model * [AIRFLOW-3272] change url for grpc, fix some bugs * [AIRFLOW-3272] Add mcck for grpc * [AIRFLOW-3272] add unit tests for grpc hook * [AIRFLOW-3272] add gRPC connection howto doc
This commit is contained in:
Родитель
ddec6bbeb3
Коммит
8d5d46022b
|
@ -0,0 +1,123 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import grpc
|
||||
from google import auth as google_auth
|
||||
from google.auth import jwt as google_auth_jwt
|
||||
from google.auth.transport import grpc as google_auth_transport_grpc
|
||||
from google.auth.transport import requests as google_auth_transport_requests
|
||||
|
||||
from airflow.hooks.base_hook import BaseHook
|
||||
from airflow.exceptions import AirflowConfigException
|
||||
|
||||
|
||||
class GrpcHook(BaseHook):
|
||||
"""
|
||||
General interaction with gRPC servers.
|
||||
"""
|
||||
|
||||
def __init__(self, grpc_conn_id, interceptors=None, custom_connection_func=None):
|
||||
"""
|
||||
:param grpc_conn_id: The connection ID to use when fetching connection info.
|
||||
:type grpc_conn_id: str
|
||||
:param interceptors: a list of gRPC interceptor objects which would be applied
|
||||
to the connected gRPC channel. None by default.
|
||||
:type interceptors: a list of gRPC interceptors based on or extends the four
|
||||
official gRPC interceptors, eg, UnaryUnaryClientInterceptor,
|
||||
UnaryStreamClientInterceptor, StreamUnaryClientInterceptor,
|
||||
StreamStreamClientInterceptor.
|
||||
::param custom_connection_func: The customized connection function to return gRPC channel.
|
||||
:type custom_connection_func: python callable objects that accept the connection as
|
||||
its only arg. Could be partial or lambda.
|
||||
"""
|
||||
self.grpc_conn_id = grpc_conn_id
|
||||
self.conn = self.get_connection(self.grpc_conn_id)
|
||||
self.extras = self.conn.extra_dejson
|
||||
self.interceptors = interceptors if interceptors else []
|
||||
self.custom_connection_func = custom_connection_func
|
||||
|
||||
def get_conn(self):
|
||||
base_url = self.conn.host
|
||||
|
||||
if self.conn.port:
|
||||
base_url = base_url + ":" + str(self.conn.port)
|
||||
|
||||
auth_type = self._get_field("auth_type")
|
||||
|
||||
if auth_type == "NO_AUTH":
|
||||
channel = grpc.insecure_channel(base_url)
|
||||
elif auth_type == "SSL" or auth_type == "TLS":
|
||||
credential_file_name = self._get_field("credential_pem_file")
|
||||
creds = grpc.ssl_channel_credentials(open(credential_file_name).read())
|
||||
channel = grpc.secure_channel(base_url, creds)
|
||||
elif auth_type == "JWT_GOOGLE":
|
||||
credentials, _ = google_auth.default()
|
||||
jwt_creds = google_auth_jwt.OnDemandCredentials.from_signing_credentials(
|
||||
credentials)
|
||||
channel = google_auth_transport_grpc.secure_authorized_channel(
|
||||
jwt_creds, None, base_url)
|
||||
elif auth_type == "OATH_GOOGLE":
|
||||
scopes = self._get_field("scopes").split(",")
|
||||
credentials, _ = google_auth.default(scopes=scopes)
|
||||
request = google_auth_transport_requests.Request()
|
||||
channel = google_auth_transport_grpc.secure_authorized_channel(
|
||||
credentials, request, base_url)
|
||||
elif auth_type == "CUSTOM":
|
||||
if not self.custom_connection_func:
|
||||
raise AirflowConfigException(
|
||||
"Customized connection function not set, not able to establish a channel")
|
||||
channel = self.custom_connection_func(self.conn)
|
||||
else:
|
||||
raise AirflowConfigException(
|
||||
"auth_type not supported or not provided, channel cannot be established,\
|
||||
given value: %s" % str(auth_type))
|
||||
|
||||
if self.interceptors:
|
||||
for interceptor in self.interceptors:
|
||||
channel = grpc.intercept_channel(channel,
|
||||
interceptor)
|
||||
|
||||
return channel
|
||||
|
||||
def run(self, stub_class, call_func, streaming=False, data={}):
|
||||
with self.get_conn() as channel:
|
||||
stub = stub_class(channel)
|
||||
try:
|
||||
rpc_func = getattr(stub, call_func)
|
||||
response = rpc_func(**data)
|
||||
if not streaming:
|
||||
yield response
|
||||
else:
|
||||
for single_response in response:
|
||||
yield single_response
|
||||
except grpc.RpcError as ex:
|
||||
self.log.exception(
|
||||
"Error occured when calling the grpc service: {0}, method: {1} \
|
||||
status code: {2}, error details: {3}"
|
||||
.format(stub.__class__.__name__, call_func, ex.code(), ex.details()))
|
||||
raise ex
|
||||
|
||||
def _get_field(self, field_name, default=None):
|
||||
"""
|
||||
Fetches a field from extras, and returns it. This is some Airflow
|
||||
magic. The grpc hook type adds custom UI elements
|
||||
to the hook page, which allow admins to specify scopes, credential pem files, etc.
|
||||
They get formatted as shown below.
|
||||
"""
|
||||
full_field_name = 'extra__grpc__{}'.format(field_name)
|
||||
if full_field_name in self.extras:
|
||||
return self.extras[full_field_name]
|
||||
else:
|
||||
return default
|
|
@ -90,6 +90,7 @@ class Connection(Base, LoggingMixin):
|
|||
('qubole', 'Qubole'),
|
||||
('mongo', 'MongoDB'),
|
||||
('gcpcloudsql', 'Google Cloud SQL'),
|
||||
('grpc', 'GRPC Connection'),
|
||||
]
|
||||
|
||||
def __init__(
|
||||
|
@ -247,6 +248,9 @@ class Connection(Base, LoggingMixin):
|
|||
elif self.conn_type == 'gcpcloudsql':
|
||||
from airflow.contrib.hooks.gcp_sql_hook import CloudSqlDatabaseHook
|
||||
return CloudSqlDatabaseHook(gcp_cloudsql_conn_id=self.conn_id)
|
||||
elif self.conn_type == 'grpc':
|
||||
from airflow.contrib.hooks.grpc_hook import GrpcHook
|
||||
return GrpcHook(grpc_conn_id=self.conn_id)
|
||||
|
||||
def __repr__(self):
|
||||
return self.conn_id
|
||||
|
|
|
@ -1928,7 +1928,10 @@ class ConnectionModelView(AirflowModelView):
|
|||
'extra__google_cloud_platform__project',
|
||||
'extra__google_cloud_platform__key_path',
|
||||
'extra__google_cloud_platform__keyfile_dict',
|
||||
'extra__google_cloud_platform__scope']
|
||||
'extra__google_cloud_platform__scope',
|
||||
'extra__grpc__auth_type',
|
||||
'extra__grpc__credential_pem_file',
|
||||
'extra__grpc__scopes']
|
||||
list_columns = ['conn_id', 'conn_type', 'host', 'port', 'is_encrypted',
|
||||
'is_extra_encrypted']
|
||||
add_columns = edit_columns = ['conn_id', 'conn_type', 'host', 'schema',
|
||||
|
@ -1949,7 +1952,7 @@ class ConnectionModelView(AirflowModelView):
|
|||
|
||||
def process_form(self, form, is_created):
|
||||
formdata = form.data
|
||||
if formdata['conn_type'] in ['jdbc', 'google_cloud_platform']:
|
||||
if formdata['conn_type'] in ['jdbc', 'google_cloud_platform', 'grpc']:
|
||||
extra = {
|
||||
key: formdata[key]
|
||||
for key in self.extra_fields if key in formdata}
|
||||
|
|
|
@ -510,6 +510,7 @@ Community contributed hooks
|
|||
.. autoclass:: airflow.contrib.hooks.gcp_api_base_hook.GoogleCloudBaseHook
|
||||
.. autoclass:: airflow.contrib.hooks.gcp_kms_hook.GoogleCloudKMSHook
|
||||
.. autoclass:: airflow.contrib.hooks.gcs_hook.GoogleCloudStorageHook
|
||||
.. autoclass:: airflow.contrib.hooks.grpc_hook.GrpcHook
|
||||
.. autoclass:: airflow.contrib.hooks.imap_hook.ImapHook
|
||||
.. autoclass:: airflow.contrib.hooks.jenkins_hook.JenkinsHook
|
||||
.. autoclass:: airflow.contrib.hooks.jira_hook.JiraHook
|
||||
|
|
|
@ -58,6 +58,7 @@ autodoc_mock_imports = [
|
|||
'google',
|
||||
'google_auth_httplib2',
|
||||
'googleapiclient',
|
||||
'grpc',
|
||||
'hdfs',
|
||||
'httplib2',
|
||||
'jaydebeapi',
|
||||
|
|
|
@ -0,0 +1,73 @@
|
|||
.. Licensed to the Apache Software Foundation (ASF) under one
|
||||
or more contributor license agreements. See the NOTICE file
|
||||
distributed with this work for additional information
|
||||
regarding copyright ownership. The ASF licenses this file
|
||||
to you under the Apache License, Version 2.0 (the
|
||||
"License"); you may not use this file except in compliance
|
||||
with the License. You may obtain a copy of the License at
|
||||
|
||||
.. http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
.. Unless required by applicable law or agreed to in writing,
|
||||
software distributed under the License is distributed on an
|
||||
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations
|
||||
under the License.
|
||||
|
||||
gRPC
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
The gRPC connection type enables integrated connection to a gRPC service
|
||||
|
||||
Authenticating to gRPC
|
||||
'''''''''''''''''''''''
|
||||
|
||||
There are several ways to connect to gRPC service using Airflow.
|
||||
|
||||
1. Using `NO_AUTH` mode, simply setup an insecure channel of connection.
|
||||
2. Using `SSL` or `TLS` mode, supply a credential pem file for the connection id,
|
||||
this will setup SSL or TLS secured connection with gRPC service.
|
||||
3. Using `JWT_GOOGLE` mode. It is using google auth default credentials by default,
|
||||
further use case of getting credentials from service account can be add later on.
|
||||
4. Using `OATH_GOOGLE` mode. Scopes are required in the extra field, can be setup in the UI.
|
||||
It is using google auth default credentials by default,
|
||||
further use case of getting credentials from service account can be add later on.
|
||||
5. Using `CUSTOM` mode. For this type of connection, you can pass in a connection
|
||||
function takes in the connection object and return a gRPC channel and supply whatever
|
||||
authentication type you want.
|
||||
|
||||
Default Connection IDs
|
||||
''''''''''''''''''''''
|
||||
|
||||
The following connection IDs are used by default.
|
||||
|
||||
``grpc_default``
|
||||
Used by the :class:`~airflow.contrib.hooks.grpc_hook.GrpcHook`
|
||||
hook.
|
||||
|
||||
Configuring the Connection
|
||||
''''''''''''''''''''''''''
|
||||
|
||||
Host
|
||||
The host url of the gRPC server
|
||||
|
||||
Port (Optional)
|
||||
The port to connect to on gRPC server
|
||||
|
||||
Auth Type
|
||||
Authentication type of the gRPC connection.
|
||||
`NO_AUTH` by default, possible values are
|
||||
`NO_AUTH`, `SSL`, `TLS`, `JWT_GOOGLE`,
|
||||
`OATH_GOOGLE`, `CUSTOM`
|
||||
|
||||
Credential Pem File (Optional)
|
||||
Pem file that contains credentials for
|
||||
`SSL` and `TLS` type auth
|
||||
Not required for other types.
|
||||
|
||||
Scopes (comma separated) (Optional)
|
||||
A list of comma-separated `Google Cloud scopes
|
||||
<https://developers.google.com/identity/protocols/googlescopes>`_ to
|
||||
authenticate with.
|
||||
Only for `OATH_GOOGLE` type connection
|
4
setup.py
4
setup.py
|
@ -186,6 +186,7 @@ gcp = [
|
|||
'pandas-gbq'
|
||||
]
|
||||
github_enterprise = ['Flask-OAuthlib>=0.9.1']
|
||||
grpc = ['grpcio>=1.15.0']
|
||||
google_auth = ['Flask-OAuthlib>=0.9.1']
|
||||
hdfs = ['snakebite>=2.7.8']
|
||||
hive = [
|
||||
|
@ -260,7 +261,7 @@ if not PY3:
|
|||
devel_minreq = devel + kubernetes + mysql + doc + password + cgroups
|
||||
devel_hadoop = devel_minreq + hive + hdfs + webhdfs + kerberos
|
||||
devel_all = (sendgrid + devel + all_dbs + doc + samba + slack + crypto + oracle +
|
||||
docker + ssh + kubernetes + celery + redis + gcp +
|
||||
docker + ssh + kubernetes + celery + redis + gcp + grpc +
|
||||
datadog + zendesk + jdbc + ldap + kerberos + password + webhdfs + jenkins +
|
||||
druid + pinot + segment + snowflake + elasticsearch +
|
||||
atlas + azure + aws)
|
||||
|
@ -355,6 +356,7 @@ def do_setup():
|
|||
'gcp_api': gcp, # TODO: remove this in Airflow 2.1
|
||||
'github_enterprise': github_enterprise,
|
||||
'google_auth': google_auth,
|
||||
'grpc': grpc,
|
||||
'hdfs': hdfs,
|
||||
'hive': hive,
|
||||
'jdbc': jdbc,
|
||||
|
|
|
@ -0,0 +1,314 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
try:
|
||||
from StringIO import StringIO
|
||||
except ImportError:
|
||||
from io import StringIO
|
||||
|
||||
from airflow import configuration
|
||||
from airflow.exceptions import AirflowConfigException
|
||||
from airflow.contrib.hooks.grpc_hook import GrpcHook
|
||||
from airflow.models.connection import Connection
|
||||
|
||||
try:
|
||||
from unittest import mock
|
||||
except ImportError:
|
||||
try:
|
||||
import mock
|
||||
except ImportError:
|
||||
mock = None
|
||||
|
||||
|
||||
def get_airflow_connection(auth_type="NO_AUTH", credential_pem_file=None, scopes=None):
|
||||
extra = \
|
||||
'{{"extra__grpc__auth_type": "{auth_type}",' \
|
||||
'"extra__grpc__credential_pem_file": "{credential_pem_file}",' \
|
||||
'"extra__grpc__scopes": "{scopes}"}}' \
|
||||
.format(auth_type=auth_type,
|
||||
credential_pem_file=credential_pem_file,
|
||||
scopes=scopes)
|
||||
|
||||
return Connection(
|
||||
conn_id='grpc_default',
|
||||
conn_type='grpc',
|
||||
host='test:8080',
|
||||
extra=extra
|
||||
)
|
||||
|
||||
|
||||
def get_airflow_connection_with_port():
|
||||
return Connection(
|
||||
conn_id='grpc_default',
|
||||
conn_type='grpc',
|
||||
host='test.com',
|
||||
port=1234,
|
||||
extra='{"extra__grpc__auth_type": "NO_AUTH"}'
|
||||
)
|
||||
|
||||
|
||||
class StubClass(object):
|
||||
def __init__(self, channel):
|
||||
pass
|
||||
|
||||
def single_call(self, data):
|
||||
return data
|
||||
|
||||
def stream_call(self, data):
|
||||
return ["streaming", "call"]
|
||||
|
||||
|
||||
class TestGrpcHook(unittest.TestCase):
|
||||
def setUp(self):
|
||||
configuration.load_test_config()
|
||||
self.channel_mock = mock.patch('grpc.Channel').start()
|
||||
|
||||
def custom_conn_func(self, connection):
|
||||
mocked_channel = self.channel_mock.return_value
|
||||
return mocked_channel
|
||||
|
||||
@mock.patch('grpc.insecure_channel')
|
||||
@mock.patch('airflow.hooks.base_hook.BaseHook.get_connection')
|
||||
def test_no_auth_connection(self, mock_get_connection, mock_insecure_channel):
|
||||
conn = get_airflow_connection()
|
||||
mock_get_connection.return_value = conn
|
||||
hook = GrpcHook("grpc_default")
|
||||
mocked_channel = self.channel_mock.return_value
|
||||
mock_insecure_channel.return_value = mocked_channel
|
||||
|
||||
channel = hook.get_conn()
|
||||
expected_url = "test:8080"
|
||||
|
||||
mock_insecure_channel.assert_called_once_with(expected_url)
|
||||
self.assertEquals(channel, mocked_channel)
|
||||
|
||||
@mock.patch('grpc.insecure_channel')
|
||||
@mock.patch('airflow.hooks.base_hook.BaseHook.get_connection')
|
||||
def test_connection_with_port(self, mock_get_connection, mock_insecure_channel):
|
||||
conn = get_airflow_connection_with_port()
|
||||
mock_get_connection.return_value = conn
|
||||
hook = GrpcHook("grpc_default")
|
||||
mocked_channel = self.channel_mock.return_value
|
||||
mock_insecure_channel.return_value = mocked_channel
|
||||
|
||||
channel = hook.get_conn()
|
||||
expected_url = "test.com:1234"
|
||||
|
||||
mock_insecure_channel.assert_called_once_with(expected_url)
|
||||
self.assertEquals(channel, mocked_channel)
|
||||
|
||||
@mock.patch('airflow.contrib.hooks.grpc_hook.open')
|
||||
@mock.patch('airflow.hooks.base_hook.BaseHook.get_connection')
|
||||
@mock.patch('grpc.ssl_channel_credentials')
|
||||
@mock.patch('grpc.secure_channel')
|
||||
def test_connection_with_ssl(self,
|
||||
mock_secure_channel,
|
||||
mock_channel_credentials,
|
||||
mock_get_connection,
|
||||
mock_open):
|
||||
conn = get_airflow_connection(
|
||||
auth_type="SSL",
|
||||
credential_pem_file="pem"
|
||||
)
|
||||
mock_get_connection.return_value = conn
|
||||
mock_open.return_value = StringIO('credential')
|
||||
hook = GrpcHook("grpc_default")
|
||||
mocked_channel = self.channel_mock.return_value
|
||||
mock_secure_channel.return_value = mocked_channel
|
||||
mock_credential_object = "test_credential_object"
|
||||
mock_channel_credentials.return_value = mock_credential_object
|
||||
|
||||
channel = hook.get_conn()
|
||||
expected_url = "test:8080"
|
||||
|
||||
mock_open.assert_called_once_with("pem")
|
||||
mock_channel_credentials.assert_called_once_with('credential')
|
||||
mock_secure_channel.assert_called_once_with(
|
||||
expected_url,
|
||||
mock_credential_object
|
||||
)
|
||||
self.assertEquals(channel, mocked_channel)
|
||||
|
||||
@mock.patch('airflow.contrib.hooks.grpc_hook.open')
|
||||
@mock.patch('airflow.hooks.base_hook.BaseHook.get_connection')
|
||||
@mock.patch('grpc.ssl_channel_credentials')
|
||||
@mock.patch('grpc.secure_channel')
|
||||
def test_connection_with_tls(self,
|
||||
mock_secure_channel,
|
||||
mock_channel_credentials,
|
||||
mock_get_connection,
|
||||
mock_open):
|
||||
conn = get_airflow_connection(
|
||||
auth_type="TLS",
|
||||
credential_pem_file="pem"
|
||||
)
|
||||
mock_get_connection.return_value = conn
|
||||
mock_open.return_value = StringIO('credential')
|
||||
hook = GrpcHook("grpc_default")
|
||||
mocked_channel = self.channel_mock.return_value
|
||||
mock_secure_channel.return_value = mocked_channel
|
||||
mock_credential_object = "test_credential_object"
|
||||
mock_channel_credentials.return_value = mock_credential_object
|
||||
|
||||
channel = hook.get_conn()
|
||||
expected_url = "test:8080"
|
||||
|
||||
mock_open.assert_called_once_with("pem")
|
||||
mock_channel_credentials.assert_called_once_with('credential')
|
||||
mock_secure_channel.assert_called_once_with(
|
||||
expected_url,
|
||||
mock_credential_object
|
||||
)
|
||||
self.assertEquals(channel, mocked_channel)
|
||||
|
||||
@mock.patch('airflow.hooks.base_hook.BaseHook.get_connection')
|
||||
@mock.patch('google.auth.jwt.OnDemandCredentials.from_signing_credentials')
|
||||
@mock.patch('google.auth.default')
|
||||
@mock.patch('google.auth.transport.grpc.secure_authorized_channel')
|
||||
def test_connection_with_jwt(self,
|
||||
mock_secure_channel,
|
||||
mock_google_default_auth,
|
||||
mock_google_cred,
|
||||
mock_get_connection):
|
||||
conn = get_airflow_connection(
|
||||
auth_type="JWT_GOOGLE"
|
||||
)
|
||||
mock_get_connection.return_value = conn
|
||||
hook = GrpcHook("grpc_default")
|
||||
mocked_channel = self.channel_mock.return_value
|
||||
mock_secure_channel.return_value = mocked_channel
|
||||
mock_credential_object = "test_credential_object"
|
||||
mock_google_default_auth.return_value = (mock_credential_object, "")
|
||||
mock_google_cred.return_value = mock_credential_object
|
||||
|
||||
channel = hook.get_conn()
|
||||
expected_url = "test:8080"
|
||||
|
||||
mock_google_cred.assert_called_once_with(mock_credential_object)
|
||||
mock_secure_channel.assert_called_once_with(
|
||||
mock_credential_object,
|
||||
None,
|
||||
expected_url
|
||||
)
|
||||
self.assertEquals(channel, mocked_channel)
|
||||
|
||||
@mock.patch('airflow.hooks.base_hook.BaseHook.get_connection')
|
||||
@mock.patch('google.auth.transport.requests.Request')
|
||||
@mock.patch('google.auth.default')
|
||||
@mock.patch('google.auth.transport.grpc.secure_authorized_channel')
|
||||
def test_connection_with_google_oauth(self,
|
||||
mock_secure_channel,
|
||||
mock_google_default_auth,
|
||||
mock_google_auth_request,
|
||||
mock_get_connection):
|
||||
conn = get_airflow_connection(
|
||||
auth_type="OATH_GOOGLE",
|
||||
scopes="grpc,gcs"
|
||||
)
|
||||
mock_get_connection.return_value = conn
|
||||
hook = GrpcHook("grpc_default")
|
||||
mocked_channel = self.channel_mock.return_value
|
||||
mock_secure_channel.return_value = mocked_channel
|
||||
mock_credential_object = "test_credential_object"
|
||||
mock_google_default_auth.return_value = (mock_credential_object, "")
|
||||
mock_google_auth_request.return_value = "request"
|
||||
|
||||
channel = hook.get_conn()
|
||||
expected_url = "test:8080"
|
||||
|
||||
mock_google_default_auth.assert_called_once_with(scopes=[u"grpc", u"gcs"])
|
||||
mock_secure_channel.assert_called_once_with(
|
||||
mock_credential_object,
|
||||
"request",
|
||||
expected_url
|
||||
)
|
||||
self.assertEquals(channel, mocked_channel)
|
||||
|
||||
@mock.patch('airflow.hooks.base_hook.BaseHook.get_connection')
|
||||
def test_custom_connection(self, mock_get_connection):
|
||||
conn = get_airflow_connection("CUSTOM")
|
||||
mock_get_connection.return_value = conn
|
||||
mocked_channel = self.channel_mock.return_value
|
||||
hook = GrpcHook("grpc_default", custom_connection_func=self.custom_conn_func)
|
||||
|
||||
channel = hook.get_conn()
|
||||
|
||||
self.assertEquals(channel, mocked_channel)
|
||||
|
||||
@mock.patch('airflow.hooks.base_hook.BaseHook.get_connection')
|
||||
def test_custom_connection_with_no_connection_func(self, mock_get_connection):
|
||||
conn = get_airflow_connection("CUSTOM")
|
||||
mock_get_connection.return_value = conn
|
||||
hook = GrpcHook("grpc_default")
|
||||
|
||||
with self.assertRaises(AirflowConfigException):
|
||||
hook.get_conn()
|
||||
|
||||
@mock.patch('airflow.hooks.base_hook.BaseHook.get_connection')
|
||||
def test_connection_type_not_supported(self, mock_get_connection):
|
||||
conn = get_airflow_connection("NOT_SUPPORT")
|
||||
mock_get_connection.return_value = conn
|
||||
hook = GrpcHook("grpc_default")
|
||||
|
||||
with self.assertRaises(AirflowConfigException):
|
||||
hook.get_conn()
|
||||
|
||||
@mock.patch('grpc.intercept_channel')
|
||||
@mock.patch('airflow.hooks.base_hook.BaseHook.get_connection')
|
||||
@mock.patch('grpc.insecure_channel')
|
||||
def test_connection_with_interceptors(self,
|
||||
mock_insecure_channel,
|
||||
mock_get_connection,
|
||||
mock_intercept_channel):
|
||||
conn = get_airflow_connection()
|
||||
mock_get_connection.return_value = conn
|
||||
mocked_channel = self.channel_mock.return_value
|
||||
hook = GrpcHook("grpc_default", interceptors=["test1"])
|
||||
mock_insecure_channel.return_value = mocked_channel
|
||||
mock_intercept_channel.return_value = mocked_channel
|
||||
|
||||
channel = hook.get_conn()
|
||||
|
||||
self.assertEquals(channel, mocked_channel)
|
||||
mock_intercept_channel.assert_called_once_with(mocked_channel, "test1")
|
||||
|
||||
@mock.patch('airflow.hooks.base_hook.BaseHook.get_connection')
|
||||
@mock.patch('airflow.contrib.hooks.grpc_hook.GrpcHook.get_conn')
|
||||
def test_simple_run(self, mock_get_conn, mock_get_connection):
|
||||
conn = get_airflow_connection()
|
||||
mock_get_connection.return_value = conn
|
||||
mocked_channel = mock.Mock()
|
||||
mocked_channel.__enter__ = mock.Mock(return_value=(mock.Mock(), None))
|
||||
mocked_channel.__exit__ = mock.Mock(return_value=None)
|
||||
hook = GrpcHook("grpc_default")
|
||||
mock_get_conn.return_value = mocked_channel
|
||||
|
||||
response = hook.run(StubClass, "single_call", data={'data': 'hello'})
|
||||
|
||||
self.assertEquals(next(response), "hello")
|
||||
|
||||
@mock.patch('airflow.hooks.base_hook.BaseHook.get_connection')
|
||||
@mock.patch('airflow.contrib.hooks.grpc_hook.GrpcHook.get_conn')
|
||||
def test_stream_run(self, mock_get_conn, mock_get_connection):
|
||||
conn = get_airflow_connection()
|
||||
mock_get_connection.return_value = conn
|
||||
mocked_channel = mock.Mock()
|
||||
mocked_channel.__enter__ = mock.Mock(return_value=(mock.Mock(), None))
|
||||
mocked_channel.__exit__ = mock.Mock(return_value=None)
|
||||
hook = GrpcHook("grpc_default")
|
||||
mock_get_conn.return_value = mocked_channel
|
||||
|
||||
response = hook.run(StubClass, "stream_call", data={'data': ['hello!', "hi"]})
|
||||
|
||||
self.assertEquals(next(response), ["streaming", "call"])
|
Загрузка…
Ссылка в новой задаче