[AIRFLOW-4342] Use @cached_property instead of re-implementing it each time (#5126)
It's not many lines, but I just find this much clearer
This commit is contained in:
Родитель
a71d4b8613
Коммит
da024dded4
|
@ -18,12 +18,12 @@
|
|||
# under the License.
|
||||
from copy import deepcopy
|
||||
|
||||
from cached_property import cached_property
|
||||
from google.cloud.vision_v1 import ProductSearchClient, ImageAnnotatorClient
|
||||
from google.protobuf.json_format import MessageToDict
|
||||
|
||||
from airflow import AirflowException
|
||||
from airflow.contrib.hooks.gcp_api_base_hook import GoogleCloudBaseHook
|
||||
from airflow.utils.decorators import cached_property
|
||||
|
||||
|
||||
class NameDeterminer:
|
||||
|
|
|
@ -103,19 +103,3 @@ if 'BUILDING_AIRFLOW_DOCS' in os.environ:
|
|||
# flake8: noqa: F811
|
||||
# Monkey patch hook to get good function headers while building docs
|
||||
apply_defaults = lambda x: x
|
||||
|
||||
|
||||
class cached_property:
|
||||
"""
|
||||
A decorator creating a property, the value of which is calculated only once and cached for later use.
|
||||
"""
|
||||
def __init__(self, func):
|
||||
self.func = func
|
||||
self.__doc__ = getattr(func, '__doc__')
|
||||
|
||||
def __get__(self, instance, cls=None):
|
||||
if instance is None:
|
||||
return self
|
||||
result = self.func(instance)
|
||||
instance.__dict__[self.func.__name__] = result
|
||||
return result
|
||||
|
|
|
@ -18,6 +18,8 @@
|
|||
# under the License.
|
||||
import os
|
||||
|
||||
from cached_property import cached_property
|
||||
|
||||
from airflow import configuration
|
||||
from airflow.exceptions import AirflowException
|
||||
from airflow.utils.log.logging_mixin import LoggingMixin
|
||||
|
@ -39,7 +41,8 @@ class GCSTaskHandler(FileTaskHandler, LoggingMixin):
|
|||
self.closed = False
|
||||
self.upload_on_close = True
|
||||
|
||||
def _build_hook(self):
|
||||
@cached_property
|
||||
def hook(self):
|
||||
remote_conn_id = configuration.conf.get('core', 'REMOTE_LOG_CONN_ID')
|
||||
try:
|
||||
from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook
|
||||
|
@ -53,12 +56,6 @@ class GCSTaskHandler(FileTaskHandler, LoggingMixin):
|
|||
'and the GCS connection exists.', remote_conn_id, str(e)
|
||||
)
|
||||
|
||||
@property
|
||||
def hook(self):
|
||||
if self._hook is None:
|
||||
self._hook = self._build_hook()
|
||||
return self._hook
|
||||
|
||||
def set_context(self, ti):
|
||||
super(GCSTaskHandler, self).set_context(ti)
|
||||
# Log relative path is used to construct local and remote
|
||||
|
|
|
@ -18,6 +18,8 @@
|
|||
# under the License.
|
||||
import os
|
||||
|
||||
from cached_property import cached_property
|
||||
|
||||
from airflow import configuration
|
||||
from airflow.utils.log.logging_mixin import LoggingMixin
|
||||
from airflow.utils.log.file_task_handler import FileTaskHandler
|
||||
|
@ -37,7 +39,8 @@ class S3TaskHandler(FileTaskHandler, LoggingMixin):
|
|||
self.closed = False
|
||||
self.upload_on_close = True
|
||||
|
||||
def _build_hook(self):
|
||||
@cached_property
|
||||
def hook(self):
|
||||
remote_conn_id = configuration.conf.get('core', 'REMOTE_LOG_CONN_ID')
|
||||
try:
|
||||
from airflow.hooks.S3_hook import S3Hook
|
||||
|
@ -49,12 +52,6 @@ class S3TaskHandler(FileTaskHandler, LoggingMixin):
|
|||
'the S3 connection exists.', remote_conn_id
|
||||
)
|
||||
|
||||
@property
|
||||
def hook(self):
|
||||
if self._hook is None:
|
||||
self._hook = self._build_hook()
|
||||
return self._hook
|
||||
|
||||
def set_context(self, ti):
|
||||
super(S3TaskHandler, self).set_context(ti)
|
||||
# Local location and remote location is needed to open and
|
||||
|
|
|
@ -19,6 +19,8 @@
|
|||
import os
|
||||
import shutil
|
||||
|
||||
from cached_property import cached_property
|
||||
|
||||
from airflow import configuration
|
||||
from airflow.utils.log.logging_mixin import LoggingMixin
|
||||
from airflow.utils.log.file_task_handler import FileTaskHandler
|
||||
|
@ -43,7 +45,8 @@ class WasbTaskHandler(FileTaskHandler, LoggingMixin):
|
|||
self.upload_on_close = True
|
||||
self.delete_local_copy = delete_local_copy
|
||||
|
||||
def _build_hook(self):
|
||||
@cached_property
|
||||
def hook(self):
|
||||
remote_conn_id = configuration.get('core', 'REMOTE_LOG_CONN_ID')
|
||||
try:
|
||||
from airflow.contrib.hooks.wasb_hook import WasbHook
|
||||
|
@ -55,12 +58,6 @@ class WasbTaskHandler(FileTaskHandler, LoggingMixin):
|
|||
'the Wasb connection exists.', remote_conn_id
|
||||
)
|
||||
|
||||
@property
|
||||
def hook(self):
|
||||
if self._hook is None:
|
||||
self._hook = self._build_hook()
|
||||
return self._hook
|
||||
|
||||
def set_context(self, ti):
|
||||
super(WasbTaskHandler, self).set_context(ti)
|
||||
# Local location and remote location is needed to open and
|
||||
|
|
1
setup.py
1
setup.py
|
@ -309,6 +309,7 @@ def do_setup():
|
|||
scripts=['airflow/bin/airflow'],
|
||||
install_requires=[
|
||||
'alembic>=0.9, <1.0',
|
||||
'cached_property~=1.5',
|
||||
'configparser>=3.5.0, <3.6.0',
|
||||
'croniter>=0.3.17, <0.4',
|
||||
'dill>=0.2.2, <0.3',
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
|
||||
import unittest
|
||||
|
||||
from airflow.utils.decorators import apply_defaults, cached_property
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
from airflow.exceptions import AirflowException
|
||||
|
||||
|
||||
|
@ -73,29 +73,3 @@ class ApplyDefaultTest(unittest.TestCase):
|
|||
default_args = {'random_params': True}
|
||||
with self.assertRaisesRegexp(AirflowException, 'Argument.*test_param.*required'):
|
||||
DummyClass(default_args=default_args)
|
||||
|
||||
|
||||
class FixtureClass:
|
||||
@cached_property
|
||||
def value(self):
|
||||
"""Fixture docstring"""
|
||||
return 1, object()
|
||||
|
||||
|
||||
class FixtureSubClass(FixtureClass):
|
||||
pass
|
||||
|
||||
|
||||
class CachedPropertyTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.test_obj = FixtureClass()
|
||||
self.test_sub_obj = FixtureSubClass()
|
||||
|
||||
def test_cache_works(self):
|
||||
self.assertIs(self.test_obj.value, self.test_obj.value)
|
||||
self.assertIs(self.test_sub_obj.value, self.test_sub_obj.value)
|
||||
|
||||
def test_docstring(self):
|
||||
self.assertEqual(FixtureClass.value.__doc__, "Fixture docstring")
|
||||
self.assertEqual(FixtureSubClass.value.__doc__, "Fixture docstring")
|
||||
|
|
Загрузка…
Ссылка в новой задаче