diff --git a/tests/webapp/api/test_throttling.py b/tests/webapp/api/test_throttling.py new file mode 100644 index 000000000..4d0ee3b05 --- /dev/null +++ b/tests/webapp/api/test_throttling.py @@ -0,0 +1,74 @@ +from rest_framework.response import Response +from rest_framework.test import APIRequestFactory +from rest_framework.views import APIView +from hawkrest import HawkAuthentication +from treeherder.webapp.api import throttling + + +class OauthKey1SecRateThrottle(throttling.OauthKeyThrottle): + THROTTLE_RATES = {'foo': '1/sec'} + + +class HawkClient1SecRateThrottle(throttling.HawkClientThrottle): + THROTTLE_RATES = {'foo': '1/sec'} + + +class MockView(APIView): + throttle_classes = (OauthKey1SecRateThrottle, HawkClient1SecRateThrottle,) + throttle_scope = 'foo' + + def get(self, request): + return Response('foo') + + +class MockReceiver(object): + parsed_header = {'id': 'my-client-id'} + + +def mock_authenticate(authentication_class, request): + request.META['hawk.receiver'] = MockReceiver() + + +factory = APIRequestFactory() + + +def test_no_throttle(): + request = factory.get('/') + + response = MockView.as_view()(request) + # first request ok + response.status_code == 200 + + for i in range(1): + response = MockView.as_view()(request) + # subsequent requests still ok + response.status_code == 200 + + +def test_oauth_key_throttle(): + request = factory.get('/', {'oauth_consumer_key': 'my-consumer-key'}) + + response = MockView.as_view()(request) + # first request ok + response.status_code == 200 + + for i in range(1): + response = MockView.as_view()(request) + # subsequent requests should get throttled + assert response.status_code == 429 + + +def test_hawk_client_throttle(monkeypatch): + + monkeypatch.setattr(HawkAuthentication, 'authenticate', mock_authenticate) + + request = factory.get('/') + response = MockView.as_view()(request) + + # first request, everything ok + response.status_code == 200 + + for i in range(1): + response = MockView.as_view()(request) + # subsequent requests should get throttled + assert response.status_code == 429 diff --git a/treeherder/settings/base.py b/treeherder/settings/base.py index 0a1e6c373..12558bb92 100644 --- a/treeherder/settings/base.py +++ b/treeherder/settings/base.py @@ -264,6 +264,7 @@ REST_FRAMEWORK = { 'EXCEPTION_HANDLER': 'treeherder.webapp.api.exceptions.exception_handler', 'DEFAULT_THROTTLE_CLASSES': ( 'treeherder.webapp.api.throttling.OauthKeyThrottle', + 'treeherder.webapp.api.throttling.HawkClientThrottle' ), 'DEFAULT_THROTTLE_RATES': { 'jobs': '220/minute', diff --git a/treeherder/webapp/api/throttling.py b/treeherder/webapp/api/throttling.py index 7a48d08a9..b76712f39 100644 --- a/treeherder/webapp/api/throttling.py +++ b/treeherder/webapp/api/throttling.py @@ -5,9 +5,11 @@ class OauthKeyThrottle(throttling.ScopedRateThrottle): def get_cache_key(self, request, view): """ - If `view.throttle_scope` is not set, don't apply this throttle. - Otherwise generate the unique cache key by concatenating the oauth key - with the '.throttle_scope` property of the view. + Returns a cache_key based on oauth_consumer_key. + + If `view.throttle_scope` is not set or oauth_consumer_key is not set, + don't apply this throttle. Otherwise generate the unique cache key by + concatenating the oauth key with the '.throttle_scope` property of the view. """ ident = request.GET.get('oauth_consumer_key', None) if not ident: @@ -16,3 +18,23 @@ class OauthKeyThrottle(throttling.ScopedRateThrottle): 'scope': self.scope, 'ident': ident } + + +class HawkClientThrottle(throttling.ScopedRateThrottle): + + def get_cache_key(self, request, view): + """ + Returns a cache_key based on the hawk Client ID. + + If `view.throttle_scope` is not set or request.META['hawk.receiver'] is not set, + don't apply this throttle. Otherwise generate the unique cache key by + concatenating the oauth key with the '.throttle_scope` property of the view. + """ + receiver = request.META.get('hawk.receiver') + if receiver is None: + return None + client_id = receiver.parsed_header['id'] + return self.cache_format % { + 'scope': self.scope, + 'ident': client_id + }