[AIRFLOW-2000] Support non-main dataflow job class
Closes #2942 from fenglu-g/master
This commit is contained in:
Родитель
88130a5d7e
Коммит
f6a1c3cf7f
|
@ -151,23 +151,25 @@ class DataFlowHook(GoogleCloudBaseHook):
|
|||
http_authorized = self._authorize()
|
||||
return build('dataflow', 'v1b3', http=http_authorized)
|
||||
|
||||
def _start_dataflow(self, task_id, variables, dataflow,
|
||||
name, command_prefix, label_formatter):
|
||||
def _start_dataflow(self, task_id, variables, name,
|
||||
command_prefix, label_formatter):
|
||||
cmd = command_prefix + self._build_cmd(task_id, variables,
|
||||
dataflow, label_formatter)
|
||||
label_formatter)
|
||||
_Dataflow(cmd).wait_for_done()
|
||||
_DataflowJob(self.get_conn(), variables['project'],
|
||||
name, self.poll_sleep).wait_for_done()
|
||||
|
||||
def start_java_dataflow(self, task_id, variables, dataflow):
|
||||
def start_java_dataflow(self, task_id, variables, dataflow, job_class=None):
|
||||
name = task_id + "-" + str(uuid.uuid1())[:8]
|
||||
variables['jobName'] = name
|
||||
|
||||
def label_formatter(labels_dict):
|
||||
return ['--labels={}'.format(
|
||||
json.dumps(labels_dict).replace(' ', ''))]
|
||||
self._start_dataflow(task_id, variables, dataflow, name,
|
||||
["java", "-jar"], label_formatter)
|
||||
command_prefix = (["java", "-cp", dataflow, job_class] if job_class
|
||||
else ["java", "-jar", dataflow])
|
||||
self._start_dataflow(task_id, variables, name,
|
||||
command_prefix, label_formatter)
|
||||
|
||||
def start_template_dataflow(self, task_id, variables, parameters, dataflow_template):
|
||||
name = task_id + "-" + str(uuid.uuid1())[:8]
|
||||
|
@ -181,11 +183,12 @@ class DataFlowHook(GoogleCloudBaseHook):
|
|||
def label_formatter(labels_dict):
|
||||
return ['--labels={}={}'.format(key, value)
|
||||
for key, value in labels_dict.items()]
|
||||
self._start_dataflow(task_id, variables, dataflow, name,
|
||||
["python"] + py_options, label_formatter)
|
||||
self._start_dataflow(task_id, variables, name,
|
||||
["python"] + py_options + [dataflow],
|
||||
label_formatter)
|
||||
|
||||
def _build_cmd(self, task_id, variables, dataflow, label_formatter):
|
||||
command = [dataflow, "--runner=DataflowRunner"]
|
||||
def _build_cmd(self, task_id, variables, label_formatter):
|
||||
command = ["--runner=DataflowRunner"]
|
||||
if variables is not None:
|
||||
for attr, value in variables.items():
|
||||
if attr == 'labels':
|
||||
|
|
|
@ -73,6 +73,7 @@ class DataFlowJavaOperator(BaseOperator):
|
|||
gcp_conn_id='google_cloud_default',
|
||||
delegate_to=None,
|
||||
poll_sleep=10,
|
||||
job_class=None,
|
||||
*args,
|
||||
**kwargs):
|
||||
"""
|
||||
|
@ -103,6 +104,9 @@ class DataFlowJavaOperator(BaseOperator):
|
|||
Cloud Platform for the dataflow job status while the job is in the
|
||||
JOB_STATE_RUNNING state.
|
||||
:type poll_sleep: int
|
||||
:param job_class: The name of the dataflow job class to be executued, it
|
||||
is often not the main class configured in the dataflow jar file.
|
||||
:type job_class: string
|
||||
"""
|
||||
super(DataFlowJavaOperator, self).__init__(*args, **kwargs)
|
||||
|
||||
|
@ -116,6 +120,7 @@ class DataFlowJavaOperator(BaseOperator):
|
|||
self.dataflow_default_options = dataflow_default_options
|
||||
self.options = options
|
||||
self.poll_sleep = poll_sleep
|
||||
self.job_class = job_class
|
||||
|
||||
def execute(self, context):
|
||||
bucket_helper = GoogleCloudBucketHelper(
|
||||
|
@ -128,7 +133,8 @@ class DataFlowJavaOperator(BaseOperator):
|
|||
dataflow_options = copy.copy(self.dataflow_default_options)
|
||||
dataflow_options.update(self.options)
|
||||
|
||||
hook.start_java_dataflow(self.task_id, dataflow_options, self.jar)
|
||||
hook.start_java_dataflow(self.task_id, dataflow_options,
|
||||
self.jar, self.job_class)
|
||||
|
||||
|
||||
class DataflowTemplateOperator(BaseOperator):
|
||||
|
|
|
@ -37,6 +37,7 @@ PARAMETERS = {
|
|||
}
|
||||
PY_FILE = 'apache_beam.examples.wordcount'
|
||||
JAR_FILE = 'unitest.jar'
|
||||
JOB_CLASS = 'com.example.UnitTest'
|
||||
PY_OPTIONS = ['-m']
|
||||
DATAFLOW_OPTIONS_PY = {
|
||||
'project': 'test',
|
||||
|
@ -62,7 +63,7 @@ def mock_init(self, gcp_conn_id, delegate_to=None):
|
|||
pass
|
||||
|
||||
|
||||
class DataFlowPythonHookTest(unittest.TestCase):
|
||||
class DataFlowHookTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
with mock.patch(BASE_STRING.format('GoogleCloudBaseHook.__init__'),
|
||||
|
@ -115,6 +116,30 @@ class DataFlowPythonHookTest(unittest.TestCase):
|
|||
self.assertListEqual(sorted(mock_dataflow.call_args[0][0]),
|
||||
sorted(EXPECTED_CMD))
|
||||
|
||||
@mock.patch(DATAFLOW_STRING.format('uuid.uuid1'))
|
||||
@mock.patch(DATAFLOW_STRING.format('_DataflowJob'))
|
||||
@mock.patch(DATAFLOW_STRING.format('_Dataflow'))
|
||||
@mock.patch(DATAFLOW_STRING.format('DataFlowHook.get_conn'))
|
||||
def test_start_java_dataflow_with_job_class(
|
||||
self, mock_conn, mock_dataflow, mock_dataflowjob, mock_uuid):
|
||||
mock_uuid.return_value = MOCK_UUID
|
||||
mock_conn.return_value = None
|
||||
dataflow_instance = mock_dataflow.return_value
|
||||
dataflow_instance.wait_for_done.return_value = None
|
||||
dataflowjob_instance = mock_dataflowjob.return_value
|
||||
dataflowjob_instance.wait_for_done.return_value = None
|
||||
self.dataflow_hook.start_java_dataflow(
|
||||
task_id=TASK_ID, variables=DATAFLOW_OPTIONS_JAVA,
|
||||
dataflow=JAR_FILE, job_class=JOB_CLASS)
|
||||
EXPECTED_CMD = ['java', '-cp', JAR_FILE, JOB_CLASS,
|
||||
'--runner=DataflowRunner', '--project=test',
|
||||
'--stagingLocation=gs://test/staging',
|
||||
'--labels={"foo":"bar"}',
|
||||
'--jobName={}-{}'.format(TASK_ID, MOCK_UUID)]
|
||||
self.assertListEqual(sorted(mock_dataflow.call_args[0][0]),
|
||||
sorted(EXPECTED_CMD))
|
||||
|
||||
|
||||
@mock.patch('airflow.contrib.hooks.gcp_dataflow_hook._Dataflow.log')
|
||||
@mock.patch('subprocess.Popen')
|
||||
@mock.patch('select.select')
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
import unittest
|
||||
|
||||
from airflow.contrib.operators.dataflow_operator import DataFlowPythonOperator, \
|
||||
DataflowTemplateOperator
|
||||
DataFlowJavaOperator, DataflowTemplateOperator
|
||||
from airflow.contrib.operators.dataflow_operator import DataFlowPythonOperator
|
||||
from airflow.version import version
|
||||
|
||||
|
@ -36,8 +36,10 @@ PARAMETERS = {
|
|||
'output': 'gs://test/output/my_output'
|
||||
}
|
||||
PY_FILE = 'gs://my-bucket/my-object.py'
|
||||
JAR_FILE = 'example/test.jar'
|
||||
JOB_CLASS = 'com.test.NotMain'
|
||||
PY_OPTIONS = ['-m']
|
||||
DEFAULT_OPTIONS_PYTHON = {
|
||||
DEFAULT_OPTIONS_PYTHON = DEFAULT_OPTIONS_JAVA = {
|
||||
'project': 'test',
|
||||
'stagingLocation': 'gs://test/staging',
|
||||
}
|
||||
|
@ -105,6 +107,44 @@ class DataFlowPythonOperatorTest(unittest.TestCase):
|
|||
self.assertTrue(self.dataflow.py_file.startswith('/tmp/dataflow'))
|
||||
|
||||
|
||||
class DataFlowJavaOperatorTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.dataflow = DataFlowJavaOperator(
|
||||
task_id=TASK_ID,
|
||||
jar=JAR_FILE,
|
||||
job_class=JOB_CLASS,
|
||||
dataflow_default_options=DEFAULT_OPTIONS_JAVA,
|
||||
options=ADDITIONAL_OPTIONS,
|
||||
poll_sleep=POLL_SLEEP)
|
||||
|
||||
def test_init(self):
|
||||
"""Test DataflowTemplateOperator instance is properly initialized."""
|
||||
self.assertEqual(self.dataflow.task_id, TASK_ID)
|
||||
self.assertEqual(self.dataflow.poll_sleep, POLL_SLEEP)
|
||||
self.assertEqual(self.dataflow.dataflow_default_options,
|
||||
DEFAULT_OPTIONS_JAVA)
|
||||
self.assertEqual(self.dataflow.job_class, JOB_CLASS)
|
||||
self.assertEqual(self.dataflow.jar, JAR_FILE)
|
||||
self.assertEqual(self.dataflow.options,
|
||||
EXPECTED_ADDITIONAL_OPTIONS)
|
||||
|
||||
@mock.patch('airflow.contrib.operators.dataflow_operator.DataFlowHook')
|
||||
@mock.patch(GCS_HOOK_STRING.format('GoogleCloudBucketHelper'))
|
||||
def test_exec(self, gcs_hook, dataflow_mock):
|
||||
"""Test DataFlowHook is created and the right args are passed to
|
||||
start_java_workflow.
|
||||
|
||||
"""
|
||||
start_java_hook = dataflow_mock.return_value.start_java_dataflow
|
||||
gcs_download_hook = gcs_hook.return_value.google_cloud_to_local
|
||||
self.dataflow.execute(None)
|
||||
self.assertTrue(dataflow_mock.called)
|
||||
gcs_download_hook.assert_called_once_with(JAR_FILE)
|
||||
start_java_hook.assert_called_once_with(TASK_ID, mock.ANY,
|
||||
mock.ANY, JOB_CLASS)
|
||||
|
||||
|
||||
class DataFlowTemplateOperatorTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
|
Загрузка…
Ссылка в новой задаче