[AIRFLOW-2000] Support non-main dataflow job class

Closes #2942 from fenglu-g/master
This commit is contained in:
fenglu-g 2018-01-16 09:32:32 -08:00 коммит произвёл Chris Riccomini
Родитель 88130a5d7e
Коммит f6a1c3cf7f
4 изменённых файлов: 88 добавлений и 14 удалений

Просмотреть файл

@ -151,23 +151,25 @@ class DataFlowHook(GoogleCloudBaseHook):
http_authorized = self._authorize()
return build('dataflow', 'v1b3', http=http_authorized)
def _start_dataflow(self, task_id, variables, dataflow,
name, command_prefix, label_formatter):
def _start_dataflow(self, task_id, variables, name,
command_prefix, label_formatter):
cmd = command_prefix + self._build_cmd(task_id, variables,
dataflow, label_formatter)
label_formatter)
_Dataflow(cmd).wait_for_done()
_DataflowJob(self.get_conn(), variables['project'],
name, self.poll_sleep).wait_for_done()
def start_java_dataflow(self, task_id, variables, dataflow):
def start_java_dataflow(self, task_id, variables, dataflow, job_class=None):
name = task_id + "-" + str(uuid.uuid1())[:8]
variables['jobName'] = name
def label_formatter(labels_dict):
return ['--labels={}'.format(
json.dumps(labels_dict).replace(' ', ''))]
self._start_dataflow(task_id, variables, dataflow, name,
["java", "-jar"], label_formatter)
command_prefix = (["java", "-cp", dataflow, job_class] if job_class
else ["java", "-jar", dataflow])
self._start_dataflow(task_id, variables, name,
command_prefix, label_formatter)
def start_template_dataflow(self, task_id, variables, parameters, dataflow_template):
name = task_id + "-" + str(uuid.uuid1())[:8]
@ -181,11 +183,12 @@ class DataFlowHook(GoogleCloudBaseHook):
def label_formatter(labels_dict):
return ['--labels={}={}'.format(key, value)
for key, value in labels_dict.items()]
self._start_dataflow(task_id, variables, dataflow, name,
["python"] + py_options, label_formatter)
self._start_dataflow(task_id, variables, name,
["python"] + py_options + [dataflow],
label_formatter)
def _build_cmd(self, task_id, variables, dataflow, label_formatter):
command = [dataflow, "--runner=DataflowRunner"]
def _build_cmd(self, task_id, variables, label_formatter):
command = ["--runner=DataflowRunner"]
if variables is not None:
for attr, value in variables.items():
if attr == 'labels':

Просмотреть файл

