[AIRFLOW-3610] Add region param for EMR jobflow creation (#4418)
This commit is contained in:
Родитель
0bea2d6e8e
Коммит
3d5160fc80
|
@ -27,12 +27,13 @@ class EmrHook(AwsHook):
|
|||
create_job_flow method.
|
||||
"""
|
||||
|
||||
def __init__(self, emr_conn_id=None, *args, **kwargs):
|
||||
def __init__(self, emr_conn_id=None, region_name=None, *args, **kwargs):
|
||||
self.emr_conn_id = emr_conn_id
|
||||
self.region_name = region_name
|
||||
super(EmrHook, self).__init__(*args, **kwargs)
|
||||
|
||||
def get_conn(self):
|
||||
self.conn = self.get_client_type('emr')
|
||||
self.conn = self.get_client_type('emr', self.region_name)
|
||||
return self.conn
|
||||
|
||||
def create_job_flow(self, job_flow_overrides):
|
||||
|
|
|
@ -46,6 +46,7 @@ class EmrCreateJobFlowOperator(BaseOperator):
|
|||
aws_conn_id='s3_default',
|
||||
emr_conn_id='emr_default',
|
||||
job_flow_overrides=None,
|
||||
region_name=None,
|
||||
*args, **kwargs):
|
||||
super(EmrCreateJobFlowOperator, self).__init__(*args, **kwargs)
|
||||
self.aws_conn_id = aws_conn_id
|
||||
|
@ -53,9 +54,10 @@ class EmrCreateJobFlowOperator(BaseOperator):
|
|||
if job_flow_overrides is None:
|
||||
job_flow_overrides = {}
|
||||
self.job_flow_overrides = job_flow_overrides
|
||||
self.region_name = region_name
|
||||
|
||||
def execute(self, context):
|
||||
emr = EmrHook(aws_conn_id=self.aws_conn_id, emr_conn_id=self.emr_conn_id)
|
||||
emr = EmrHook(aws_conn_id=self.aws_conn_id, emr_conn_id=self.emr_conn_id, region_name=self.region_name)
|
||||
|
||||
self.log.info(
|
||||
'Creating JobFlow using aws-conn-id: %s, emr-conn-id: %s',
|
||||
|
|
|
@ -38,7 +38,7 @@ class TestEmrHook(unittest.TestCase):
|
|||
|
||||
@mock_emr
|
||||
def test_get_conn_returns_a_boto3_connection(self):
|
||||
hook = EmrHook(aws_conn_id='aws_default')
|
||||
hook = EmrHook(aws_conn_id='aws_default', region_name='ap-southeast-2')
|
||||
self.assertIsNotNone(hook.get_conn().list_clusters())
|
||||
|
||||
@mock_emr
|
||||
|
|
|
@ -71,12 +71,14 @@ class TestEmrCreateJobFlowOperator(unittest.TestCase):
|
|||
aws_conn_id='aws_default',
|
||||
emr_conn_id='emr_default',
|
||||
job_flow_overrides=self._config,
|
||||
region_name='ap-southeast-2',
|
||||
dag=DAG('test_dag_id', default_args=args)
|
||||
)
|
||||
|
||||
def test_init(self):
|
||||
self.assertEqual(self.operator.aws_conn_id, 'aws_default')
|
||||
self.assertEqual(self.operator.emr_conn_id, 'emr_default')
|
||||
self.assertEqual(self.operator.region_name, 'ap-southeast-2')
|
||||
|
||||
def test_render_template(self):
|
||||
ti = TaskInstance(self.operator, DEFAULT_DATE)
|
||||
|
|
Загрузка…
Ссылка в новой задаче