diff --git a/airflow/hooks/http_hook.py b/airflow/hooks/http_hook.py index 3a61fab7cf..4f33627f87 100644 --- a/airflow/hooks/http_hook.py +++ b/airflow/hooks/http_hook.py @@ -81,8 +81,8 @@ class HttpHook(BaseHook): return session - def run(self, endpoint, data=None, headers=None, extra_options=None): - """ + def run(self, endpoint, data=None, headers=None, extra_options=None, **request_kwargs): + r""" Performs the request :param endpoint: the endpoint to be called i.e. resource/v1/query? @@ -95,6 +95,8 @@ class HttpHook(BaseHook): i.e. {'check_response': False} to avoid checking raising exceptions on non 2XX or 3XX status codes :type extra_options: dict + :param \**request_kwargs: Additional kwargs to pass when creating a request. + For example, ``run(json=obj)`` is passed as ``requests.Request(json=obj)`` """ extra_options = extra_options or {} @@ -112,18 +114,21 @@ class HttpHook(BaseHook): req = requests.Request(self.method, url, params=data, - headers=headers) + headers=headers, + **request_kwargs) elif self.method == 'HEAD': # HEAD doesn't use params req = requests.Request(self.method, url, - headers=headers) + headers=headers, + **request_kwargs) else: # Others use data req = requests.Request(self.method, url, data=data, - headers=headers) + headers=headers, + **request_kwargs) prepped_request = session.prepare_request(req) self.log.info("Sending '%s' to url: %s", self.method, url) diff --git a/tests/hooks/test_http_hook.py b/tests/hooks/test_http_hook.py index 41e5f9c31b..638a6b3742 100644 --- a/tests/hooks/test_http_hook.py +++ b/tests/hooks/test_http_hook.py @@ -23,6 +23,7 @@ import mock import requests import requests_mock import tenacity +from parameterized import parameterized from airflow.exceptions import AirflowException from airflow.hooks.http_hook import HttpHook @@ -334,5 +335,29 @@ class TestHttpHook(unittest.TestCase): hook.get_conn({}) self.assertEqual(hook.base_url, 'http://') + @parameterized.expand([ + 'GET', + 'POST', + ]) + @requests_mock.mock() + def test_json_request(self, method, mock_requests): + obj1 = {'a': 1, 'b': 'abc', 'c': [1, 2, {"d": 10}]} + + def match_obj1(request): + return request.json() == obj1 + + mock_requests.request( + method=method, + url='//test:8080/v1/test', + additional_matcher=match_obj1 + ) + + with mock.patch( + 'airflow.hooks.base_hook.BaseHook.get_connection', + side_effect=get_airflow_connection + ): + # will raise NoMockAddress exception if obj1 != request.json() + HttpHook(method=method).run('v1/test', json=obj1) + send_email_test = mock.Mock()