[AIRFLOW-4965] Handle quote exceptions in GCP operators (#6305)

This commit is contained in:
Kamil Breguła 2019-10-22 08:29:33 +02:00 коммит произвёл Jarek Potiuk
Родитель b08f86290e
Коммит 417754a14a
8 изменённых файлов: 149 добавлений и 3 удалений

Просмотреть файл

@ -23,6 +23,7 @@ This module contains a Google Cloud API base hook.
import functools
import json
import logging
import os
import tempfile
from contextlib import contextmanager
@ -32,17 +33,69 @@ import google.auth
import google.oauth2.service_account
import google_auth_httplib2
import httplib2
from google.api_core.exceptions import AlreadyExists, GoogleAPICallError, RetryError
import tenacity
from google.api_core.exceptions import (
AlreadyExists, Forbidden, GoogleAPICallError, ResourceExhausted, RetryError,
)
from google.api_core.gapic_v1.client_info import ClientInfo
from google.auth.environment_vars import CREDENTIALS
from googleapiclient.errors import HttpError
from airflow import version
from airflow import LoggingMixin, version
from airflow.exceptions import AirflowException
from airflow.hooks.base_hook import BaseHook
logger = LoggingMixin().log
_DEFAULT_SCOPES = ('https://www.googleapis.com/auth/cloud-platform',) # type: Sequence[str]
# Constants used by the mechanism of repeating requests in reaction to exceeding the temporary quota.
INVALID_KEYS = [
'DefaultRequestsPerMinutePerProject',
'DefaultRequestsPerMinutePerUser',
'RequestsPerMinutePerProject',
"Resource has been exhausted (e.g. check quota).",
]
INVALID_REASONS = [
'userRateLimitExceeded',
]
def is_soft_quota_exception(exception: Exception):
"""
API for Google services does not have a standardized way to report quota violation errors.
The function has been adapted by trial and error to the following services:
* Google Translate
* Google Vision
* Google Text-to-Speech
* Google Speech-to-Text
* Google Natural Language
* Google Video Intelligence
"""
if isinstance(exception, Forbidden):
return any(
reason in error["reason"]
for reason in INVALID_REASONS
for error in exception.errors
)
if isinstance(exception, ResourceExhausted):
return any(
key in error.details()
for key in INVALID_KEYS
for error in exception.errors
)
return False
class retry_if_temporary_quota(tenacity.retry_if_exception): # pylint: disable=invalid-name
"""Retries if there was an exception for exceeding the temporary quote limit."""
def __init__(self):
super().__init__(is_soft_quota_exception)
RT = TypeVar('RT') # pylint: disable=invalid-name
@ -221,6 +274,25 @@ class GoogleCloudBaseHook(BaseHook):
return [s.strip() for s in scope_value.split(',')] \
if scope_value else _DEFAULT_SCOPES
@staticmethod
def quota_retry(*args, **kwargs) -> Callable:
"""
A decorator who provides a mechanism to repeat requests in response to exceeding a temporary quote
limit.
"""
def decorator(fun: Callable):
default_kwargs = {
'wait': tenacity.wait_exponential(multiplier=1, max=100),
'retry': retry_if_temporary_quota(),
'before': tenacity.before_log(logger, logging.DEBUG),
'after': tenacity.after_log(logger, logging.DEBUG),
}
default_kwargs.update(**kwargs)
return tenacity.retry(
*args, **default_kwargs
)(fun)
return decorator
@staticmethod
def catch_http_exception(func: Callable[..., RT]) -> Callable[..., RT]:
"""

Просмотреть файл

@ -63,6 +63,7 @@ class CloudNaturalLanguageHook(GoogleCloudBaseHook):
return self._conn
@GoogleCloudBaseHook.catch_http_exception
@GoogleCloudBaseHook.quota_retry()
def analyze_entities(
self,
document: Union[Dict, Document],
@ -97,6 +98,7 @@ class CloudNaturalLanguageHook(GoogleCloudBaseHook):
)
@GoogleCloudBaseHook.catch_http_exception
@GoogleCloudBaseHook.quota_retry()
def analyze_entity_sentiment(
self,
document: Union[Dict, Document],
@ -131,6 +133,7 @@ class CloudNaturalLanguageHook(GoogleCloudBaseHook):
)
@GoogleCloudBaseHook.catch_http_exception
@GoogleCloudBaseHook.quota_retry()
def analyze_sentiment(
self,
document: Union[Dict, Document],
@ -164,6 +167,7 @@ class CloudNaturalLanguageHook(GoogleCloudBaseHook):
)
@GoogleCloudBaseHook.catch_http_exception
@GoogleCloudBaseHook.quota_retry()
def analyze_syntax(
self,
document: Union[Dict, Document],
@ -198,6 +202,7 @@ class CloudNaturalLanguageHook(GoogleCloudBaseHook):
)
@GoogleCloudBaseHook.catch_http_exception
@GoogleCloudBaseHook.quota_retry()
def annotate_text(
self,
document: Union[Dict, Document],
@ -241,6 +246,7 @@ class CloudNaturalLanguageHook(GoogleCloudBaseHook):
)
@GoogleCloudBaseHook.catch_http_exception
@GoogleCloudBaseHook.quota_retry()
def classify_text(
self,
document: Union[Dict, Document],

Просмотреть файл

@ -55,6 +55,7 @@ class GCPSpeechToTextHook(GoogleCloudBaseHook):
self._client = SpeechClient(credentials=self._get_credentials(), client_info=self.client_info)
return self._client
@GoogleCloudBaseHook.quota_retry()
def recognize_speech(
self,
config: Union[Dict, RecognitionConfig],

Просмотреть файл

@ -64,6 +64,7 @@ class GCPTextToSpeechHook(GoogleCloudBaseHook):
return self._client
@GoogleCloudBaseHook.quota_retry()
def synthesize_speech(
self,
input_data: Union[Dict, SynthesisInput],

Просмотреть файл

@ -49,6 +49,7 @@ class CloudTranslateHook(GoogleCloudBaseHook):
self._client = Client(credentials=self._get_credentials(), client_info=self.client_info)
return self._client
@GoogleCloudBaseHook.quota_retry()
def translate(
self,
values: Union[str, List[str]],

Просмотреть файл

@ -61,6 +61,7 @@ class CloudVideoIntelligenceHook(GoogleCloudBaseHook):
)
return self._conn
@GoogleCloudBaseHook.quota_retry()
def annotate_video(
self,
input_uri: Optional[str] = None,

Просмотреть файл

@ -544,6 +544,7 @@ class CloudVisionHook(GoogleCloudBaseHook):
return MessageToDict(response)
@GoogleCloudBaseHook.catch_http_exception
@GoogleCloudBaseHook.quota_retry()
def batch_annotate_images(
self,
requests: Union[List[dict], List[AnnotateImageRequest]],
@ -567,6 +568,7 @@ class CloudVisionHook(GoogleCloudBaseHook):
return MessageToDict(response)
@GoogleCloudBaseHook.catch_http_exception
@GoogleCloudBaseHook.quota_retry()
def text_detection(
self,
image: Union[Dict, Image],
@ -597,6 +599,7 @@ class CloudVisionHook(GoogleCloudBaseHook):
return response
@GoogleCloudBaseHook.catch_http_exception
@GoogleCloudBaseHook.quota_retry()
def document_text_detection(
self,
image: Union[Dict, Image],
@ -627,6 +630,7 @@ class CloudVisionHook(GoogleCloudBaseHook):
return response
@GoogleCloudBaseHook.catch_http_exception
@GoogleCloudBaseHook.quota_retry()
def label_detection(
self,
image: Union[Dict, Image],
@ -657,6 +661,7 @@ class CloudVisionHook(GoogleCloudBaseHook):
return response
@GoogleCloudBaseHook.catch_http_exception
@GoogleCloudBaseHook.quota_retry()
def safe_search_detection(
self,
image: Union[Dict, Image],

Просмотреть файл

@ -23,10 +23,11 @@ import unittest
from io import StringIO
import google.auth
import tenacity
from google.api_core.exceptions import AlreadyExists, RetryError
from google.auth.environment_vars import CREDENTIALS
from google.auth.exceptions import GoogleAuthError
from google.cloud.exceptions import MovedPermanently
from google.cloud.exceptions import Forbidden, MovedPermanently
from googleapiclient.errors import HttpError
from parameterized import parameterized
@ -46,6 +47,64 @@ except GoogleAuthError:
MODULE_NAME = "airflow.gcp.hooks.base"
class NoForbiddenAfterCount:
"""Holds counter state for invoking a method several times in a row."""
def __init__(self, count, **kwargs):
self.counter = 0
self.count = count
self.kwargs = kwargs
def __call__(self):
"""
Raise an Forbidden until after count threshold has been crossed.
Then return True.
"""
if self.counter < self.count:
self.counter += 1
raise Forbidden(**self.kwargs)
return True
@hook.GoogleCloudBaseHook.quota_retry(wait=tenacity.wait_none())
def _retryable_test_with_temporary_quota_retry(thing):
return thing()
class QuotaRetryTestCase(unittest.TestCase): # ptlint: disable=invalid-name
def test_do_nothing_on_non_error(self):
result = _retryable_test_with_temporary_quota_retry(lambda: 42)
self.assertTrue(result, 42)
def test_retry_on_exception(self):
message = "POST https://translation.googleapis.com/language/translate/v2: User Rate Limit Exceeded"
errors = [
{
'message': 'User Rate Limit Exceeded',
'domain': 'usageLimits',
'reason': 'userRateLimitExceeded',
}
]
custom_fn = NoForbiddenAfterCount(
count=5,
message=message,
errors=errors
)
_retryable_test_with_temporary_quota_retry(custom_fn)
self.assertEqual(5, custom_fn.counter)
def test_raise_exception_on_non_quota_exception(self):
with self.assertRaisesRegex(Forbidden, "Daily Limit Exceeded"):
message = "POST https://translation.googleapis.com/language/translate/v2: Daily Limit Exceeded"
errors = [
{'message': 'Daily Limit Exceeded', 'domain': 'usageLimits', 'reason': 'dailyLimitExceeded'}
]
_retryable_test_with_temporary_quota_retry(
NoForbiddenAfterCount(5, message=message, errors=errors)
)
class TestCatchHttpException(unittest.TestCase):
# pylint:disable=no-method-argument,unused-argument
@parameterized.expand(