[AIRFLOW-3610] Add region param for EMR jobflow creation (#4418)

This commit is contained in:
Dana Ma 2019-01-07 08:51:01 +11:00 коммит произвёл Kaxil Naik
Родитель 0bea2d6e8e
Коммит 3d5160fc80
4 изменённых файлов: 9 добавлений и 4 удалений

Просмотреть файл

@ -27,12 +27,13 @@ class EmrHook(AwsHook):
create_job_flow method.
"""
def __init__(self, emr_conn_id=None, *args, **kwargs):
def __init__(self, emr_conn_id=None, region_name=None, *args, **kwargs):
self.emr_conn_id = emr_conn_id
self.region_name = region_name
super(EmrHook, self).__init__(*args, **kwargs)
def get_conn(self):
self.conn = self.get_client_type('emr')
self.conn = self.get_client_type('emr', self.region_name)
return self.conn
def create_job_flow(self, job_flow_overrides):

Просмотреть файл

@ -46,6 +46,7 @@ class EmrCreateJobFlowOperator(BaseOperator):
aws_conn_id='s3_default',
emr_conn_id='emr_default',
job_flow_overrides=None,
region_name=None,
*args, **kwargs):
super(EmrCreateJobFlowOperator, self).__init__(*args, **kwargs)
self.aws_conn_id = aws_conn_id
@ -53,9 +54,10 @@ class EmrCreateJobFlowOperator(BaseOperator):
if job_flow_overrides is None:
job_flow_overrides = {}
self.job_flow_overrides = job_flow_overrides
self.region_name = region_name
def execute(self, context):
emr = EmrHook(aws_conn_id=self.aws_conn_id, emr_conn_id=self.emr_conn_id)
emr = EmrHook(aws_conn_id=self.aws_conn_id, emr_conn_id=self.emr_conn_id, region_name=self.region_name)
self.log.info(
'Creating JobFlow using aws-conn-id: %s, emr-conn-id: %s',

Просмотреть файл

@ -38,7 +38,7 @@ class TestEmrHook(unittest.TestCase):
@mock_emr
def test_get_conn_returns_a_boto3_connection(self):
hook = EmrHook(aws_conn_id='aws_default')
hook = EmrHook(aws_conn_id='aws_default', region_name='ap-southeast-2')
self.assertIsNotNone(hook.get_conn().list_clusters())
@mock_emr

Просмотреть файл

@ -71,12 +71,14 @@ class TestEmrCreateJobFlowOperator(unittest.TestCase):
aws_conn_id='aws_default',
emr_conn_id='emr_default',
job_flow_overrides=self._config,
region_name='ap-southeast-2',
dag=DAG('test_dag_id', default_args=args)
)
def test_init(self):
self.assertEqual(self.operator.aws_conn_id, 'aws_default')
self.assertEqual(self.operator.emr_conn_id, 'emr_default')
self.assertEqual(self.operator.region_name, 'ap-southeast-2')
def test_render_template(self):
ti = TaskInstance(self.operator, DEFAULT_DATE)