[AIRFLOW-6056] Allow EmrAddStepsOperator to accept job_flow_name as alternative to job_flow_id (#6655)

This commit is contained in:
Aviem Zur 2019-12-10 15:49:38 +02:00 коммит произвёл Ash Berlin-Taylor
Родитель 239d51ed31
Коммит e37066086f
4 изменённых файлов: 104 добавлений и 13 удалений

Просмотреть файл

@ -30,12 +30,34 @@ class EmrHook(AwsHook):
def __init__(self, emr_conn_id=None, region_name=None, *args, **kwargs):
self.emr_conn_id = emr_conn_id
self.region_name = region_name
self.conn = None
super().__init__(*args, **kwargs)
def get_conn(self):
self.conn = self.get_client_type('emr', self.region_name)
if not self.conn:
self.conn = self.get_client_type('emr', self.region_name)
return self.conn
def get_cluster_id_by_name(self, emr_cluster_name, cluster_states):
conn = self.get_conn()
response = conn.list_clusters(
ClusterStates=cluster_states
)
matching_clusters = list(
filter(lambda cluster: cluster['Name'] == emr_cluster_name, response['Clusters'])
)
if len(matching_clusters) == 1:
cluster_id = matching_clusters[0]['Id']
self.log.info('Found cluster name = %s id = %s' % (emr_cluster_name, cluster_id))
return cluster_id
elif len(matching_clusters) > 1:
raise AirflowException('More than one cluster found for name = %s' % emr_cluster_name)
else:
return None
def create_job_flow(self, job_flow_overrides):
"""
Creates a job flow using the config from the EMR connection.

Просмотреть файл

@ -28,33 +28,58 @@ class EmrAddStepsOperator(BaseOperator):
:param job_flow_id: id of the JobFlow to add steps to. (templated)
:type job_flow_id: str
:param job_flow_name: name of the JobFlow to add steps to. Use as an alternative to passing
job_flow_id. will search for id of JobFlow with matching name in one of the states in
param cluster_states. Exactly one cluster like this should exist or will fail. (templated)
:type job_flow_name: str
:param cluster_states: Acceptable cluster states when searching for JobFlow id by job_flow_name.
(templated)
:type cluster_states: list
:param aws_conn_id: aws connection to uses
:type aws_conn_id: str
:param steps: boto3 style steps to be added to the jobflow. (templated)
:type steps: list
:param do_xcom_push: if True, job_flow_id is pushed to XCom with key job_flow_id.
:type do_xcom_push: bool
"""
template_fields = ['job_flow_id', 'steps']
template_fields = ['job_flow_id', 'job_flow_name', 'cluster_states', 'steps']
template_ext = ()
ui_color = '#f9c915'
@apply_defaults
def __init__(
self,
job_flow_id,
job_flow_id=None,
job_flow_name=None,
cluster_states=None,
aws_conn_id='aws_default',
steps=None,
*args, **kwargs):
if kwargs.get('xcom_push') is not None:
raise AirflowException("'xcom_push' was deprecated, use 'do_xcom_push' instead")
if not ((job_flow_id is None) ^ (job_flow_name is None)):
raise AirflowException('Exactly one of job_flow_id or job_flow_name must be specified.')
super().__init__(*args, **kwargs)
steps = steps or []
self.job_flow_id = job_flow_id
self.aws_conn_id = aws_conn_id
self.job_flow_id = job_flow_id
self.job_flow_name = job_flow_name
self.cluster_states = cluster_states
self.steps = steps
def execute(self, context):
emr = EmrHook(aws_conn_id=self.aws_conn_id).get_conn()
self.log.info('Adding steps to %s', self.job_flow_id)
response = emr.add_job_flow_steps(JobFlowId=self.job_flow_id, Steps=self.steps)
job_flow_id = self.job_flow_id
if not job_flow_id:
job_flow_id = emr.get_cluster_id_by_name(self.job_flow_name, self.cluster_states)
if self.do_xcom_push:
context['ti'].xcom_push(key='job_flow_id', value=job_flow_id)
self.log.info('Adding steps to %s', job_flow_id)
response = emr.add_job_flow_steps(JobFlowId=job_flow_id, Steps=self.steps)
if not response['ResponseMetadata']['HTTPStatusCode'] == 200:
raise AirflowException('Adding steps failed: %s' % response)

Просмотреть файл

@ -70,6 +70,26 @@ class TestEmrHook(unittest.TestCase):
# The AmiVersion comes back as {Requested,Running}AmiVersion fields.
self.assertEqual(cluster['RequestedAmiVersion'], '3.2')
@mock_emr
def test_get_cluster_id_by_name(self):
"""
Test that we can resolve cluster id by cluster name.
"""
hook = EmrHook(aws_conn_id='aws_default', emr_conn_id='emr_default')
job_flow = hook.create_job_flow({'Name': 'test_cluster',
'Instances': {'KeepJobFlowAliveWhenNoSteps': True}})
job_flow_id = job_flow['JobFlowId']
matching_cluster = hook.get_cluster_id_by_name('test_cluster', ['RUNNING', 'WAITING'])
self.assertEqual(matching_cluster, job_flow_id)
no_match = hook.get_cluster_id_by_name('foo', ['RUNNING', 'WAITING', 'BOOTSTRAPPING'])
self.assertIsNone(no_match)
if __name__ == '__main__':
unittest.main()

Просмотреть файл

@ -52,19 +52,27 @@ class TestEmrAddStepsOperator(unittest.TestCase):
}]
def setUp(self):
args = {
self.args = {
'owner': 'airflow',
'start_date': DEFAULT_DATE
}
# Mock out the emr_client (moto has incorrect response)
self.emr_client_mock = MagicMock()
# Mock out the emr_client creator
emr_session_mock = MagicMock()
emr_session_mock.client.return_value = self.emr_client_mock
self.boto3_session_mock = MagicMock(return_value=emr_session_mock)
self.mock_context = MagicMock()
self.operator = EmrAddStepsOperator(
task_id='test_task',
job_flow_id='j-8989898989',
aws_conn_id='aws_default',
steps=self._config,
dag=DAG('test_dag_id', default_args=args)
dag=DAG('test_dag_id', default_args=self.args)
)
def test_init(self):
@ -93,13 +101,29 @@ class TestEmrAddStepsOperator(unittest.TestCase):
def test_execute_returns_step_id(self):
self.emr_client_mock.add_job_flow_steps.return_value = ADD_STEPS_SUCCESS_RETURN
# Mock out the emr_client creator
emr_session_mock = MagicMock()
emr_session_mock.client.return_value = self.emr_client_mock
self.boto3_session_mock = MagicMock(return_value=emr_session_mock)
with patch('boto3.session.Session', self.boto3_session_mock):
self.assertEqual(self.operator.execute(self.mock_context), ['s-2LH3R5GW3A53T'])
def test_init_with_cluster_name(self):
expected_job_flow_id = 'j-1231231234'
self.emr_client_mock.get_cluster_id_by_name.return_value = expected_job_flow_id
self.emr_client_mock.add_job_flow_steps.return_value = ADD_STEPS_SUCCESS_RETURN
with patch('boto3.session.Session', self.boto3_session_mock):
self.assertEqual(self.operator.execute(None), ['s-2LH3R5GW3A53T'])
operator = EmrAddStepsOperator(
task_id='test_task',
job_flow_name='test_cluster',
cluster_states=['RUNNING', 'WAITING'],
aws_conn_id='aws_default',
dag=DAG('test_dag_id', default_args=self.args)
)
operator.execute(self.mock_context)
ti = self.mock_context['ti']
ti.xcom_push.assert_any_call(key='job_flow_id', value=expected_job_flow_id)
if __name__ == '__main__':