[AIRFLOW-1888] Add AWS Redshift Cluster Sensor
Add AWS Redshift Cluster Sensor to contrib, along with corresponding unit tests. Additionally, updated Redshift Hook cluster_status method to better handle cluster_not_found exception, added unit tests, and corrected linting errors. Closes #2849 from andyxhadji/AIRFLOW-1888
This commit is contained in:
Родитель
9ad6d1202d
Коммит
4936a80773
|
@ -14,6 +14,7 @@
|
|||
|
||||
from airflow.contrib.hooks.aws_hook import AwsHook
|
||||
|
||||
|
||||
class RedshiftHook(AwsHook):
|
||||
"""
|
||||
Interact with AWS Redshift, using the boto3 library
|
||||
|
@ -26,29 +27,36 @@ class RedshiftHook(AwsHook):
|
|||
"""
|
||||
Return status of a cluster
|
||||
|
||||
:param cluster_identifier: unique identifier of a cluster whose properties you are requesting
|
||||
:param cluster_identifier: unique identifier of a cluster
|
||||
:type cluster_identifier: str
|
||||
"""
|
||||
# Use describe clusters
|
||||
response = self.get_conn().describe_clusters(ClusterIdentifier=cluster_identifier)
|
||||
# Possibly return error if cluster does not exist
|
||||
return response['Clusters'][0]['ClusterStatus'] if response['Clusters'] else None
|
||||
conn = self.get_conn()
|
||||
try:
|
||||
response = conn.describe_clusters(
|
||||
ClusterIdentifier=cluster_identifier)['Clusters']
|
||||
return response[0]['ClusterStatus'] if response else None
|
||||
except conn.exceptions.ClusterNotFoundFault:
|
||||
return 'cluster_not_found'
|
||||
|
||||
def delete_cluster(self, cluster_identifier, skip_final_cluster_snapshot=True, final_cluster_snapshot_identifier=''):
|
||||
def delete_cluster(
|
||||
self,
|
||||
cluster_identifier,
|
||||
skip_final_cluster_snapshot=True,
|
||||
final_cluster_snapshot_identifier=''):
|
||||
"""
|
||||
Delete a cluster and optionally create a snapshot
|
||||
|
||||
:param cluster_identifier: unique identifier of a cluster whose properties you are requesting
|
||||
:param cluster_identifier: unique identifier of a cluster
|
||||
:type cluster_identifier: str
|
||||
:param skip_final_cluster_snapshot: determines if a final cluster snapshot is made before shut-down
|
||||
:param skip_final_cluster_snapshot: determines cluster snapshot creation
|
||||
:type skip_final_cluster_snapshot: bool
|
||||
:param final_cluster_snapshot_identifier: name of final cluster snapshot
|
||||
:type final_cluster_snapshot_identifier: str
|
||||
"""
|
||||
response = self.get_conn().delete_cluster(
|
||||
ClusterIdentifier = cluster_identifier,
|
||||
SkipFinalClusterSnapshot = skip_final_cluster_snapshot,
|
||||
FinalClusterSnapshotIdentifier = final_cluster_snapshot_identifier
|
||||
ClusterIdentifier=cluster_identifier,
|
||||
SkipFinalClusterSnapshot=skip_final_cluster_snapshot,
|
||||
FinalClusterSnapshotIdentifier=final_cluster_snapshot_identifier
|
||||
)
|
||||
return response['Cluster'] if response['Cluster'] else None
|
||||
|
||||
|
@ -56,11 +64,11 @@ class RedshiftHook(AwsHook):
|
|||
"""
|
||||
Gets a list of snapshots for a cluster
|
||||
|
||||
:param cluster_identifier: unique identifier of a cluster whose properties you are requesting
|
||||
:param cluster_identifier: unique identifier of a cluster
|
||||
:type cluster_identifier: str
|
||||
"""
|
||||
response = self.get_conn().describe_cluster_snapshots(
|
||||
ClusterIdentifier = cluster_identifier
|
||||
ClusterIdentifier=cluster_identifier
|
||||
)
|
||||
if 'Snapshots' not in response:
|
||||
return None
|
||||
|
@ -73,14 +81,14 @@ class RedshiftHook(AwsHook):
|
|||
"""
|
||||
Restores a cluster from it's snapshot
|
||||
|
||||
:param cluster_identifier: unique identifier of a cluster whose properties you are requesting
|
||||
:param cluster_identifier: unique identifier of a cluster
|
||||
:type cluster_identifier: str
|
||||
:param snapshot_identifier: unique identifier for a snapshot of a cluster
|
||||
:type snapshot_identifier: str
|
||||
"""
|
||||
response = self.get_conn().restore_from_cluster_snapshot(
|
||||
ClusterIdentifier = cluster_identifier,
|
||||
SnapshotIdentifier = snapshot_identifier
|
||||
ClusterIdentifier=cluster_identifier,
|
||||
SnapshotIdentifier=snapshot_identifier
|
||||
)
|
||||
return response['Cluster'] if response['Cluster'] else None
|
||||
|
||||
|
@ -90,7 +98,7 @@ class RedshiftHook(AwsHook):
|
|||
|
||||
:param snapshot_identifier: unique identifier for a snapshot of a cluster
|
||||
:type snapshot_identifier: str
|
||||
:param cluster_identifier: unique identifier of a cluster whose properties you are requesting
|
||||
:param cluster_identifier: unique identifier of a cluster
|
||||
:type cluster_identifier: str
|
||||
"""
|
||||
response = self.get_conn().create_cluster_snapshot(
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from airflow.operators.sensors import BaseSensorOperator
|
||||
from airflow.contrib.hooks.redshift_hook import RedshiftHook
|
||||
from airflow.utils.decorators import apply_defaults
|
||||
|
||||
|
||||
class AwsRedshiftClusterSensor(BaseSensorOperator):
|
||||
"""
|
||||
Waits for a Redshift cluster to reach a specific status.
|
||||
|
||||
:param cluster_identifier: The identifier for the cluster being pinged.
|
||||
:type cluster_identifier: str
|
||||
:param target_status: The cluster status desired.
|
||||
:type target_status: str
|
||||
"""
|
||||
template_fields = ('cluster_identifier', 'target_status')
|
||||
|
||||
@apply_defaults
|
||||
def __init__(
|
||||
self, cluster_identifier,
|
||||
target_status='available',
|
||||
aws_conn_id='aws_default',
|
||||
*args, **kwargs):
|
||||
super(AwsRedshiftClusterSensor, self).__init__(*args, **kwargs)
|
||||
self.cluster_identifier = cluster_identifier
|
||||
self.target_status = target_status
|
||||
self.aws_conn_id = aws_conn_id
|
||||
|
||||
def poke(self, context):
|
||||
self.log.info('Poking for status : {self.target_status}\n'
|
||||
'for cluster {self.cluster_identifier}'.format(**locals()))
|
||||
hook = RedshiftHook(aws_conn_id=self.aws_conn_id)
|
||||
return hook.cluster_status(self.cluster_identifier) == self.target_status
|
|
@ -25,6 +25,7 @@ try:
|
|||
except ImportError:
|
||||
mock_redshift = None
|
||||
|
||||
|
||||
@mock_redshift
|
||||
class TestRedshiftHook(unittest.TestCase):
|
||||
def setUp(self):
|
||||
|
@ -56,8 +57,12 @@ class TestRedshiftHook(unittest.TestCase):
|
|||
@unittest.skipIf(mock_redshift is None, 'mock_redshift package not present')
|
||||
def test_restore_from_cluster_snapshot_returns_dict_with_cluster_data(self):
|
||||
hook = RedshiftHook(aws_conn_id='aws_default')
|
||||
snapshot = hook.create_cluster_snapshot('test_snapshot', 'test_cluster')
|
||||
self.assertEqual(hook.restore_from_cluster_snapshot('test_cluster_3', 'test_snapshot')['ClusterIdentifier'], 'test_cluster_3')
|
||||
hook.create_cluster_snapshot('test_snapshot', 'test_cluster')
|
||||
self.assertEqual(
|
||||
hook.restore_from_cluster_snapshot(
|
||||
'test_cluster_3', 'test_snapshot'
|
||||
)['ClusterIdentifier'],
|
||||
'test_cluster_3')
|
||||
|
||||
@unittest.skipIf(mock_redshift is None, 'mock_redshift package not present')
|
||||
def test_delete_cluster_returns_a_dict_with_cluster_data(self):
|
||||
|
@ -73,5 +78,18 @@ class TestRedshiftHook(unittest.TestCase):
|
|||
snapshot = hook.create_cluster_snapshot('test_snapshot_2', 'test_cluster')
|
||||
self.assertNotEqual(snapshot, None)
|
||||
|
||||
@unittest.skipIf(mock_redshift is None, 'mock_redshift package not present')
|
||||
def test_cluster_status_returns_cluster_not_found(self):
|
||||
hook = RedshiftHook(aws_conn_id='aws_default')
|
||||
status = hook.cluster_status('test_cluster_not_here')
|
||||
self.assertEqual(status, 'cluster_not_found')
|
||||
|
||||
@unittest.skipIf(mock_redshift is None, 'mock_redshift package not present')
|
||||
def test_cluster_status_returns_available_cluster(self):
|
||||
hook = RedshiftHook(aws_conn_id='aws_default')
|
||||
status = hook.cluster_status('test_cluster')
|
||||
self.assertEqual(status, 'available')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -0,0 +1,76 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import unittest
|
||||
import boto3
|
||||
|
||||
from airflow import configuration
|
||||
from airflow.contrib.sensors.aws_redshift_cluster_sensor import AwsRedshiftClusterSensor
|
||||
|
||||
try:
|
||||
from moto import mock_redshift
|
||||
except ImportError:
|
||||
mock_redshift = None
|
||||
|
||||
|
||||
@mock_redshift
|
||||
class TestAwsRedshiftClusterSensor(unittest.TestCase):
|
||||
def setUp(self):
|
||||
configuration.load_test_config()
|
||||
client = boto3.client('redshift', region_name='us-east-1')
|
||||
client.create_cluster(
|
||||
ClusterIdentifier='test_cluster',
|
||||
NodeType='dc1.large',
|
||||
MasterUsername='admin',
|
||||
MasterUserPassword='mock_password'
|
||||
)
|
||||
if len(client.describe_clusters()['Clusters']) == 0:
|
||||
raise ValueError('AWS not properly mocked')
|
||||
|
||||
@unittest.skipIf(mock_redshift is None, 'mock_redshift package not present')
|
||||
def test_poke(self):
|
||||
op = AwsRedshiftClusterSensor(task_id='test_cluster_sensor',
|
||||
poke_interval=1,
|
||||
timeout=5,
|
||||
aws_conn_id='aws_default',
|
||||
cluster_identifier='test_cluster',
|
||||
target_status='available')
|
||||
self.assertTrue(op.poke(None))
|
||||
|
||||
@unittest.skipIf(mock_redshift is None, 'mock_redshift package not present')
|
||||
def test_poke_false(self):
|
||||
op = AwsRedshiftClusterSensor(task_id='test_cluster_sensor',
|
||||
poke_interval=1,
|
||||
timeout=5,
|
||||
aws_conn_id='aws_default',
|
||||
cluster_identifier='test_cluster_not_found',
|
||||
target_status='available')
|
||||
|
||||
self.assertFalse(op.poke(None))
|
||||
|
||||
@unittest.skipIf(mock_redshift is None, 'mock_redshift package not present')
|
||||
def test_poke_cluster_not_found(self):
|
||||
op = AwsRedshiftClusterSensor(task_id='test_cluster_sensor',
|
||||
poke_interval=1,
|
||||
timeout=5,
|
||||
aws_conn_id='aws_default',
|
||||
cluster_identifier='test_cluster_not_found',
|
||||
target_status='cluster_not_found')
|
||||
|
||||
self.assertTrue(op.poke(None))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Загрузка…
Ссылка в новой задаче