[AIRFLOW-4965] Handle quote exceptions in GCP operators (#6305)
This commit is contained in:
Родитель
b08f86290e
Коммит
417754a14a
|
@ -23,6 +23,7 @@ This module contains a Google Cloud API base hook.
|
|||
|
||||
import functools
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from contextlib import contextmanager
|
||||
|
@ -32,17 +33,69 @@ import google.auth
|
|||
import google.oauth2.service_account
|
||||
import google_auth_httplib2
|
||||
import httplib2
|
||||
from google.api_core.exceptions import AlreadyExists, GoogleAPICallError, RetryError
|
||||
import tenacity
|
||||
from google.api_core.exceptions import (
|
||||
AlreadyExists, Forbidden, GoogleAPICallError, ResourceExhausted, RetryError,
|
||||
)
|
||||
from google.api_core.gapic_v1.client_info import ClientInfo
|
||||
from google.auth.environment_vars import CREDENTIALS
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
from airflow import version
|
||||
from airflow import LoggingMixin, version
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.hooks.base_hook import BaseHook
|
||||
|
||||
logger = LoggingMixin().log
|
||||
|
||||
_DEFAULT_SCOPES = ('https://www.googleapis.com/auth/cloud-platform',) # type: Sequence[str]
|
||||
|
||||
# Constants used by the mechanism of repeating requests in reaction to exceeding the temporary quota.
|
||||
INVALID_KEYS = [
|
||||
'DefaultRequestsPerMinutePerProject',
|
||||
'DefaultRequestsPerMinutePerUser',
|
||||
'RequestsPerMinutePerProject',
|
||||
"Resource has been exhausted (e.g. check quota).",
|
||||
]
|
||||
INVALID_REASONS = [
|
||||
'userRateLimitExceeded',
|
||||
]
|
||||
|
||||
|
||||
def is_soft_quota_exception(exception: Exception):
|
||||
"""
|
||||
API for Google services does not have a standardized way to report quota violation errors.
|
||||
The function has been adapted by trial and error to the following services:
|
||||
|
||||
* Google Translate
|
||||
* Google Vision
|
||||
* Google Text-to-Speech
|
||||
* Google Speech-to-Text
|
||||
* Google Natural Language
|
||||
* Google Video Intelligence
|
||||
"""
|
||||
if isinstance(exception, Forbidden):
|
||||
return any(
|
||||
reason in error["reason"]
|
||||
for reason in INVALID_REASONS
|
||||
for error in exception.errors
|
||||
)
|
||||
|
||||
if isinstance(exception, ResourceExhausted):
|
||||
return any(
|
||||
key in error.details()
|
||||
for key in INVALID_KEYS
|
||||
for error in exception.errors
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class retry_if_temporary_quota(tenacity.retry_if_exception): # pylint: disable=invalid-name
|
||||
"""Retries if there was an exception for exceeding the temporary quote limit."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(is_soft_quota_exception)
|
||||
|
||||
|
||||
RT = TypeVar('RT') # pylint: disable=invalid-name
|
||||
|
||||
|
@ -221,6 +274,25 @@ class GoogleCloudBaseHook(BaseHook):
|
|||
return [s.strip() for s in scope_value.split(',')] \
|
||||
if scope_value else _DEFAULT_SCOPES
|
||||
|
||||
@staticmethod
|
||||
def quota_retry(*args, **kwargs) -> Callable:
|
||||
"""
|
||||
A decorator who provides a mechanism to repeat requests in response to exceeding a temporary quote
|
||||
limit.
|
||||
"""
|
||||
def decorator(fun: Callable):
|
||||
default_kwargs = {
|
||||
'wait': tenacity.wait_exponential(multiplier=1, max=100),
|
||||
'retry': retry_if_temporary_quota(),
|
||||
'before': tenacity.before_log(logger, logging.DEBUG),
|
||||
'after': tenacity.after_log(logger, logging.DEBUG),
|
||||
}
|
||||
default_kwargs.update(**kwargs)
|
||||
return tenacity.retry(
|
||||
*args, **default_kwargs
|
||||
)(fun)
|
||||
return decorator
|
||||
|
||||
@staticmethod
|
||||
def catch_http_exception(func: Callable[..., RT]) -> Callable[..., RT]:
|
||||
"""
|
||||
|
|
|
@ -63,6 +63,7 @@ class CloudNaturalLanguageHook(GoogleCloudBaseHook):
|
|||
return self._conn
|
||||
|
||||
@GoogleCloudBaseHook.catch_http_exception
|
||||
@GoogleCloudBaseHook.quota_retry()
|
||||
def analyze_entities(
|
||||
self,
|
||||
document: Union[Dict, Document],
|
||||
|
@ -97,6 +98,7 @@ class CloudNaturalLanguageHook(GoogleCloudBaseHook):
|
|||
)
|
||||
|
||||
@GoogleCloudBaseHook.catch_http_exception
|
||||
@GoogleCloudBaseHook.quota_retry()
|
||||
def analyze_entity_sentiment(
|
||||
self,
|
||||
document: Union[Dict, Document],
|
||||
|
@ -131,6 +133,7 @@ class CloudNaturalLanguageHook(GoogleCloudBaseHook):
|
|||
)
|
||||
|
||||
@GoogleCloudBaseHook.catch_http_exception
|
||||
@GoogleCloudBaseHook.quota_retry()
|
||||
def analyze_sentiment(
|
||||
self,
|
||||
document: Union[Dict, Document],
|
||||
|
@ -164,6 +167,7 @@ class CloudNaturalLanguageHook(GoogleCloudBaseHook):
|
|||
)
|
||||
|
||||
@GoogleCloudBaseHook.catch_http_exception
|
||||
@GoogleCloudBaseHook.quota_retry()
|
||||
def analyze_syntax(
|
||||
self,
|
||||
document: Union[Dict, Document],
|
||||
|
@ -198,6 +202,7 @@ class CloudNaturalLanguageHook(GoogleCloudBaseHook):
|
|||
)
|
||||
|
||||
@GoogleCloudBaseHook.catch_http_exception
|
||||
@GoogleCloudBaseHook.quota_retry()
|
||||
def annotate_text(
|
||||
self,
|
||||
document: Union[Dict, Document],
|
||||
|
@ -241,6 +246,7 @@ class CloudNaturalLanguageHook(GoogleCloudBaseHook):
|
|||
)
|
||||
|
||||
@GoogleCloudBaseHook.catch_http_exception
|
||||
@GoogleCloudBaseHook.quota_retry()
|
||||
def classify_text(
|
||||
self,
|
||||
document: Union[Dict, Document],
|
||||
|
|
|
@ -55,6 +55,7 @@ class GCPSpeechToTextHook(GoogleCloudBaseHook):
|
|||
self._client = SpeechClient(credentials=self._get_credentials(), client_info=self.client_info)
|
||||
return self._client
|
||||
|
||||
@GoogleCloudBaseHook.quota_retry()
|
||||
def recognize_speech(
|
||||
self,
|
||||
config: Union[Dict, RecognitionConfig],
|
||||
|
|
|
@ -64,6 +64,7 @@ class GCPTextToSpeechHook(GoogleCloudBaseHook):
|
|||
|
||||
return self._client
|
||||
|
||||
@GoogleCloudBaseHook.quota_retry()
|
||||
def synthesize_speech(
|
||||
self,
|
||||
input_data: Union[Dict, SynthesisInput],
|
||||
|
|
|
@ -49,6 +49,7 @@ class CloudTranslateHook(GoogleCloudBaseHook):
|
|||
self._client = Client(credentials=self._get_credentials(), client_info=self.client_info)
|
||||
return self._client
|
||||
|
||||
@GoogleCloudBaseHook.quota_retry()
|
||||
def translate(
|
||||
self,
|
||||
values: Union[str, List[str]],
|
||||
|
|
|
@ -61,6 +61,7 @@ class CloudVideoIntelligenceHook(GoogleCloudBaseHook):
|
|||
)
|
||||
return self._conn
|
||||
|
||||
@GoogleCloudBaseHook.quota_retry()
|
||||
def annotate_video(
|
||||
self,
|
||||
input_uri: Optional[str] = None,
|
||||
|
|
|
@ -544,6 +544,7 @@ class CloudVisionHook(GoogleCloudBaseHook):
|
|||
return MessageToDict(response)
|
||||
|
||||
@GoogleCloudBaseHook.catch_http_exception
|
||||
@GoogleCloudBaseHook.quota_retry()
|
||||
def batch_annotate_images(
|
||||
self,
|
||||
requests: Union[List[dict], List[AnnotateImageRequest]],
|
||||
|
@ -567,6 +568,7 @@ class CloudVisionHook(GoogleCloudBaseHook):
|
|||
return MessageToDict(response)
|
||||
|
||||
@GoogleCloudBaseHook.catch_http_exception
|
||||
@GoogleCloudBaseHook.quota_retry()
|
||||
def text_detection(
|
||||
self,
|
||||
image: Union[Dict, Image],
|
||||
|
@ -597,6 +599,7 @@ class CloudVisionHook(GoogleCloudBaseHook):
|
|||
return response
|
||||
|
||||
@GoogleCloudBaseHook.catch_http_exception
|
||||
@GoogleCloudBaseHook.quota_retry()
|
||||
def document_text_detection(
|
||||
self,
|
||||
image: Union[Dict, Image],
|
||||
|
@ -627,6 +630,7 @@ class CloudVisionHook(GoogleCloudBaseHook):
|
|||
return response
|
||||
|
||||
@GoogleCloudBaseHook.catch_http_exception
|
||||
@GoogleCloudBaseHook.quota_retry()
|
||||
def label_detection(
|
||||
self,
|
||||
image: Union[Dict, Image],
|
||||
|
@ -657,6 +661,7 @@ class CloudVisionHook(GoogleCloudBaseHook):
|
|||
return response
|
||||
|
||||
@GoogleCloudBaseHook.catch_http_exception
|
||||
@GoogleCloudBaseHook.quota_retry()
|
||||
def safe_search_detection(
|
||||
self,
|
||||
image: Union[Dict, Image],
|
||||
|
|
|
@ -23,10 +23,11 @@ import unittest
|
|||
from io import StringIO
|
||||
|
||||
import google.auth
|
||||
import tenacity
|
||||
from google.api_core.exceptions import AlreadyExists, RetryError
|
||||
from google.auth.environment_vars import CREDENTIALS
|
||||
from google.auth.exceptions import GoogleAuthError
|
||||
from google.cloud.exceptions import MovedPermanently
|
||||
from google.cloud.exceptions import Forbidden, MovedPermanently
|
||||
from googleapiclient.errors import HttpError
|
||||
from parameterized import parameterized
|
||||
|
||||
|
@ -46,6 +47,64 @@ except GoogleAuthError:
|
|||
MODULE_NAME = "airflow.gcp.hooks.base"
|
||||
|
||||
|
||||
class NoForbiddenAfterCount:
|
||||
"""Holds counter state for invoking a method several times in a row."""
|
||||
|
||||
def __init__(self, count, **kwargs):
|
||||
self.counter = 0
|
||||
self.count = count
|
||||
self.kwargs = kwargs
|
||||
|
||||
def __call__(self):
|
||||
"""
|
||||
Raise an Forbidden until after count threshold has been crossed.
|
||||
Then return True.
|
||||
"""
|
||||
if self.counter < self.count:
|
||||
self.counter += 1
|
||||
raise Forbidden(**self.kwargs)
|
||||
return True
|
||||
|
||||
|
||||
@hook.GoogleCloudBaseHook.quota_retry(wait=tenacity.wait_none())
|
||||
def _retryable_test_with_temporary_quota_retry(thing):
|
||||
return thing()
|
||||
|
||||
|
||||
class QuotaRetryTestCase(unittest.TestCase): # ptlint: disable=invalid-name
|
||||
def test_do_nothing_on_non_error(self):
|
||||
result = _retryable_test_with_temporary_quota_retry(lambda: 42)
|
||||
self.assertTrue(result, 42)
|
||||
|
||||
def test_retry_on_exception(self):
|
||||
message = "POST https://translation.googleapis.com/language/translate/v2: User Rate Limit Exceeded"
|
||||
errors = [
|
||||
{
|
||||
'message': 'User Rate Limit Exceeded',
|
||||
'domain': 'usageLimits',
|
||||
'reason': 'userRateLimitExceeded',
|
||||
}
|
||||
]
|
||||
custom_fn = NoForbiddenAfterCount(
|
||||
count=5,
|
||||
message=message,
|
||||
errors=errors
|
||||
)
|
||||
_retryable_test_with_temporary_quota_retry(custom_fn)
|
||||
self.assertEqual(5, custom_fn.counter)
|
||||
|
||||
def test_raise_exception_on_non_quota_exception(self):
|
||||
with self.assertRaisesRegex(Forbidden, "Daily Limit Exceeded"):
|
||||
message = "POST https://translation.googleapis.com/language/translate/v2: Daily Limit Exceeded"
|
||||
errors = [
|
||||
{'message': 'Daily Limit Exceeded', 'domain': 'usageLimits', 'reason': 'dailyLimitExceeded'}
|
||||
]
|
||||
|
||||
_retryable_test_with_temporary_quota_retry(
|
||||
NoForbiddenAfterCount(5, message=message, errors=errors)
|
||||
)
|
||||
|
||||
|
||||
class TestCatchHttpException(unittest.TestCase):
|
||||
# pylint:disable=no-method-argument,unused-argument
|
||||
@parameterized.expand(
|
||||
|
|
Загрузка…
Ссылка в новой задаче