[AIRFLOW-6056] Allow EmrAddStepsOperator to accept job_flow_name as alternative to job_flow_id (#6655)
This commit is contained in:
Родитель
239d51ed31
Коммит
e37066086f
|
@ -30,12 +30,34 @@ class EmrHook(AwsHook):
|
|||
def __init__(self, emr_conn_id=None, region_name=None, *args, **kwargs):
|
||||
self.emr_conn_id = emr_conn_id
|
||||
self.region_name = region_name
|
||||
self.conn = None
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def get_conn(self):
|
||||
self.conn = self.get_client_type('emr', self.region_name)
|
||||
if not self.conn:
|
||||
self.conn = self.get_client_type('emr', self.region_name)
|
||||
return self.conn
|
||||
|
||||
def get_cluster_id_by_name(self, emr_cluster_name, cluster_states):
|
||||
conn = self.get_conn()
|
||||
|
||||
response = conn.list_clusters(
|
||||
ClusterStates=cluster_states
|
||||
)
|
||||
|
||||
matching_clusters = list(
|
||||
filter(lambda cluster: cluster['Name'] == emr_cluster_name, response['Clusters'])
|
||||
)
|
||||
|
||||
if len(matching_clusters) == 1:
|
||||
cluster_id = matching_clusters[0]['Id']
|
||||
self.log.info('Found cluster name = %s id = %s' % (emr_cluster_name, cluster_id))
|
||||
return cluster_id
|
||||
elif len(matching_clusters) > 1:
|
||||
raise AirflowException('More than one cluster found for name = %s' % emr_cluster_name)
|
||||
else:
|
||||
return None
|
||||
|
||||
def create_job_flow(self, job_flow_overrides):
|
||||
"""
|
||||
Creates a job flow using the config from the EMR connection.
|
||||
|
|
|
@ -28,33 +28,58 @@ class EmrAddStepsOperator(BaseOperator):
|
|||
|
||||
:param job_flow_id: id of the JobFlow to add steps to. (templated)
|
||||
:type job_flow_id: str
|
||||
:param job_flow_name: name of the JobFlow to add steps to. Use as an alternative to passing
|
||||
job_flow_id. will search for id of JobFlow with matching name in one of the states in
|
||||
param cluster_states. Exactly one cluster like this should exist or will fail. (templated)
|
||||
:type job_flow_name: str
|
||||
:param cluster_states: Acceptable cluster states when searching for JobFlow id by job_flow_name.
|
||||
(templated)
|
||||
:type cluster_states: list
|
||||
:param aws_conn_id: aws connection to uses
|
||||
:type aws_conn_id: str
|
||||
:param steps: boto3 style steps to be added to the jobflow. (templated)
|
||||
:type steps: list
|
||||
:param do_xcom_push: if True, job_flow_id is pushed to XCom with key job_flow_id.
|
||||
:type do_xcom_push: bool
|
||||
"""
|
||||
template_fields = ['job_flow_id', 'steps']
|
||||
template_fields = ['job_flow_id', 'job_flow_name', 'cluster_states', 'steps']
|
||||
template_ext = ()
|
||||
ui_color = '#f9c915'
|
||||
|
||||
@apply_defaults
|
||||
def __init__(
|
||||
self,
|
||||
job_flow_id,
|
||||
job_flow_id=None,
|
||||
job_flow_name=None,
|
||||
cluster_states=None,
|
||||
aws_conn_id='aws_default',
|
||||
steps=None,
|
||||
*args, **kwargs):
|
||||
if kwargs.get('xcom_push') is not None:
|
||||
raise AirflowException("'xcom_push' was deprecated, use 'do_xcom_push' instead")
|
||||
if not ((job_flow_id is None) ^ (job_flow_name is None)):
|
||||
raise AirflowException('Exactly one of job_flow_id or job_flow_name must be specified.')
|
||||
super().__init__(*args, **kwargs)
|
||||
steps = steps or []
|
||||
self.job_flow_id = job_flow_id
|
||||
self.aws_conn_id = aws_conn_id
|
||||
self.job_flow_id = job_flow_id
|
||||
self.job_flow_name = job_flow_name
|
||||
self.cluster_states = cluster_states
|
||||
self.steps = steps
|
||||
|
||||
def execute(self, context):
|
||||
emr = EmrHook(aws_conn_id=self.aws_conn_id).get_conn()
|
||||
|
||||
self.log.info('Adding steps to %s', self.job_flow_id)
|
||||
response = emr.add_job_flow_steps(JobFlowId=self.job_flow_id, Steps=self.steps)
|
||||
job_flow_id = self.job_flow_id
|
||||
|
||||
if not job_flow_id:
|
||||
job_flow_id = emr.get_cluster_id_by_name(self.job_flow_name, self.cluster_states)
|
||||
|
||||
if self.do_xcom_push:
|
||||
context['ti'].xcom_push(key='job_flow_id', value=job_flow_id)
|
||||
|
||||
self.log.info('Adding steps to %s', job_flow_id)
|
||||
response = emr.add_job_flow_steps(JobFlowId=job_flow_id, Steps=self.steps)
|
||||
|
||||
if not response['ResponseMetadata']['HTTPStatusCode'] == 200:
|
||||
raise AirflowException('Adding steps failed: %s' % response)
|
||||
|
|
|
@ -70,6 +70,26 @@ class TestEmrHook(unittest.TestCase):
|
|||
# The AmiVersion comes back as {Requested,Running}AmiVersion fields.
|
||||
self.assertEqual(cluster['RequestedAmiVersion'], '3.2')
|
||||
|
||||
@mock_emr
|
||||
def test_get_cluster_id_by_name(self):
|
||||
"""
|
||||
Test that we can resolve cluster id by cluster name.
|
||||
"""
|
||||
hook = EmrHook(aws_conn_id='aws_default', emr_conn_id='emr_default')
|
||||
|
||||
job_flow = hook.create_job_flow({'Name': 'test_cluster',
|
||||
'Instances': {'KeepJobFlowAliveWhenNoSteps': True}})
|
||||
|
||||
job_flow_id = job_flow['JobFlowId']
|
||||
|
||||
matching_cluster = hook.get_cluster_id_by_name('test_cluster', ['RUNNING', 'WAITING'])
|
||||
|
||||
self.assertEqual(matching_cluster, job_flow_id)
|
||||
|
||||
no_match = hook.get_cluster_id_by_name('foo', ['RUNNING', 'WAITING', 'BOOTSTRAPPING'])
|
||||
|
||||
self.assertIsNone(no_match)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -52,19 +52,27 @@ class TestEmrAddStepsOperator(unittest.TestCase):
|
|||
}]
|
||||
|
||||
def setUp(self):
|
||||
args = {
|
||||
self.args = {
|
||||
'owner': 'airflow',
|
||||
'start_date': DEFAULT_DATE
|
||||
}
|
||||
|
||||
# Mock out the emr_client (moto has incorrect response)
|
||||
self.emr_client_mock = MagicMock()
|
||||
|
||||
# Mock out the emr_client creator
|
||||
emr_session_mock = MagicMock()
|
||||
emr_session_mock.client.return_value = self.emr_client_mock
|
||||
self.boto3_session_mock = MagicMock(return_value=emr_session_mock)
|
||||
|
||||
self.mock_context = MagicMock()
|
||||
|
||||
self.operator = EmrAddStepsOperator(
|
||||
task_id='test_task',
|
||||
job_flow_id='j-8989898989',
|
||||
aws_conn_id='aws_default',
|
||||
steps=self._config,
|
||||
dag=DAG('test_dag_id', default_args=args)
|
||||
dag=DAG('test_dag_id', default_args=self.args)
|
||||
)
|
||||
|
||||
def test_init(self):
|
||||
|
@ -93,13 +101,29 @@ class TestEmrAddStepsOperator(unittest.TestCase):
|
|||
def test_execute_returns_step_id(self):
|
||||
self.emr_client_mock.add_job_flow_steps.return_value = ADD_STEPS_SUCCESS_RETURN
|
||||
|
||||
# Mock out the emr_client creator
|
||||
emr_session_mock = MagicMock()
|
||||
emr_session_mock.client.return_value = self.emr_client_mock
|
||||
self.boto3_session_mock = MagicMock(return_value=emr_session_mock)
|
||||
with patch('boto3.session.Session', self.boto3_session_mock):
|
||||
self.assertEqual(self.operator.execute(self.mock_context), ['s-2LH3R5GW3A53T'])
|
||||
|
||||
def test_init_with_cluster_name(self):
|
||||
expected_job_flow_id = 'j-1231231234'
|
||||
|
||||
self.emr_client_mock.get_cluster_id_by_name.return_value = expected_job_flow_id
|
||||
self.emr_client_mock.add_job_flow_steps.return_value = ADD_STEPS_SUCCESS_RETURN
|
||||
|
||||
with patch('boto3.session.Session', self.boto3_session_mock):
|
||||
self.assertEqual(self.operator.execute(None), ['s-2LH3R5GW3A53T'])
|
||||
operator = EmrAddStepsOperator(
|
||||
task_id='test_task',
|
||||
job_flow_name='test_cluster',
|
||||
cluster_states=['RUNNING', 'WAITING'],
|
||||
aws_conn_id='aws_default',
|
||||
dag=DAG('test_dag_id', default_args=self.args)
|
||||
)
|
||||
|
||||
operator.execute(self.mock_context)
|
||||
|
||||
ti = self.mock_context['ti']
|
||||
|
||||
ti.xcom_push.assert_any_call(key='job_flow_id', value=expected_job_flow_id)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Загрузка…
Ссылка в новой задаче