From 4936a807736557718dbc0690b92240806de5f3a9 Mon Sep 17 00:00:00 2001 From: Andy Hadjigeorgiou Date: Fri, 8 Dec 2017 10:16:44 +0100 Subject: [PATCH] [AIRFLOW-1888] Add AWS Redshift Cluster Sensor Add AWS Redshift Cluster Sensor to contrib, along with corresponding unit tests. Additionally, updated Redshift Hook cluster_status method to better handle cluster_not_found exception, added unit tests, and corrected linting errors. Closes #2849 from andyxhadji/AIRFLOW-1888 --- airflow/contrib/hooks/redshift_hook.py | 42 +++++----- .../sensors/aws_redshift_cluster_sensor.py | 46 +++++++++++ tests/contrib/hooks/test_redshift_hook.py | 22 +++++- .../test_aws_redshift_cluster_sensor.py | 76 +++++++++++++++++++ 4 files changed, 167 insertions(+), 19 deletions(-) create mode 100644 airflow/contrib/sensors/aws_redshift_cluster_sensor.py create mode 100644 tests/contrib/sensors/test_aws_redshift_cluster_sensor.py diff --git a/airflow/contrib/hooks/redshift_hook.py b/airflow/contrib/hooks/redshift_hook.py index 071caf2610..70a4854714 100644 --- a/airflow/contrib/hooks/redshift_hook.py +++ b/airflow/contrib/hooks/redshift_hook.py @@ -14,6 +14,7 @@ from airflow.contrib.hooks.aws_hook import AwsHook + class RedshiftHook(AwsHook): """ Interact with AWS Redshift, using the boto3 library @@ -26,29 +27,36 @@ class RedshiftHook(AwsHook): """ Return status of a cluster - :param cluster_identifier: unique identifier of a cluster whose properties you are requesting + :param cluster_identifier: unique identifier of a cluster :type cluster_identifier: str """ - # Use describe clusters - response = self.get_conn().describe_clusters(ClusterIdentifier=cluster_identifier) - # Possibly return error if cluster does not exist - return response['Clusters'][0]['ClusterStatus'] if response['Clusters'] else None + conn = self.get_conn() + try: + response = conn.describe_clusters( + ClusterIdentifier=cluster_identifier)['Clusters'] + return response[0]['ClusterStatus'] if response else None + except conn.exceptions.ClusterNotFoundFault: + return 'cluster_not_found' - def delete_cluster(self, cluster_identifier, skip_final_cluster_snapshot=True, final_cluster_snapshot_identifier=''): + def delete_cluster( + self, + cluster_identifier, + skip_final_cluster_snapshot=True, + final_cluster_snapshot_identifier=''): """ Delete a cluster and optionally create a snapshot - :param cluster_identifier: unique identifier of a cluster whose properties you are requesting + :param cluster_identifier: unique identifier of a cluster :type cluster_identifier: str - :param skip_final_cluster_snapshot: determines if a final cluster snapshot is made before shut-down + :param skip_final_cluster_snapshot: determines cluster snapshot creation :type skip_final_cluster_snapshot: bool :param final_cluster_snapshot_identifier: name of final cluster snapshot :type final_cluster_snapshot_identifier: str """ response = self.get_conn().delete_cluster( - ClusterIdentifier = cluster_identifier, - SkipFinalClusterSnapshot = skip_final_cluster_snapshot, - FinalClusterSnapshotIdentifier = final_cluster_snapshot_identifier + ClusterIdentifier=cluster_identifier, + SkipFinalClusterSnapshot=skip_final_cluster_snapshot, + FinalClusterSnapshotIdentifier=final_cluster_snapshot_identifier ) return response['Cluster'] if response['Cluster'] else None @@ -56,11 +64,11 @@ class RedshiftHook(AwsHook): """ Gets a list of snapshots for a cluster - :param cluster_identifier: unique identifier of a cluster whose properties you are requesting + :param cluster_identifier: unique identifier of a cluster :type cluster_identifier: str """ response = self.get_conn().describe_cluster_snapshots( - ClusterIdentifier = cluster_identifier + ClusterIdentifier=cluster_identifier ) if 'Snapshots' not in response: return None @@ -73,14 +81,14 @@ class RedshiftHook(AwsHook): """ Restores a cluster from it's snapshot - :param cluster_identifier: unique identifier of a cluster whose properties you are requesting + :param cluster_identifier: unique identifier of a cluster :type cluster_identifier: str :param snapshot_identifier: unique identifier for a snapshot of a cluster :type snapshot_identifier: str """ response = self.get_conn().restore_from_cluster_snapshot( - ClusterIdentifier = cluster_identifier, - SnapshotIdentifier = snapshot_identifier + ClusterIdentifier=cluster_identifier, + SnapshotIdentifier=snapshot_identifier ) return response['Cluster'] if response['Cluster'] else None @@ -90,7 +98,7 @@ class RedshiftHook(AwsHook): :param snapshot_identifier: unique identifier for a snapshot of a cluster :type snapshot_identifier: str - :param cluster_identifier: unique identifier of a cluster whose properties you are requesting + :param cluster_identifier: unique identifier of a cluster :type cluster_identifier: str """ response = self.get_conn().create_cluster_snapshot( diff --git a/airflow/contrib/sensors/aws_redshift_cluster_sensor.py b/airflow/contrib/sensors/aws_redshift_cluster_sensor.py new file mode 100644 index 0000000000..8db85e6abe --- /dev/null +++ b/airflow/contrib/sensors/aws_redshift_cluster_sensor.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from airflow.operators.sensors import BaseSensorOperator +from airflow.contrib.hooks.redshift_hook import RedshiftHook +from airflow.utils.decorators import apply_defaults + + +class AwsRedshiftClusterSensor(BaseSensorOperator): + """ + Waits for a Redshift cluster to reach a specific status. + + :param cluster_identifier: The identifier for the cluster being pinged. + :type cluster_identifier: str + :param target_status: The cluster status desired. + :type target_status: str + """ + template_fields = ('cluster_identifier', 'target_status') + + @apply_defaults + def __init__( + self, cluster_identifier, + target_status='available', + aws_conn_id='aws_default', + *args, **kwargs): + super(AwsRedshiftClusterSensor, self).__init__(*args, **kwargs) + self.cluster_identifier = cluster_identifier + self.target_status = target_status + self.aws_conn_id = aws_conn_id + + def poke(self, context): + self.log.info('Poking for status : {self.target_status}\n' + 'for cluster {self.cluster_identifier}'.format(**locals())) + hook = RedshiftHook(aws_conn_id=self.aws_conn_id) + return hook.cluster_status(self.cluster_identifier) == self.target_status diff --git a/tests/contrib/hooks/test_redshift_hook.py b/tests/contrib/hooks/test_redshift_hook.py index 185be5e636..c7884a375c 100644 --- a/tests/contrib/hooks/test_redshift_hook.py +++ b/tests/contrib/hooks/test_redshift_hook.py @@ -25,6 +25,7 @@ try: except ImportError: mock_redshift = None + @mock_redshift class TestRedshiftHook(unittest.TestCase): def setUp(self): @@ -56,8 +57,12 @@ class TestRedshiftHook(unittest.TestCase): @unittest.skipIf(mock_redshift is None, 'mock_redshift package not present') def test_restore_from_cluster_snapshot_returns_dict_with_cluster_data(self): hook = RedshiftHook(aws_conn_id='aws_default') - snapshot = hook.create_cluster_snapshot('test_snapshot', 'test_cluster') - self.assertEqual(hook.restore_from_cluster_snapshot('test_cluster_3', 'test_snapshot')['ClusterIdentifier'], 'test_cluster_3') + hook.create_cluster_snapshot('test_snapshot', 'test_cluster') + self.assertEqual( + hook.restore_from_cluster_snapshot( + 'test_cluster_3', 'test_snapshot' + )['ClusterIdentifier'], + 'test_cluster_3') @unittest.skipIf(mock_redshift is None, 'mock_redshift package not present') def test_delete_cluster_returns_a_dict_with_cluster_data(self): @@ -73,5 +78,18 @@ class TestRedshiftHook(unittest.TestCase): snapshot = hook.create_cluster_snapshot('test_snapshot_2', 'test_cluster') self.assertNotEqual(snapshot, None) + @unittest.skipIf(mock_redshift is None, 'mock_redshift package not present') + def test_cluster_status_returns_cluster_not_found(self): + hook = RedshiftHook(aws_conn_id='aws_default') + status = hook.cluster_status('test_cluster_not_here') + self.assertEqual(status, 'cluster_not_found') + + @unittest.skipIf(mock_redshift is None, 'mock_redshift package not present') + def test_cluster_status_returns_available_cluster(self): + hook = RedshiftHook(aws_conn_id='aws_default') + status = hook.cluster_status('test_cluster') + self.assertEqual(status, 'available') + + if __name__ == '__main__': unittest.main() diff --git a/tests/contrib/sensors/test_aws_redshift_cluster_sensor.py b/tests/contrib/sensors/test_aws_redshift_cluster_sensor.py new file mode 100644 index 0000000000..a5c9e66eac --- /dev/null +++ b/tests/contrib/sensors/test_aws_redshift_cluster_sensor.py @@ -0,0 +1,76 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest +import boto3 + +from airflow import configuration +from airflow.contrib.sensors.aws_redshift_cluster_sensor import AwsRedshiftClusterSensor + +try: + from moto import mock_redshift +except ImportError: + mock_redshift = None + + +@mock_redshift +class TestAwsRedshiftClusterSensor(unittest.TestCase): + def setUp(self): + configuration.load_test_config() + client = boto3.client('redshift', region_name='us-east-1') + client.create_cluster( + ClusterIdentifier='test_cluster', + NodeType='dc1.large', + MasterUsername='admin', + MasterUserPassword='mock_password' + ) + if len(client.describe_clusters()['Clusters']) == 0: + raise ValueError('AWS not properly mocked') + + @unittest.skipIf(mock_redshift is None, 'mock_redshift package not present') + def test_poke(self): + op = AwsRedshiftClusterSensor(task_id='test_cluster_sensor', + poke_interval=1, + timeout=5, + aws_conn_id='aws_default', + cluster_identifier='test_cluster', + target_status='available') + self.assertTrue(op.poke(None)) + + @unittest.skipIf(mock_redshift is None, 'mock_redshift package not present') + def test_poke_false(self): + op = AwsRedshiftClusterSensor(task_id='test_cluster_sensor', + poke_interval=1, + timeout=5, + aws_conn_id='aws_default', + cluster_identifier='test_cluster_not_found', + target_status='available') + + self.assertFalse(op.poke(None)) + + @unittest.skipIf(mock_redshift is None, 'mock_redshift package not present') + def test_poke_cluster_not_found(self): + op = AwsRedshiftClusterSensor(task_id='test_cluster_sensor', + poke_interval=1, + timeout=5, + aws_conn_id='aws_default', + cluster_identifier='test_cluster_not_found', + target_status='cluster_not_found') + + self.assertTrue(op.poke(None)) + + +if __name__ == '__main__': + unittest.main()