[AIRFLOW-1888] Add AWS Redshift Cluster Sensor

Add AWS Redshift Cluster Sensor to contrib, along
with corresponding
unit tests. Additionally, updated Redshift Hook
cluster_status method to
better handle cluster_not_found exception, added
unit tests, and
corrected linting errors.

Closes #2849 from andyxhadji/AIRFLOW-1888
This commit is contained in:
Andy Hadjigeorgiou 2017-12-08 10:16:44 +01:00 коммит произвёл Bolke de Bruin
Родитель 9ad6d1202d
Коммит 4936a80773
4 изменённых файлов: 167 добавлений и 19 удалений

Просмотреть файл

@ -14,6 +14,7 @@
from airflow.contrib.hooks.aws_hook import AwsHook
class RedshiftHook(AwsHook):
"""
Interact with AWS Redshift, using the boto3 library
@ -26,29 +27,36 @@ class RedshiftHook(AwsHook):
"""
Return status of a cluster
:param cluster_identifier: unique identifier of a cluster whose properties you are requesting
:param cluster_identifier: unique identifier of a cluster
:type cluster_identifier: str
"""
# Use describe clusters
response = self.get_conn().describe_clusters(ClusterIdentifier=cluster_identifier)
# Possibly return error if cluster does not exist
return response['Clusters'][0]['ClusterStatus'] if response['Clusters'] else None
conn = self.get_conn()
try:
response = conn.describe_clusters(
ClusterIdentifier=cluster_identifier)['Clusters']
return response[0]['ClusterStatus'] if response else None
except conn.exceptions.ClusterNotFoundFault:
return 'cluster_not_found'
def delete_cluster(self, cluster_identifier, skip_final_cluster_snapshot=True, final_cluster_snapshot_identifier=''):
def delete_cluster(
self,
cluster_identifier,
skip_final_cluster_snapshot=True,
final_cluster_snapshot_identifier=''):
"""
Delete a cluster and optionally create a snapshot
:param cluster_identifier: unique identifier of a cluster whose properties you are requesting
:param cluster_identifier: unique identifier of a cluster
:type cluster_identifier: str
:param skip_final_cluster_snapshot: determines if a final cluster snapshot is made before shut-down
:param skip_final_cluster_snapshot: determines cluster snapshot creation
:type skip_final_cluster_snapshot: bool
:param final_cluster_snapshot_identifier: name of final cluster snapshot
:type final_cluster_snapshot_identifier: str
"""
response = self.get_conn().delete_cluster(
ClusterIdentifier = cluster_identifier,
SkipFinalClusterSnapshot = skip_final_cluster_snapshot,
FinalClusterSnapshotIdentifier = final_cluster_snapshot_identifier
ClusterIdentifier=cluster_identifier,
SkipFinalClusterSnapshot=skip_final_cluster_snapshot,
FinalClusterSnapshotIdentifier=final_cluster_snapshot_identifier
)
return response['Cluster'] if response['Cluster'] else None
@ -56,11 +64,11 @@ class RedshiftHook(AwsHook):
"""
Gets a list of snapshots for a cluster
:param cluster_identifier: unique identifier of a cluster whose properties you are requesting
:param cluster_identifier: unique identifier of a cluster
:type cluster_identifier: str
"""
response = self.get_conn().describe_cluster_snapshots(
ClusterIdentifier = cluster_identifier
ClusterIdentifier=cluster_identifier
)
if 'Snapshots' not in response:
return None
@ -73,14 +81,14 @@ class RedshiftHook(AwsHook):
"""
Restores a cluster from it's snapshot
:param cluster_identifier: unique identifier of a cluster whose properties you are requesting
:param cluster_identifier: unique identifier of a cluster
:type cluster_identifier: str
:param snapshot_identifier: unique identifier for a snapshot of a cluster
:type snapshot_identifier: str
"""
response = self.get_conn().restore_from_cluster_snapshot(
ClusterIdentifier = cluster_identifier,
SnapshotIdentifier = snapshot_identifier
ClusterIdentifier=cluster_identifier,
SnapshotIdentifier=snapshot_identifier
)
return response['Cluster'] if response['Cluster'] else None
@ -90,7 +98,7 @@ class RedshiftHook(AwsHook):
:param snapshot_identifier: unique identifier for a snapshot of a cluster
:type snapshot_identifier: str
:param cluster_identifier: unique identifier of a cluster whose properties you are requesting
:param cluster_identifier: unique identifier of a cluster
:type cluster_identifier: str
"""
response = self.get_conn().create_cluster_snapshot(

Просмотреть файл

@ -0,0 +1,46 @@
# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from airflow.operators.sensors import BaseSensorOperator
from airflow.contrib.hooks.redshift_hook import RedshiftHook
from airflow.utils.decorators import apply_defaults
class AwsRedshiftClusterSensor(BaseSensorOperator):
"""
Waits for a Redshift cluster to reach a specific status.
:param cluster_identifier: The identifier for the cluster being pinged.
:type cluster_identifier: str
:param target_status: The cluster status desired.
:type target_status: str
"""
template_fields = ('cluster_identifier', 'target_status')
@apply_defaults
def __init__(
self, cluster_identifier,
target_status='available',
aws_conn_id='aws_default',
*args, **kwargs):
super(AwsRedshiftClusterSensor, self).__init__(*args, **kwargs)
self.cluster_identifier = cluster_identifier
self.target_status = target_status
self.aws_conn_id = aws_conn_id
def poke(self, context):
self.log.info('Poking for status : {self.target_status}\n'
'for cluster {self.cluster_identifier}'.format(**locals()))
hook = RedshiftHook(aws_conn_id=self.aws_conn_id)
return hook.cluster_status(self.cluster_identifier) == self.target_status

Просмотреть файл

@ -25,6 +25,7 @@ try:
except ImportError:
mock_redshift = None
@mock_redshift
class TestRedshiftHook(unittest.TestCase):
def setUp(self):
@ -56,8 +57,12 @@ class TestRedshiftHook(unittest.TestCase):
@unittest.skipIf(mock_redshift is None, 'mock_redshift package not present')
def test_restore_from_cluster_snapshot_returns_dict_with_cluster_data(self):
hook = RedshiftHook(aws_conn_id='aws_default')
snapshot = hook.create_cluster_snapshot('test_snapshot', 'test_cluster')
self.assertEqual(hook.restore_from_cluster_snapshot('test_cluster_3', 'test_snapshot')['ClusterIdentifier'], 'test_cluster_3')
hook.create_cluster_snapshot('test_snapshot', 'test_cluster')
self.assertEqual(
hook.restore_from_cluster_snapshot(
'test_cluster_3', 'test_snapshot'
)['ClusterIdentifier'],
'test_cluster_3')
@unittest.skipIf(mock_redshift is None, 'mock_redshift package not present')
def test_delete_cluster_returns_a_dict_with_cluster_data(self):
@ -73,5 +78,18 @@ class TestRedshiftHook(unittest.TestCase):
snapshot = hook.create_cluster_snapshot('test_snapshot_2', 'test_cluster')
self.assertNotEqual(snapshot, None)
@unittest.skipIf(mock_redshift is None, 'mock_redshift package not present')
def test_cluster_status_returns_cluster_not_found(self):
hook = RedshiftHook(aws_conn_id='aws_default')
status = hook.cluster_status('test_cluster_not_here')
self.assertEqual(status, 'cluster_not_found')
@unittest.skipIf(mock_redshift is None, 'mock_redshift package not present')
def test_cluster_status_returns_available_cluster(self):
hook = RedshiftHook(aws_conn_id='aws_default')
status = hook.cluster_status('test_cluster')
self.assertEqual(status, 'available')
if __name__ == '__main__':
unittest.main()

Просмотреть файл

@ -0,0 +1,76 @@
# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import unittest
import boto3
from airflow import configuration
from airflow.contrib.sensors.aws_redshift_cluster_sensor import AwsRedshiftClusterSensor
try:
from moto import mock_redshift
except ImportError:
mock_redshift = None
@mock_redshift
class TestAwsRedshiftClusterSensor(unittest.TestCase):
def setUp(self):
configuration.load_test_config()
client = boto3.client('redshift', region_name='us-east-1')
client.create_cluster(
ClusterIdentifier='test_cluster',
NodeType='dc1.large',
MasterUsername='admin',
MasterUserPassword='mock_password'
)
if len(client.describe_clusters()['Clusters']) == 0:
raise ValueError('AWS not properly mocked')
@unittest.skipIf(mock_redshift is None, 'mock_redshift package not present')
def test_poke(self):
op = AwsRedshiftClusterSensor(task_id='test_cluster_sensor',
poke_interval=1,
timeout=5,
aws_conn_id='aws_default',
cluster_identifier='test_cluster',
target_status='available')
self.assertTrue(op.poke(None))
@unittest.skipIf(mock_redshift is None, 'mock_redshift package not present')
def test_poke_false(self):
op = AwsRedshiftClusterSensor(task_id='test_cluster_sensor',
poke_interval=1,
timeout=5,
aws_conn_id='aws_default',
cluster_identifier='test_cluster_not_found',
target_status='available')
self.assertFalse(op.poke(None))
@unittest.skipIf(mock_redshift is None, 'mock_redshift package not present')
def test_poke_cluster_not_found(self):
op = AwsRedshiftClusterSensor(task_id='test_cluster_sensor',
poke_interval=1,
timeout=5,
aws_conn_id='aws_default',
cluster_identifier='test_cluster_not_found',
target_status='cluster_not_found')
self.assertTrue(op.poke(None))
if __name__ == '__main__':
unittest.main()