Add DataflowStartFlexTemplateOperator (#8550)
This commit is contained in:
Родитель
45d608396d
Коммит
3c10ca6504
|
@ -0,0 +1,61 @@
|
|||
#
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Example Airflow DAG for Google Cloud Dataflow service
|
||||
"""
|
||||
import os
|
||||
|
||||
from airflow import models
|
||||
from airflow.providers.google.cloud.operators.dataflow import DataflowStartFlexTemplateOperator
|
||||
from airflow.utils.dates import days_ago
|
||||
|
||||
GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project")
|
||||
|
||||
DATAFLOW_FLEX_TEMPLATE_JOB_NAME = os.environ.get('DATAFLOW_FLEX_TEMPLATE_JOB_NAME', "dataflow-flex-template")
|
||||
|
||||
# For simplicity we use the same topic name as the subscription name.
|
||||
PUBSUB_FLEX_TEMPLATE_TOPIC = os.environ.get('DATAFLOW_PUBSUB_FLEX_TEMPLATE_TOPIC', "dataflow-flex-template")
|
||||
PUBSUB_FLEX_TEMPLATE_SUBSCRIPTION = PUBSUB_FLEX_TEMPLATE_TOPIC
|
||||
GCS_FLEX_TEMPLATE_TEMPLATE_PATH = os.environ.get(
|
||||
'DATAFLOW_GCS_FLEX_TEMPLATE_TEMPLATE_PATH',
|
||||
"gs://test-airflow-dataflow-flex-template/samples/dataflow/templates/streaming-beam-sql.json",
|
||||
)
|
||||
BQ_FLEX_TEMPLATE_DATASET = os.environ.get('DATAFLOW_BQ_FLEX_TEMPLATE_DATASET', 'airflow_dataflow_samples')
|
||||
BQ_FLEX_TEMPLATE_LOCATION = os.environ.get('DATAFLOW_BQ_FLEX_TEMPLATE_LOCAATION>', 'us-west1')
|
||||
|
||||
with models.DAG(
|
||||
dag_id="example_gcp_dataflow_flex_template_java",
|
||||
start_date=days_ago(1),
|
||||
schedule_interval=None, # Override to match your needs
|
||||
) as dag_flex_template:
|
||||
start_flex_template = DataflowStartFlexTemplateOperator(
|
||||
task_id="start_flex_template_streaming_beam_sql",
|
||||
body={
|
||||
"launchParameter": {
|
||||
"containerSpecGcsPath": GCS_FLEX_TEMPLATE_TEMPLATE_PATH,
|
||||
"jobName": DATAFLOW_FLEX_TEMPLATE_JOB_NAME,
|
||||
"parameters": {
|
||||
"inputSubscription": PUBSUB_FLEX_TEMPLATE_SUBSCRIPTION,
|
||||
"outputTable": f"{GCP_PROJECT_ID}:{BQ_FLEX_TEMPLATE_DATASET}.streaming_beam_sql",
|
||||
},
|
||||
}
|
||||
},
|
||||
do_xcom_push=True,
|
||||
location=BQ_FLEX_TEMPLATE_LOCATION,
|
||||
)
|
|
@ -96,16 +96,32 @@ _fallback_to_project_id_from_variables = _fallback_variable_parameter('project_i
|
|||
class DataflowJobStatus:
|
||||
"""
|
||||
Helper class with Dataflow job statuses.
|
||||
Reference: https://cloud.google.com/dataflow/docs/reference/rest/v1b3/projects.jobs#Job.JobState
|
||||
"""
|
||||
|
||||
JOB_STATE_DONE = "JOB_STATE_DONE"
|
||||
JOB_STATE_UNKNOWN = "JOB_STATE_UNKNOWN"
|
||||
JOB_STATE_STOPPED = "JOB_STATE_STOPPED"
|
||||
JOB_STATE_RUNNING = "JOB_STATE_RUNNING"
|
||||
JOB_STATE_FAILED = "JOB_STATE_FAILED"
|
||||
JOB_STATE_CANCELLED = "JOB_STATE_CANCELLED"
|
||||
JOB_STATE_UPDATED = "JOB_STATE_UPDATED"
|
||||
JOB_STATE_DRAINING = "JOB_STATE_DRAINING"
|
||||
JOB_STATE_DRAINED = "JOB_STATE_DRAINED"
|
||||
JOB_STATE_PENDING = "JOB_STATE_PENDING"
|
||||
JOB_STATE_CANCELLING = "JOB_STATE_CANCELLING"
|
||||
JOB_STATE_QUEUED = "JOB_STATE_QUEUED"
|
||||
FAILED_END_STATES = {JOB_STATE_FAILED, JOB_STATE_CANCELLED}
|
||||
SUCCEEDED_END_STATES = {JOB_STATE_DONE}
|
||||
END_STATES = SUCCEEDED_END_STATES | FAILED_END_STATES
|
||||
SUCCEEDED_END_STATES = {JOB_STATE_DONE, JOB_STATE_UPDATED, JOB_STATE_DRAINED}
|
||||
TERMINAL_STATES = SUCCEEDED_END_STATES | FAILED_END_STATES
|
||||
AWAITING_STATES = {
|
||||
JOB_STATE_RUNNING,
|
||||
JOB_STATE_PENDING,
|
||||
JOB_STATE_QUEUED,
|
||||
JOB_STATE_CANCELLING,
|
||||
JOB_STATE_DRAINING,
|
||||
JOB_STATE_STOPPED,
|
||||
}
|
||||
|
||||
|
||||
class DataflowJobType:
|
||||
|
@ -170,7 +186,7 @@ class _DataflowJobsController(LoggingMixin):
|
|||
return False
|
||||
|
||||
for job in self._jobs:
|
||||
if job['currentState'] not in DataflowJobStatus.END_STATES:
|
||||
if job['currentState'] not in DataflowJobStatus.TERMINAL_STATES:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
@ -261,10 +277,7 @@ class _DataflowJobsController(LoggingMixin):
|
|||
and DataflowJobType.JOB_TYPE_STREAMING == job['type']
|
||||
):
|
||||
return True
|
||||
elif job['currentState'] in {
|
||||
DataflowJobStatus.JOB_STATE_RUNNING,
|
||||
DataflowJobStatus.JOB_STATE_PENDING,
|
||||
}:
|
||||
elif job['currentState'] in DataflowJobStatus.AWAITING_STATES:
|
||||
return False
|
||||
self.log.debug("Current job: %s", str(job))
|
||||
raise Exception(
|
||||
|
@ -282,14 +295,14 @@ class _DataflowJobsController(LoggingMixin):
|
|||
time.sleep(self._poll_sleep)
|
||||
self._refresh_jobs()
|
||||
|
||||
def get_jobs(self) -> List[dict]:
|
||||
def get_jobs(self, refresh=False) -> List[dict]:
|
||||
"""
|
||||
Returns Dataflow jobs.
|
||||
|
||||
:return: list of jobs
|
||||
:rtype: list
|
||||
"""
|
||||
if not self._jobs:
|
||||
if not self._jobs or refresh:
|
||||
self._refresh_jobs()
|
||||
if not self._jobs:
|
||||
raise ValueError("Could not read _jobs")
|
||||
|
@ -300,23 +313,26 @@ class _DataflowJobsController(LoggingMixin):
|
|||
"""
|
||||
Cancels current job
|
||||
"""
|
||||
jobs = self._get_current_jobs()
|
||||
batch = self._dataflow.new_batch_http_request()
|
||||
job_ids = [job['id'] for job in jobs]
|
||||
self.log.info("Canceling jobs: %s", ", ".join(job_ids))
|
||||
for job_id in job_ids:
|
||||
batch.add(
|
||||
self._dataflow.projects()
|
||||
.locations()
|
||||
.jobs()
|
||||
.update(
|
||||
projectId=self._project_number,
|
||||
location=self._job_location,
|
||||
jobId=job_id,
|
||||
body={"requestedState": DataflowJobStatus.JOB_STATE_CANCELLED},
|
||||
jobs = self.get_jobs()
|
||||
job_ids = [job['id'] for job in jobs if job['currentState'] not in DataflowJobStatus.TERMINAL_STATES]
|
||||
if job_ids:
|
||||
batch = self._dataflow.new_batch_http_request()
|
||||
self.log.info("Canceling jobs: %s", ", ".join(job_ids))
|
||||
for job_id in job_ids:
|
||||
batch.add(
|
||||
self._dataflow.projects()
|
||||
.locations()
|
||||
.jobs()
|
||||
.update(
|
||||
projectId=self._project_number,
|
||||
location=self._job_location,
|
||||
jobId=job_id,
|
||||
body={"requestedState": DataflowJobStatus.JOB_STATE_CANCELLED},
|
||||
)
|
||||
)
|
||||
)
|
||||
batch.execute()
|
||||
batch.execute()
|
||||
else:
|
||||
self.log.info("No jobs to cancel")
|
||||
|
||||
|
||||
class _DataflowRunner(LoggingMixin):
|
||||
|
@ -631,6 +647,52 @@ class DataflowHook(GoogleBaseHook):
|
|||
jobs_controller.wait_for_done()
|
||||
return response["job"]
|
||||
|
||||
@GoogleBaseHook.fallback_to_default_project_id
|
||||
def start_flex_template(
|
||||
self,
|
||||
body: dict,
|
||||
location: str,
|
||||
project_id: str,
|
||||
on_new_job_id_callback: Optional[Callable[[str], None]] = None,
|
||||
):
|
||||
"""
|
||||
Starts flex templates with the Dataflow pipeline.
|
||||
|
||||
:param body: The request body. See:
|
||||
https://cloud.google.com/dataflow/docs/reference/rest/v1b3/projects.locations.flexTemplates/launch#request-body
|
||||
:param location: The location of the Dataflow job (for example europe-west1)
|
||||
:type location: str
|
||||
:param project_id: The ID of the GCP project that owns the job.
|
||||
If set to ``None`` or missing, the default project_id from the GCP connection is used.
|
||||
:type project_id: Optional[str]
|
||||
:param on_new_job_id_callback: A callback that is called when a Job ID is detected.
|
||||
:return: the Job
|
||||
"""
|
||||
service = self.get_conn()
|
||||
request = (
|
||||
service.projects() # pylint: disable=no-member
|
||||
.locations()
|
||||
.flexTemplates()
|
||||
.launch(projectId=project_id, body=body, location=location)
|
||||
)
|
||||
response = request.execute(num_retries=self.num_retries)
|
||||
job_id = response['job']['id']
|
||||
|
||||
if on_new_job_id_callback:
|
||||
on_new_job_id_callback(job_id)
|
||||
|
||||
jobs_controller = _DataflowJobsController(
|
||||
dataflow=self.get_conn(),
|
||||
project_number=project_id,
|
||||
job_id=job_id,
|
||||
location=location,
|
||||
poll_sleep=self.poll_sleep,
|
||||
num_retries=self.num_retries,
|
||||
)
|
||||
jobs_controller.wait_for_done()
|
||||
|
||||
return jobs_controller.get_jobs(refresh=True)[0]
|
||||
|
||||
@_fallback_to_location_from_variables
|
||||
@_fallback_to_project_id_from_variables
|
||||
@GoogleBaseHook.fallback_to_default_project_id
|
||||
|
@ -659,6 +721,9 @@ class DataflowHook(GoogleBaseHook):
|
|||
:type dataflow: str
|
||||
:param py_options: Additional options.
|
||||
:type py_options: List[str]
|
||||
:param project_id: The ID of the GCP project that owns the job.
|
||||
If set to ``None`` or missing, the default project_id from the GCP connection is used.
|
||||
:type project_id: Optional[str]
|
||||
:param py_interpreter: Python version of the beam pipeline.
|
||||
If None, this defaults to the python3.
|
||||
To track python versions supported by beam and related
|
||||
|
|
|
@ -451,6 +451,72 @@ class DataflowTemplatedJobStartOperator(BaseOperator):
|
|||
self.hook.cancel_job(job_id=self.job_id, project_id=self.project_id)
|
||||
|
||||
|
||||
class DataflowStartFlexTemplateOperator(BaseOperator):
|
||||
"""
|
||||
Starts flex templates with the Dataflow pipeline.
|
||||
|
||||
:param body: The request body. See:
|
||||
https://cloud.google.com/dataflow/docs/reference/rest/v1b3/projects.locations.flexTemplates/launch#request-body
|
||||
:param location: The location of the Dataflow job (for example europe-west1)
|
||||
:type location: str
|
||||
:param project_id: The ID of the GCP project that owns the job.
|
||||
If set to ``None`` or missing, the default project_id from the GCP connection is used.
|
||||
:type project_id: Optional[str]
|
||||
:param gcp_conn_id: The connection ID to use connecting to Google Cloud
|
||||
Platform.
|
||||
:type gcp_conn_id: str
|
||||
:param delegate_to: The account to impersonate, if any.
|
||||
For this to work, the service account making the request must have
|
||||
domain-wide delegation enabled.
|
||||
:type delegate_to: str
|
||||
"""
|
||||
|
||||
template_fields = ["body", 'location', 'project_id', 'gcp_conn_id']
|
||||
|
||||
@apply_defaults
|
||||
def __init__(
|
||||
self,
|
||||
body: Dict,
|
||||
location: str,
|
||||
project_id: Optional[str] = None,
|
||||
gcp_conn_id: str = 'google_cloud_default',
|
||||
delegate_to: Optional[str] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.body = body
|
||||
self.location = location
|
||||
self.project_id = project_id
|
||||
self.gcp_conn_id = gcp_conn_id
|
||||
self.delegate_to = delegate_to
|
||||
self.job_id = None
|
||||
self.hook: Optional[DataflowHook] = None
|
||||
|
||||
def execute(self, context):
|
||||
self.hook = DataflowHook(
|
||||
gcp_conn_id=self.gcp_conn_id,
|
||||
delegate_to=self.delegate_to,
|
||||
)
|
||||
|
||||
def set_current_job_id(job_id):
|
||||
self.job_id = job_id
|
||||
|
||||
job = self.hook.start_flex_template(
|
||||
body=self.body,
|
||||
location=self.location,
|
||||
project_id=self.project_id,
|
||||
on_new_job_id_callback=set_current_job_id,
|
||||
)
|
||||
|
||||
return job
|
||||
|
||||
def on_kill(self) -> None:
|
||||
self.log.info("On kill.")
|
||||
if self.job_id:
|
||||
self.hook.cancel_job(job_id=self.job_id, project_id=self.project_id)
|
||||
|
||||
|
||||
class DataflowCreatePythonJobOperator(BaseOperator):
|
||||
"""
|
||||
Launching Cloud Dataflow jobs written in python. Note that both
|
||||
|
|
|
@ -81,6 +81,15 @@ TEST_PROJECT = 'test-project'
|
|||
TEST_JOB_ID = 'test-job-id'
|
||||
TEST_LOCATION = 'custom-location'
|
||||
DEFAULT_PY_INTERPRETER = 'python3'
|
||||
TEST_FLEX_PARAMETERS = {
|
||||
"containerSpecGcsPath": "gs://test-bucket/test-file",
|
||||
"jobName": 'test-job-name',
|
||||
"parameters": {
|
||||
"inputSubscription": 'test-subsription',
|
||||
"outputTable": "test-project:test-dataset.streaming_beam_sql",
|
||||
},
|
||||
}
|
||||
TEST_PROJECT_ID = 'test-project-id'
|
||||
|
||||
|
||||
class TestFallbackToVariables(unittest.TestCase):
|
||||
|
@ -812,6 +821,40 @@ class TestDataflowTemplateHook(unittest.TestCase):
|
|||
)
|
||||
mock_uuid.assert_called_once_with()
|
||||
|
||||
@mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
|
||||
@mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn'))
|
||||
def test_start_flex_template(self, mock_conn, mock_controller):
|
||||
mock_locations = mock_conn.return_value.projects.return_value.locations
|
||||
launch_method = mock_locations.return_value.flexTemplates.return_value.launch
|
||||
launch_method.return_value.execute.return_value = {"job": {"id": TEST_JOB_ID}}
|
||||
mock_controller.return_value.get_jobs.return_value = [{"id": TEST_JOB_ID}]
|
||||
|
||||
on_new_job_id_callback = mock.MagicMock()
|
||||
result = self.dataflow_hook.start_flex_template(
|
||||
body={"launchParameter": TEST_FLEX_PARAMETERS},
|
||||
location=TEST_LOCATION,
|
||||
project_id=TEST_PROJECT_ID,
|
||||
on_new_job_id_callback=on_new_job_id_callback,
|
||||
)
|
||||
on_new_job_id_callback.assert_called_once_with(TEST_JOB_ID)
|
||||
launch_method.assert_called_once_with(
|
||||
projectId='test-project-id',
|
||||
body={'launchParameter': TEST_FLEX_PARAMETERS},
|
||||
location=TEST_LOCATION,
|
||||
)
|
||||
mock_controller.assert_called_once_with(
|
||||
dataflow=mock_conn.return_value,
|
||||
project_number=TEST_PROJECT_ID,
|
||||
job_id=TEST_JOB_ID,
|
||||
location=TEST_LOCATION,
|
||||
poll_sleep=self.dataflow_hook.poll_sleep,
|
||||
num_retries=self.dataflow_hook.num_retries,
|
||||
)
|
||||
mock_controller.return_value.get_jobs.wait_for_done.assrt_called_once_with()
|
||||
mock_controller.return_value.get_jobs.assrt_called_once_with()
|
||||
|
||||
self.assertEqual(result, {"id": TEST_JOB_ID})
|
||||
|
||||
@mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
|
||||
@mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn'))
|
||||
def test_cancel_job(self, mock_get_conn, jobs_controller):
|
||||
|
@ -1114,54 +1157,37 @@ class TestDataflowJob(unittest.TestCase):
|
|||
self.assertEqual(False, result)
|
||||
|
||||
def test_dataflow_job_cancel_job(self):
|
||||
job = {
|
||||
"id": TEST_JOB_ID,
|
||||
"name": UNIQUE_JOB_NAME,
|
||||
"currentState": DataflowJobStatus.JOB_STATE_RUNNING,
|
||||
}
|
||||
# fmt: off
|
||||
get_method = (
|
||||
self.mock_dataflow.projects.return_value.
|
||||
locations.return_value.
|
||||
jobs.return_value.
|
||||
get
|
||||
)
|
||||
get_method.return_value.execute.return_value = job
|
||||
mock_jobs = self.mock_dataflow.projects.return_value.locations.return_value.jobs
|
||||
get_method = mock_jobs.return_value.get
|
||||
get_method.return_value.execute.side_effect = [
|
||||
{"id": TEST_JOB_ID, "name": JOB_NAME, "currentState": DataflowJobStatus.JOB_STATE_RUNNING},
|
||||
{"id": TEST_JOB_ID, "name": JOB_NAME, "currentState": DataflowJobStatus.JOB_STATE_PENDING},
|
||||
{"id": TEST_JOB_ID, "name": JOB_NAME, "currentState": DataflowJobStatus.JOB_STATE_QUEUED},
|
||||
{"id": TEST_JOB_ID, "name": JOB_NAME, "currentState": DataflowJobStatus.JOB_STATE_CANCELLING},
|
||||
{"id": TEST_JOB_ID, "name": JOB_NAME, "currentState": DataflowJobStatus.JOB_STATE_DRAINING},
|
||||
{"id": TEST_JOB_ID, "name": JOB_NAME, "currentState": DataflowJobStatus.JOB_STATE_STOPPED},
|
||||
{"id": TEST_JOB_ID, "name": JOB_NAME, "currentState": DataflowJobStatus.JOB_STATE_CANCELLED},
|
||||
]
|
||||
|
||||
(
|
||||
self.mock_dataflow.projects.return_value.
|
||||
locations.return_value.
|
||||
jobs.return_value.
|
||||
list_next.return_value
|
||||
) = None
|
||||
# fmt: on
|
||||
mock_jobs.return_value.list_next.return_value = None
|
||||
dataflow_job = _DataflowJobsController(
|
||||
dataflow=self.mock_dataflow,
|
||||
project_number=TEST_PROJECT,
|
||||
name=UNIQUE_JOB_NAME,
|
||||
location=TEST_LOCATION,
|
||||
poll_sleep=10,
|
||||
poll_sleep=0,
|
||||
job_id=TEST_JOB_ID,
|
||||
num_retries=20,
|
||||
multiple_jobs=False,
|
||||
)
|
||||
dataflow_job.cancel()
|
||||
|
||||
get_method.assert_called_once_with(jobId=TEST_JOB_ID, location=TEST_LOCATION, projectId=TEST_PROJECT)
|
||||
|
||||
get_method.return_value.execute.assert_called_once_with(num_retries=20)
|
||||
get_method.assert_called_with(jobId=TEST_JOB_ID, location=TEST_LOCATION, projectId=TEST_PROJECT)
|
||||
get_method.return_value.execute.assert_called_with(num_retries=20)
|
||||
|
||||
self.mock_dataflow.new_batch_http_request.assert_called_once_with()
|
||||
|
||||
mock_batch = self.mock_dataflow.new_batch_http_request.return_value
|
||||
# fmt: off
|
||||
mock_update = (
|
||||
self.mock_dataflow.projects.return_value.
|
||||
locations.return_value.
|
||||
jobs.return_value.
|
||||
update
|
||||
)
|
||||
# fmt: on
|
||||
mock_update = mock_jobs.return_value.update
|
||||
mock_update.assert_called_once_with(
|
||||
body={'requestedState': 'JOB_STATE_CANCELLED'},
|
||||
jobId='test-job-id',
|
||||
|
@ -1169,7 +1195,36 @@ class TestDataflowJob(unittest.TestCase):
|
|||
projectId='test-project',
|
||||
)
|
||||
mock_batch.add.assert_called_once_with(mock_update.return_value)
|
||||
mock_batch.execute.assert_called_once()
|
||||
|
||||
def test_dataflow_job_cancel_job_no_running_jobs(self):
|
||||
mock_jobs = self.mock_dataflow.projects.return_value.locations.return_value.jobs
|
||||
get_method = mock_jobs.return_value.get
|
||||
get_method.return_value.execute.side_effect = [
|
||||
{"id": TEST_JOB_ID, "name": JOB_NAME, "currentState": DataflowJobStatus.JOB_STATE_DONE},
|
||||
{"id": TEST_JOB_ID, "name": JOB_NAME, "currentState": DataflowJobStatus.JOB_STATE_UPDATED},
|
||||
{"id": TEST_JOB_ID, "name": JOB_NAME, "currentState": DataflowJobStatus.JOB_STATE_DRAINED},
|
||||
{"id": TEST_JOB_ID, "name": JOB_NAME, "currentState": DataflowJobStatus.JOB_STATE_FAILED},
|
||||
{"id": TEST_JOB_ID, "name": JOB_NAME, "currentState": DataflowJobStatus.JOB_STATE_CANCELLED},
|
||||
]
|
||||
|
||||
mock_jobs.return_value.list_next.return_value = None
|
||||
dataflow_job = _DataflowJobsController(
|
||||
dataflow=self.mock_dataflow,
|
||||
project_number=TEST_PROJECT,
|
||||
name=UNIQUE_JOB_NAME,
|
||||
location=TEST_LOCATION,
|
||||
poll_sleep=0,
|
||||
job_id=TEST_JOB_ID,
|
||||
num_retries=20,
|
||||
multiple_jobs=False,
|
||||
)
|
||||
dataflow_job.cancel()
|
||||
|
||||
get_method.assert_called_with(jobId=TEST_JOB_ID, location=TEST_LOCATION, projectId=TEST_PROJECT)
|
||||
get_method.return_value.execute.assert_called_with(num_retries=20)
|
||||
|
||||
self.mock_dataflow.new_batch_http_request.assert_not_called()
|
||||
mock_jobs.return_value.update.assert_not_called()
|
||||
|
||||
|
||||
APACHE_BEAM_V_2_14_0_JAVA_SDK_LOG = f""""\
|
||||
|
|
|
@ -26,11 +26,13 @@ from airflow.providers.google.cloud.operators.dataflow import (
|
|||
DataflowCreateJavaJobOperator,
|
||||
DataflowCreatePythonJobOperator,
|
||||
DataflowTemplatedJobStartOperator,
|
||||
DataflowStartFlexTemplateOperator,
|
||||
)
|
||||
from airflow.version import version
|
||||
|
||||
TASK_ID = 'test-dataflow-operator'
|
||||
JOB_NAME = 'test-dataflow-pipeline'
|
||||
JOB_ID = 'test-dataflow-pipeline-id'
|
||||
JOB_NAME = 'test-dataflow-pipeline-name'
|
||||
TEMPLATE = 'gs://dataflow-templates/wordcount/template_file'
|
||||
PARAMETERS = {
|
||||
'inputFile': 'gs://dataflow-samples/shakespeare/kinglear.txt',
|
||||
|
@ -59,7 +61,16 @@ EXPECTED_ADDITIONAL_OPTIONS = {
|
|||
}
|
||||
POLL_SLEEP = 30
|
||||
GCS_HOOK_STRING = 'airflow.providers.google.cloud.operators.dataflow.{}'
|
||||
TEST_LOCATION = "custom-location"
|
||||
TEST_FLEX_PARAMETERS = {
|
||||
"containerSpecGcsPath": "gs://test-bucket/test-file",
|
||||
"jobName": 'test-job-name',
|
||||
"parameters": {
|
||||
"inputSubscription": 'test-subsription',
|
||||
"outputTable": "test-project:test-dataset.streaming_beam_sql",
|
||||
},
|
||||
}
|
||||
TEST_LOCATION = 'custom-location'
|
||||
TEST_PROJECT_ID = 'test-project-id'
|
||||
|
||||
|
||||
class TestDataflowPythonOperator(unittest.TestCase):
|
||||
|
@ -290,3 +301,37 @@ class TestDataflowTemplateOperator(unittest.TestCase):
|
|||
location=TEST_LOCATION,
|
||||
environment={'maxWorkers': 2},
|
||||
)
|
||||
|
||||
|
||||
class TestDataflowStartFlexTemplateOperator(unittest.TestCase):
|
||||
@mock.patch('airflow.providers.google.cloud.operators.dataflow.DataflowHook')
|
||||
def test_execute(self, mock_dataflow):
|
||||
start_flex_template = DataflowStartFlexTemplateOperator(
|
||||
task_id="start_flex_template_streaming_beam_sql",
|
||||
body={"launchParameter": TEST_FLEX_PARAMETERS},
|
||||
do_xcom_push=True,
|
||||
project_id=TEST_PROJECT_ID,
|
||||
location=TEST_LOCATION,
|
||||
)
|
||||
start_flex_template.execute(mock.MagicMock())
|
||||
mock_dataflow.return_value.start_flex_template.assert_called_once_with(
|
||||
body={"launchParameter": TEST_FLEX_PARAMETERS},
|
||||
location=TEST_LOCATION,
|
||||
project_id=TEST_PROJECT_ID,
|
||||
on_new_job_id_callback=mock.ANY,
|
||||
)
|
||||
|
||||
def test_on_kill(self):
|
||||
start_flex_template = DataflowStartFlexTemplateOperator(
|
||||
task_id="start_flex_template_streaming_beam_sql",
|
||||
body={"launchParameter": TEST_FLEX_PARAMETERS},
|
||||
do_xcom_push=True,
|
||||
location=TEST_LOCATION,
|
||||
project_id=TEST_PROJECT_ID,
|
||||
)
|
||||
start_flex_template.hook = mock.MagicMock()
|
||||
start_flex_template.job_id = JOB_ID
|
||||
start_flex_template.on_kill()
|
||||
start_flex_template.hook.cancel_job.assert_called_once_with(
|
||||
job_id='test-dataflow-pipeline-id', project_id=TEST_PROJECT_ID
|
||||
)
|
||||
|
|
|
@ -15,9 +15,25 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import pytest
|
||||
import json
|
||||
import os
|
||||
import shlex
|
||||
import textwrap
|
||||
from tempfile import NamedTemporaryFile
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from tests.providers.google.cloud.utils.gcp_authenticator import GCP_DATAFLOW_KEY
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from airflow.providers.google.cloud.example_dags.example_dataflow_flex_template import (
|
||||
BQ_FLEX_TEMPLATE_DATASET,
|
||||
BQ_FLEX_TEMPLATE_LOCATION,
|
||||
DATAFLOW_FLEX_TEMPLATE_JOB_NAME,
|
||||
GCS_FLEX_TEMPLATE_TEMPLATE_PATH,
|
||||
PUBSUB_FLEX_TEMPLATE_SUBSCRIPTION,
|
||||
PUBSUB_FLEX_TEMPLATE_TOPIC,
|
||||
)
|
||||
from tests.providers.google.cloud.utils.gcp_authenticator import GCP_DATAFLOW_KEY, GCP_GCS_TRANSFER_KEY
|
||||
from tests.test_utils.gcp_system_helpers import CLOUD_DAG_FOLDER, GoogleSystemTest, provide_gcp_context
|
||||
|
||||
|
||||
|
@ -35,3 +51,193 @@ class CloudDataflowExampleDagsSystemTest(GoogleSystemTest):
|
|||
@provide_gcp_context(GCP_DATAFLOW_KEY)
|
||||
def test_run_example_gcp_dataflow_template(self):
|
||||
self.run_dag('example_gcp_dataflow_template', CLOUD_DAG_FOLDER)
|
||||
|
||||
|
||||
GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project")
|
||||
GCR_FLEX_TEMPLATE_IMAGE = f"gcr.io/{GCP_PROJECT_ID}/samples-dataflow-streaming-beam-sql:latest"
|
||||
|
||||
# https://github.com/GoogleCloudPlatform/java-docs-samples/tree/954553c/dataflow/flex-templates/streaming_beam_sql
|
||||
GCS_TEMPLATE_PARTS = urlparse(GCS_FLEX_TEMPLATE_TEMPLATE_PATH)
|
||||
GCS_FLEX_TEMPLATE_BUCKET_NAME = GCS_TEMPLATE_PARTS.netloc
|
||||
|
||||
|
||||
EXAMPLE_FLEX_TEMPLATE_REPO = "GoogleCloudPlatform/java-docs-samples"
|
||||
EXAMPLE_FLEX_TEMPLATE_COMMIT = "deb0745be1d1ac1d133e1f0a7faa9413dbfbe5fe"
|
||||
EXAMPLE_FLEX_TEMPLATE_SUBDIR = "dataflow/flex-templates/streaming_beam_sql"
|
||||
|
||||
|
||||
@pytest.mark.backend("mysql", "postgres")
|
||||
@pytest.mark.credential_file(GCP_GCS_TRANSFER_KEY)
|
||||
class CloudDataflowExampleDagFlexTemplateJavagSystemTest(GoogleSystemTest):
|
||||
@provide_gcp_context(GCP_GCS_TRANSFER_KEY, project_id=GoogleSystemTest._project_id())
|
||||
def setUp(self) -> None:
|
||||
# Create a Cloud Storage bucket
|
||||
self.execute_cmd(["gsutil", "mb", f"gs://{GCS_FLEX_TEMPLATE_BUCKET_NAME}"])
|
||||
|
||||
# Build image with pipeline
|
||||
with NamedTemporaryFile("w") as f:
|
||||
cloud_build_config = {
|
||||
'steps': [
|
||||
{'name': 'gcr.io/cloud-builders/git', 'args': ['clone', "$_EXAMPLE_REPO", "repo_dir"]},
|
||||
{
|
||||
'name': 'gcr.io/cloud-builders/git',
|
||||
'args': ['checkout', '$_EXAMPLE_COMMIT'],
|
||||
'dir': 'repo_dir',
|
||||
},
|
||||
{
|
||||
'name': 'maven',
|
||||
'args': ['mvn', 'clean', 'package'],
|
||||
'dir': 'repo_dir/$_EXAMPLE_SUBDIR',
|
||||
},
|
||||
{
|
||||
'name': 'gcr.io/cloud-builders/docker',
|
||||
'args': ['build', '-t', '$_TEMPLATE_IMAGE', '.'],
|
||||
'dir': 'repo_dir/$_EXAMPLE_SUBDIR',
|
||||
},
|
||||
],
|
||||
'images': ['$_TEMPLATE_IMAGE'],
|
||||
}
|
||||
f.write(json.dumps(cloud_build_config))
|
||||
f.flush()
|
||||
self.execute_cmd(["cat", f.name])
|
||||
substitutions = {
|
||||
"_TEMPLATE_IMAGE": GCR_FLEX_TEMPLATE_IMAGE,
|
||||
"_EXAMPLE_REPO": f"https://github.com/{EXAMPLE_FLEX_TEMPLATE_REPO}.git",
|
||||
"_EXAMPLE_SUBDIR": EXAMPLE_FLEX_TEMPLATE_SUBDIR,
|
||||
"_EXAMPLE_COMMIT": EXAMPLE_FLEX_TEMPLATE_COMMIT,
|
||||
}
|
||||
self.execute_cmd(
|
||||
[
|
||||
"gcloud",
|
||||
"builds",
|
||||
"submit",
|
||||
"--substitutions="
|
||||
+ ",".join([f"{k}={shlex.quote(v)}" for k, v in substitutions.items()]),
|
||||
f"--config={f.name}",
|
||||
"--no-source",
|
||||
]
|
||||
)
|
||||
|
||||
# Build template
|
||||
with NamedTemporaryFile() as f: # type: ignore
|
||||
manifest_url = (
|
||||
f"https://raw.githubusercontent.com/"
|
||||
f"{EXAMPLE_FLEX_TEMPLATE_REPO}/{EXAMPLE_FLEX_TEMPLATE_COMMIT}/"
|
||||
f"{EXAMPLE_FLEX_TEMPLATE_SUBDIR}/metadata.json"
|
||||
)
|
||||
f.write(requests.get(manifest_url).content) # type: ignore
|
||||
f.flush()
|
||||
self.execute_cmd(
|
||||
[
|
||||
"gcloud",
|
||||
"beta",
|
||||
"dataflow",
|
||||
"flex-template",
|
||||
"build",
|
||||
GCS_FLEX_TEMPLATE_TEMPLATE_PATH,
|
||||
"--image",
|
||||
GCR_FLEX_TEMPLATE_IMAGE,
|
||||
"--sdk-language",
|
||||
"JAVA",
|
||||
"--metadata-file",
|
||||
f.name,
|
||||
]
|
||||
)
|
||||
|
||||
# Create a Pub/Sub topic and a subscription to that topic
|
||||
self.execute_cmd(["gcloud", "pubsub", "topics", "create", PUBSUB_FLEX_TEMPLATE_TOPIC])
|
||||
self.execute_cmd(
|
||||
[
|
||||
"gcloud",
|
||||
"pubsub",
|
||||
"subscriptions",
|
||||
"create",
|
||||
"--topic",
|
||||
PUBSUB_FLEX_TEMPLATE_TOPIC,
|
||||
PUBSUB_FLEX_TEMPLATE_SUBSCRIPTION,
|
||||
]
|
||||
)
|
||||
# Create a publisher for "positive ratings" that publishes 1 message per minute
|
||||
self.execute_cmd(
|
||||
[
|
||||
"gcloud",
|
||||
"scheduler",
|
||||
"jobs",
|
||||
"create",
|
||||
"pubsub",
|
||||
"positive-ratings-publisher",
|
||||
'--schedule=* * * * *',
|
||||
f"--topic={PUBSUB_FLEX_TEMPLATE_TOPIC}",
|
||||
'--message-body=\'{"url": "https://beam.apache.org/", "review": "positive"}\'',
|
||||
]
|
||||
)
|
||||
# Create and run another similar publisher for "negative ratings" that
|
||||
self.execute_cmd(
|
||||
[
|
||||
"gcloud",
|
||||
"scheduler",
|
||||
"jobs",
|
||||
"create",
|
||||
"pubsub",
|
||||
"negative-ratings-publisher",
|
||||
'--schedule=*/2 * * * *',
|
||||
f"--topic={PUBSUB_FLEX_TEMPLATE_TOPIC}",
|
||||
'--message-body=\'{"url": "https://beam.apache.org/", "review": "negative"}\'',
|
||||
]
|
||||
)
|
||||
|
||||
# Create a BigQuery dataset
|
||||
self.execute_cmd(["bq", "mk", "--dataset", f'{self._project_id()}:{BQ_FLEX_TEMPLATE_DATASET}'])
|
||||
|
||||
@provide_gcp_context(GCP_GCS_TRANSFER_KEY)
|
||||
def test_run_example_dag_function(self):
|
||||
self.run_dag("example_gcp_dataflow_flex_template_java", CLOUD_DAG_FOLDER)
|
||||
|
||||
@provide_gcp_context(GCP_GCS_TRANSFER_KEY, project_id=GoogleSystemTest._project_id())
|
||||
def tearDown(self) -> None:
|
||||
# Stop the Dataflow pipeline.
|
||||
self.execute_cmd(
|
||||
[
|
||||
"bash",
|
||||
"-c",
|
||||
textwrap.dedent(
|
||||
f"""\
|
||||
gcloud dataflow jobs list \
|
||||
--region={BQ_FLEX_TEMPLATE_LOCATION} \
|
||||
--filter 'NAME:{DATAFLOW_FLEX_TEMPLATE_JOB_NAME} AND STATE=Running' \
|
||||
--format 'value(JOB_ID)' \
|
||||
| xargs -r gcloud dataflow jobs cancel --region={BQ_FLEX_TEMPLATE_LOCATION}
|
||||
"""
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
# Delete the template spec file from Cloud Storage
|
||||
self.execute_cmd(["gsutil", "rm", GCS_FLEX_TEMPLATE_TEMPLATE_PATH])
|
||||
|
||||
# Delete the Flex Template container image from Container Registry.
|
||||
self.execute_cmd(
|
||||
[
|
||||
"gcloud",
|
||||
"container",
|
||||
"images",
|
||||
"delete",
|
||||
GCR_FLEX_TEMPLATE_IMAGE,
|
||||
"--force-delete-tags",
|
||||
"--quiet",
|
||||
]
|
||||
)
|
||||
|
||||
# Delete the Cloud Scheduler jobs.
|
||||
self.execute_cmd(["gcloud", "scheduler", "jobs", "delete", "negative-ratings-publisher", "--quiet"])
|
||||
self.execute_cmd(["gcloud", "scheduler", "jobs", "delete", "positive-ratings-publisher", "--quiet"])
|
||||
|
||||
# Delete the Pub/Sub subscription and topic.
|
||||
self.execute_cmd(["gcloud", "pubsub", "subscriptions", "delete", PUBSUB_FLEX_TEMPLATE_SUBSCRIPTION])
|
||||
self.execute_cmd(["gcloud", "pubsub", "topics", "delete", PUBSUB_FLEX_TEMPLATE_TOPIC])
|
||||
|
||||
# Delete the BigQuery dataset,
|
||||
self.execute_cmd(["bq", "rm", "-r", "-f", "-d", f'{self._project_id()}:{BQ_FLEX_TEMPLATE_DATASET}'])
|
||||
|
||||
# Delete the Cloud Storage bucket
|
||||
self.execute_cmd(["gsutil", "rm", "-r", f"gs://{GCS_FLEX_TEMPLATE_BUCKET_NAME}"])
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import os
|
||||
import shlex
|
||||
import subprocess
|
||||
|
||||
from airflow.exceptions import AirflowException
|
||||
|
@ -26,11 +27,11 @@ class LoggingCommandExecutor(LoggingMixin):
|
|||
|
||||
def execute_cmd(self, cmd, silent=False, cwd=None, env=None):
|
||||
if silent:
|
||||
self.log.info("Executing in silent mode: '%s'", " ".join(cmd))
|
||||
self.log.info("Executing in silent mode: '%s'", " ".join([shlex.quote(c) for c in cmd]))
|
||||
with open(os.devnull, 'w') as dev_null:
|
||||
return subprocess.call(args=cmd, stdout=dev_null, stderr=subprocess.STDOUT, env=env, cwd=cwd)
|
||||
else:
|
||||
self.log.info("Executing: '%s'", " ".join(cmd))
|
||||
self.log.info("Executing: '%s'", " ".join([shlex.quote(c) for c in cmd]))
|
||||
process = subprocess.Popen(
|
||||
args=cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
|
||||
universal_newlines=True, cwd=cwd, env=env
|
||||
|
@ -40,17 +41,17 @@ class LoggingCommandExecutor(LoggingMixin):
|
|||
self.log.info("Stdout: %s", output)
|
||||
self.log.info("Stderr: %s", err)
|
||||
if retcode:
|
||||
self.log.error("Error when executing %s", " ".join(cmd))
|
||||
self.log.error("Error when executing %s", " ".join([shlex.quote(c) for c in cmd]))
|
||||
return retcode
|
||||
|
||||
def check_output(self, cmd):
|
||||
self.log.info("Executing for output: '%s'", " ".join(cmd))
|
||||
self.log.info("Executing for output: '%s'", " ".join([shlex.quote(c) for c in cmd]))
|
||||
process = subprocess.Popen(args=cmd, stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE)
|
||||
output, err = process.communicate()
|
||||
retcode = process.poll()
|
||||
if retcode:
|
||||
self.log.error("Error when executing '%s'", " ".join(cmd))
|
||||
self.log.error("Error when executing '%s'", " ".join([shlex.quote(c) for c in cmd]))
|
||||
self.log.info("Stdout: %s", output)
|
||||
self.log.info("Stderr: %s", err)
|
||||
raise AirflowException("Retcode {} on {} with stdout: {}, stderr: {}".
|
||||
|
|
Загрузка…
Ссылка в новой задаче