339 строки
12 KiB
Python
339 строки
12 KiB
Python
# -*- coding: utf-8 -*-
|
|
#
|
|
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
import json
|
|
import unittest
|
|
|
|
import requests
|
|
import requests_mock
|
|
import tenacity
|
|
|
|
from airflow.exceptions import AirflowException
|
|
from airflow.hooks.http_hook import HttpHook
|
|
from airflow.models import Connection
|
|
from tests.compat import mock
|
|
|
|
|
|
def get_airflow_connection(unused_conn_id=None):
|
|
return Connection(
|
|
conn_id='http_default',
|
|
conn_type='http',
|
|
host='test:8080/',
|
|
extra='{"bareer": "test"}'
|
|
)
|
|
|
|
|
|
def get_airflow_connection_with_port(unused_conn_id=None):
|
|
return Connection(
|
|
conn_id='http_default',
|
|
conn_type='http',
|
|
host='test.com',
|
|
port=1234
|
|
)
|
|
|
|
|
|
class TestHttpHook(unittest.TestCase):
|
|
"""Test get, post and raise_for_status"""
|
|
|
|
def setUp(self):
|
|
session = requests.Session()
|
|
adapter = requests_mock.Adapter()
|
|
session.mount('mock', adapter)
|
|
self.get_hook = HttpHook(method='GET')
|
|
self.get_lowercase_hook = HttpHook(method='get')
|
|
self.post_hook = HttpHook(method='POST')
|
|
|
|
@requests_mock.mock()
|
|
def test_raise_for_status_with_200(self, m):
|
|
|
|
m.get(
|
|
'http://test:8080/v1/test',
|
|
status_code=200,
|
|
text='{"status":{"status": 200}}',
|
|
reason='OK'
|
|
)
|
|
with mock.patch(
|
|
'airflow.hooks.base_hook.BaseHook.get_connection',
|
|
side_effect=get_airflow_connection
|
|
):
|
|
resp = self.get_hook.run('v1/test')
|
|
self.assertEqual(resp.text, '{"status":{"status": 200}}')
|
|
|
|
@requests_mock.mock()
|
|
@mock.patch('requests.Session')
|
|
@mock.patch('requests.Request')
|
|
def test_get_request_with_port(self, mock_requests, request_mock, mock_session):
|
|
from requests.exceptions import MissingSchema
|
|
|
|
with mock.patch(
|
|
'airflow.hooks.base_hook.BaseHook.get_connection',
|
|
side_effect=get_airflow_connection_with_port
|
|
):
|
|
expected_url = 'http://test.com:1234/some/endpoint'
|
|
for endpoint in ['some/endpoint', '/some/endpoint']:
|
|
|
|
try:
|
|
self.get_hook.run(endpoint)
|
|
except MissingSchema:
|
|
pass
|
|
|
|
request_mock.assert_called_once_with(
|
|
mock.ANY,
|
|
expected_url,
|
|
headers=mock.ANY,
|
|
params=mock.ANY
|
|
)
|
|
|
|
request_mock.reset_mock()
|
|
|
|
@requests_mock.mock()
|
|
def test_get_request_do_not_raise_for_status_if_check_response_is_false(self, m):
|
|
|
|
m.get(
|
|
'http://test:8080/v1/test',
|
|
status_code=404,
|
|
text='{"status":{"status": 404}}',
|
|
reason='Bad request'
|
|
)
|
|
|
|
with mock.patch(
|
|
'airflow.hooks.base_hook.BaseHook.get_connection',
|
|
side_effect=get_airflow_connection
|
|
):
|
|
resp = self.get_hook.run('v1/test', extra_options={'check_response': False})
|
|
self.assertEqual(resp.text, '{"status":{"status": 404}}')
|
|
|
|
@requests_mock.mock()
|
|
def test_hook_contains_header_from_extra_field(self, mock_requests):
|
|
with mock.patch(
|
|
'airflow.hooks.base_hook.BaseHook.get_connection',
|
|
side_effect=get_airflow_connection
|
|
):
|
|
expected_conn = get_airflow_connection()
|
|
conn = self.get_hook.get_conn()
|
|
self.assertDictContainsSubset(json.loads(expected_conn.extra), conn.headers)
|
|
self.assertEqual(conn.headers.get('bareer'), 'test')
|
|
|
|
@requests_mock.mock()
|
|
@mock.patch('requests.Request')
|
|
def test_hook_with_method_in_lowercase(self, mock_requests, request_mock):
|
|
from requests.exceptions import MissingSchema, InvalidURL
|
|
with mock.patch(
|
|
'airflow.hooks.base_hook.BaseHook.get_connection',
|
|
side_effect=get_airflow_connection_with_port
|
|
):
|
|
data = "test params"
|
|
try:
|
|
self.get_lowercase_hook.run('v1/test', data=data)
|
|
except (MissingSchema, InvalidURL):
|
|
pass
|
|
request_mock.assert_called_once_with(
|
|
mock.ANY,
|
|
mock.ANY,
|
|
headers=mock.ANY,
|
|
params=data
|
|
)
|
|
|
|
@requests_mock.mock()
|
|
def test_hook_uses_provided_header(self, mock_requests):
|
|
conn = self.get_hook.get_conn(headers={"bareer": "newT0k3n"})
|
|
self.assertEqual(conn.headers.get('bareer'), "newT0k3n")
|
|
|
|
@requests_mock.mock()
|
|
def test_hook_has_no_header_from_extra(self, mock_requests):
|
|
conn = self.get_hook.get_conn()
|
|
self.assertIsNone(conn.headers.get('bareer'))
|
|
|
|
@requests_mock.mock()
|
|
def test_hooks_header_from_extra_is_overridden(self, mock_requests):
|
|
with mock.patch(
|
|
'airflow.hooks.base_hook.BaseHook.get_connection',
|
|
side_effect=get_airflow_connection
|
|
):
|
|
conn = self.get_hook.get_conn(headers={"bareer": "newT0k3n"})
|
|
self.assertEqual(conn.headers.get('bareer'), 'newT0k3n')
|
|
|
|
@requests_mock.mock()
|
|
def test_post_request(self, mock_requests):
|
|
mock_requests.post(
|
|
'http://test:8080/v1/test',
|
|
status_code=200,
|
|
text='{"status":{"status": 200}}',
|
|
reason='OK'
|
|
)
|
|
|
|
with mock.patch(
|
|
'airflow.hooks.base_hook.BaseHook.get_connection',
|
|
side_effect=get_airflow_connection
|
|
):
|
|
resp = self.post_hook.run('v1/test')
|
|
self.assertEqual(resp.status_code, 200)
|
|
|
|
@requests_mock.mock()
|
|
def test_post_request_with_error_code(self, mock_requests):
|
|
mock_requests.post(
|
|
'http://test:8080/v1/test',
|
|
status_code=418,
|
|
text='{"status":{"status": 418}}',
|
|
reason='I\'m a teapot'
|
|
)
|
|
|
|
with mock.patch(
|
|
'airflow.hooks.base_hook.BaseHook.get_connection',
|
|
side_effect=get_airflow_connection
|
|
):
|
|
with self.assertRaises(AirflowException):
|
|
self.post_hook.run('v1/test')
|
|
|
|
@requests_mock.mock()
|
|
def test_post_request_do_not_raise_for_status_if_check_response_is_false(self, mock_requests):
|
|
mock_requests.post(
|
|
'http://test:8080/v1/test',
|
|
status_code=418,
|
|
text='{"status":{"status": 418}}',
|
|
reason='I\'m a teapot'
|
|
)
|
|
|
|
with mock.patch(
|
|
'airflow.hooks.base_hook.BaseHook.get_connection',
|
|
side_effect=get_airflow_connection
|
|
):
|
|
resp = self.post_hook.run('v1/test', extra_options={'check_response': False})
|
|
self.assertEqual(resp.status_code, 418)
|
|
|
|
@mock.patch('airflow.hooks.http_hook.requests.Session')
|
|
def test_retry_on_conn_error(self, mocked_session):
|
|
|
|
retry_args = dict(
|
|
wait=tenacity.wait_none(),
|
|
stop=tenacity.stop_after_attempt(7),
|
|
retry=tenacity.retry_if_exception_type(
|
|
requests.exceptions.ConnectionError
|
|
)
|
|
)
|
|
|
|
def send_and_raise(unused_request, **kwargs):
|
|
raise requests.exceptions.ConnectionError
|
|
|
|
mocked_session().send.side_effect = send_and_raise
|
|
# The job failed for some reason
|
|
with self.assertRaises(tenacity.RetryError):
|
|
self.get_hook.run_with_advanced_retry(
|
|
endpoint='v1/test',
|
|
_retry_args=retry_args
|
|
)
|
|
self.assertEqual(
|
|
self.get_hook._retry_obj.stop.max_attempt_number + 1,
|
|
mocked_session.call_count
|
|
)
|
|
|
|
@requests_mock.mock()
|
|
def test_run_with_advanced_retry(self, m):
|
|
|
|
m.get(
|
|
'http://test:8080/v1/test',
|
|
status_code=200,
|
|
reason='OK'
|
|
)
|
|
|
|
retry_args = dict(
|
|
wait=tenacity.wait_none(),
|
|
stop=tenacity.stop_after_attempt(3),
|
|
retry=tenacity.retry_if_exception_type(Exception),
|
|
reraise=True
|
|
)
|
|
with mock.patch(
|
|
'airflow.hooks.base_hook.BaseHook.get_connection',
|
|
side_effect=get_airflow_connection
|
|
):
|
|
response = self.get_hook.run_with_advanced_retry(
|
|
endpoint='v1/test',
|
|
_retry_args=retry_args
|
|
)
|
|
self.assertIsInstance(response, requests.Response)
|
|
|
|
def test_header_from_extra_and_run_method_are_merged(self):
|
|
|
|
def run_and_return(unused_session, prepped_request, unused_extra_options, **kwargs):
|
|
return prepped_request
|
|
|
|
# The job failed for some reason
|
|
with mock.patch(
|
|
'airflow.hooks.http_hook.HttpHook.run_and_check',
|
|
side_effect=run_and_return
|
|
):
|
|
with mock.patch(
|
|
'airflow.hooks.base_hook.BaseHook.get_connection',
|
|
side_effect=get_airflow_connection
|
|
):
|
|
pr = self.get_hook.run('v1/test', headers={'some_other_header': 'test'})
|
|
actual = dict(pr.headers)
|
|
self.assertEqual(actual.get('bareer'), 'test')
|
|
self.assertEqual(actual.get('some_other_header'), 'test')
|
|
|
|
@mock.patch('airflow.hooks.http_hook.HttpHook.get_connection')
|
|
def test_http_connection(self, mock_get_connection):
|
|
c = Connection(conn_id='http_default', conn_type='http',
|
|
host='localhost', schema='http')
|
|
mock_get_connection.return_value = c
|
|
hook = HttpHook()
|
|
hook.get_conn({})
|
|
self.assertEqual(hook.base_url, 'http://localhost')
|
|
|
|
@mock.patch('airflow.hooks.http_hook.HttpHook.get_connection')
|
|
def test_https_connection(self, mock_get_connection):
|
|
c = Connection(conn_id='http_default', conn_type='http',
|
|
host='localhost', schema='https')
|
|
mock_get_connection.return_value = c
|
|
hook = HttpHook()
|
|
hook.get_conn({})
|
|
self.assertEqual(hook.base_url, 'https://localhost')
|
|
|
|
@mock.patch('airflow.hooks.http_hook.HttpHook.get_connection')
|
|
def test_host_encoded_http_connection(self, mock_get_connection):
|
|
c = Connection(conn_id='http_default', conn_type='http',
|
|
host='http://localhost')
|
|
mock_get_connection.return_value = c
|
|
hook = HttpHook()
|
|
hook.get_conn({})
|
|
self.assertEqual(hook.base_url, 'http://localhost')
|
|
|
|
@mock.patch('airflow.hooks.http_hook.HttpHook.get_connection')
|
|
def test_host_encoded_https_connection(self, mock_get_connection):
|
|
c = Connection(conn_id='http_default', conn_type='http',
|
|
host='https://localhost')
|
|
mock_get_connection.return_value = c
|
|
hook = HttpHook()
|
|
hook.get_conn({})
|
|
self.assertEqual(hook.base_url, 'https://localhost')
|
|
|
|
def test_method_converted_to_uppercase_when_created_in_lowercase(self):
|
|
self.assertEqual(self.get_lowercase_hook.method, 'GET')
|
|
|
|
@mock.patch('airflow.hooks.http_hook.HttpHook.get_connection')
|
|
def test_connection_without_host(self, mock_get_connection):
|
|
c = Connection(conn_id='http_default', conn_type='http')
|
|
mock_get_connection.return_value = c
|
|
|
|
hook = HttpHook()
|
|
hook.get_conn({})
|
|
self.assertEqual(hook.base_url, 'http://')
|
|
|
|
|
|
send_email_test = mock.Mock()
|