@ -73,6 +73,7 @@ class DataFlowJavaOperator(BaseOperator):
gcp_conn_id='google_cloud_default',
delegate_to=None,
poll_sleep=10,
job_class=None,
*args,
**kwargs):
"""
@ -103,6 +104,9 @@ class DataFlowJavaOperator(BaseOperator):
Cloud Platform for the dataflow job status while the job is in the
JOB_STATE_RUNNING state.
:type poll_sleep: int
:param job_class: The name of the dataflow job class to be executued, it
is often not the main class configured in the dataflow jar file.
:type job_class: string
"""
super(DataFlowJavaOperator, self).__init__(*args, **kwargs)
@ -116,6 +120,7 @@ class DataFlowJavaOperator(BaseOperator):
self.dataflow_default_options = dataflow_default_options
self.options = options
self.poll_sleep = poll_sleep
self.job_class = job_class
def execute(self, context):
bucket_helper = GoogleCloudBucketHelper(
@ -128,7 +133,8 @@ class DataFlowJavaOperator(BaseOperator):
dataflow_options = copy.copy(self.dataflow_default_options)
dataflow_options.update(self.options)
hook.start_java_dataflow(self.task_id, dataflow_options, self.jar)
hook.start_java_dataflow(self.task_id, dataflow_options,
self.jar, self.job_class)
class DataflowTemplateOperator(BaseOperator):

Просмотреть файл

@ -37,6 +37,7 @@ PARAMETERS = {
}
PY_FILE = 'apache_beam.examples.wordcount'
JAR_FILE = 'unitest.jar'
JOB_CLASS = 'com.example.UnitTest'
PY_OPTIONS = ['-m']
DATAFLOW_OPTIONS_PY = {
'project': 'test',
@ -62,7 +63,7 @@ def mock_init(self, gcp_conn_id, delegate_to=None):
pass
class DataFlowPythonHookTest(unittest.TestCase):
class DataFlowHookTest(unittest.TestCase):
def setUp(self):
with mock.patch(BASE_STRING.format('GoogleCloudBaseHook.__init__'),
@ -115,6 +116,30 @@ class DataFlowPythonHookTest(unittest.TestCase):
self.assertListEqual(sorted(mock_dataflow.call_args[0][0]),
sorted(EXPECTED_CMD))
@mock.patch(DATAFLOW_STRING.format('uuid.uuid1'))
@mock.patch(DATAFLOW_STRING.format('_DataflowJob'))
@mock.patch(DATAFLOW_STRING.format('_Dataflow'))
@mock.patch(DATAFLOW_STRING.format('DataFlowHook.get_conn'))
def test_start_java_dataflow_with_job_class(
self, mock_conn, mock_dataflow, mock_dataflowjob, mock_uuid):
mock_uuid.return_value = MOCK_UUID
mock_conn.return_value = None
dataflow_instance = mock_dataflow.return_value
dataflow_instance.wait_for_done.return_value = None
dataflowjob_instance = mock_dataflowjob.return_value
dataflowjob_instance.wait_for_done.return_value = None
self.dataflow_hook.start_java_dataflow(
task_id=TASK_ID, variables=DATAFLOW_OPTIONS_JAVA,
dataflow=JAR_FILE, job_class=JOB_CLASS)
EXPECTED_CMD = ['java', '-cp', JAR_FILE, JOB_CLASS,
'--runner=DataflowRunner', '--project=test',
'--stagingLocation=gs://test/staging',
'--labels={"foo":"bar"}',
'--jobName={}-{}'.format(TASK_ID, MOCK_UUID)]
self.assertListEqual(sorted(mock_dataflow.call_args[0][0]),
sorted(EXPECTED_CMD))
@mock.patch('airflow.contrib.hooks.gcp_dataflow_hook._Dataflow.log')
@mock.patch('subprocess.Popen')
@mock.patch('select.select')

Просмотреть файл

@ -16,7 +16,7 @@
import unittest
from airflow.contrib.operators.dataflow_operator import DataFlowPythonOperator, \
DataflowTemplateOperator
DataFlowJavaOperator, DataflowTemplateOperator
from airflow.contrib.operators.dataflow_operator import DataFlowPythonOperator
from airflow.version import version
@ -36,8 +36,10 @@ PARAMETERS = {
'output': 'gs://test/output/my_output'
}
PY_FILE = 'gs://my-bucket/my-object.py'
JAR_FILE = 'example/test.jar'
JOB_CLASS = 'com.test.NotMain'
PY_OPTIONS = ['-m']
DEFAULT_OPTIONS_PYTHON = {
DEFAULT_OPTIONS_PYTHON = DEFAULT_OPTIONS_JAVA = {
'project': 'test',
'stagingLocation': 'gs://test/staging',
}
@ -105,6 +107,44 @@ class DataFlowPythonOperatorTest(unittest.TestCase):
self.assertTrue(self.dataflow.py_file.startswith('/tmp/dataflow'))
class DataFlowJavaOperatorTest(unittest.TestCase):
def setUp(self):
self.dataflow = DataFlowJavaOperator(
task_id=TASK_ID,
jar=JAR_FILE,
job_class=JOB_CLASS,
dataflow_default_options=DEFAULT_OPTIONS_JAVA,
options=ADDITIONAL_OPTIONS,
poll_sleep=POLL_SLEEP)
def test_init(self):
"""Test DataflowTemplateOperator instance is properly initialized."""
self.assertEqual(self.dataflow.task_id, TASK_ID)
self.assertEqual(self.dataflow.poll_sleep, POLL_SLEEP)
self.assertEqual(self.dataflow.dataflow_default_options,
DEFAULT_OPTIONS_JAVA)
self.assertEqual(self.dataflow.job_class, JOB_CLASS)
self.assertEqual(self.dataflow.jar, JAR_FILE)
self.assertEqual(self.dataflow.options,
EXPECTED_ADDITIONAL_OPTIONS)
@mock.patch('airflow.contrib.operators.dataflow_operator.DataFlowHook')
@mock.patch(GCS_HOOK_STRING.format('GoogleCloudBucketHelper'))
def test_exec(self, gcs_hook, dataflow_mock):
"""Test DataFlowHook is created and the right args are passed to
start_java_workflow.
"""
start_java_hook = dataflow_mock.return_value.start_java_dataflow
gcs_download_hook = gcs_hook.return_value.google_cloud_to_local
self.dataflow.execute(None)
self.assertTrue(dataflow_mock.called)
gcs_download_hook.assert_called_once_with(JAR_FILE)
start_java_hook.assert_called_once_with(TASK_ID, mock.ANY,
mock.ANY, JOB_CLASS)
class DataFlowTemplateOperatorTest(unittest.TestCase):
def setUp(self):