Add DataflowStartFlexTemplateOperator (#8550)

This commit is contained in:
Kamil Breguła 2020-10-16 18:28:23 +02:00 коммит произвёл GitHub
Родитель 45d608396d
Коммит 3c10ca6504
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
7 изменённых файлов: 567 добавлений и 68 удалений

Просмотреть файл

@ -0,0 +1,61 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Example Airflow DAG for Google Cloud Dataflow service
"""
import os
from airflow import models
from airflow.providers.google.cloud.operators.dataflow import DataflowStartFlexTemplateOperator
from airflow.utils.dates import days_ago
GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project")
DATAFLOW_FLEX_TEMPLATE_JOB_NAME = os.environ.get('DATAFLOW_FLEX_TEMPLATE_JOB_NAME', "dataflow-flex-template")
# For simplicity we use the same topic name as the subscription name.
PUBSUB_FLEX_TEMPLATE_TOPIC = os.environ.get('DATAFLOW_PUBSUB_FLEX_TEMPLATE_TOPIC', "dataflow-flex-template")
PUBSUB_FLEX_TEMPLATE_SUBSCRIPTION = PUBSUB_FLEX_TEMPLATE_TOPIC
GCS_FLEX_TEMPLATE_TEMPLATE_PATH = os.environ.get(
'DATAFLOW_GCS_FLEX_TEMPLATE_TEMPLATE_PATH',
"gs://test-airflow-dataflow-flex-template/samples/dataflow/templates/streaming-beam-sql.json",
)
BQ_FLEX_TEMPLATE_DATASET = os.environ.get('DATAFLOW_BQ_FLEX_TEMPLATE_DATASET', 'airflow_dataflow_samples')
BQ_FLEX_TEMPLATE_LOCATION = os.environ.get('DATAFLOW_BQ_FLEX_TEMPLATE_LOCAATION>', 'us-west1')
with models.DAG(
dag_id="example_gcp_dataflow_flex_template_java",
start_date=days_ago(1),
schedule_interval=None, # Override to match your needs
) as dag_flex_template:
start_flex_template = DataflowStartFlexTemplateOperator(
task_id="start_flex_template_streaming_beam_sql",
body={
"launchParameter": {
"containerSpecGcsPath": GCS_FLEX_TEMPLATE_TEMPLATE_PATH,
"jobName": DATAFLOW_FLEX_TEMPLATE_JOB_NAME,
"parameters": {
"inputSubscription": PUBSUB_FLEX_TEMPLATE_SUBSCRIPTION,
"outputTable": f"{GCP_PROJECT_ID}:{BQ_FLEX_TEMPLATE_DATASET}.streaming_beam_sql",
},
}
},
do_xcom_push=True,
location=BQ_FLEX_TEMPLATE_LOCATION,
)

Просмотреть файл

@ -96,16 +96,32 @@ _fallback_to_project_id_from_variables = _fallback_variable_parameter('project_i
class DataflowJobStatus:
"""
Helper class with Dataflow job statuses.
Reference: https://cloud.google.com/dataflow/docs/reference/rest/v1b3/projects.jobs#Job.JobState
"""
JOB_STATE_DONE = "JOB_STATE_DONE"
JOB_STATE_UNKNOWN = "JOB_STATE_UNKNOWN"
JOB_STATE_STOPPED = "JOB_STATE_STOPPED"
JOB_STATE_RUNNING = "JOB_STATE_RUNNING"
JOB_STATE_FAILED = "JOB_STATE_FAILED"
JOB_STATE_CANCELLED = "JOB_STATE_CANCELLED"
JOB_STATE_UPDATED = "JOB_STATE_UPDATED"
JOB_STATE_DRAINING = "JOB_STATE_DRAINING"
JOB_STATE_DRAINED = "JOB_STATE_DRAINED"
JOB_STATE_PENDING = "JOB_STATE_PENDING"
JOB_STATE_CANCELLING = "JOB_STATE_CANCELLING"
JOB_STATE_QUEUED = "JOB_STATE_QUEUED"
FAILED_END_STATES = {JOB_STATE_FAILED, JOB_STATE_CANCELLED}
SUCCEEDED_END_STATES = {JOB_STATE_DONE}
END_STATES = SUCCEEDED_END_STATES | FAILED_END_STATES
SUCCEEDED_END_STATES = {JOB_STATE_DONE, JOB_STATE_UPDATED, JOB_STATE_DRAINED}
TERMINAL_STATES = SUCCEEDED_END_STATES | FAILED_END_STATES
AWAITING_STATES = {
JOB_STATE_RUNNING,
JOB_STATE_PENDING,
JOB_STATE_QUEUED,
JOB_STATE_CANCELLING,
JOB_STATE_DRAINING,
JOB_STATE_STOPPED,
}
class DataflowJobType:
@ -170,7 +186,7 @@ class _DataflowJobsController(LoggingMixin):
return False
for job in self._jobs:
if job['currentState'] not in DataflowJobStatus.END_STATES:
if job['currentState'] not in DataflowJobStatus.TERMINAL_STATES:
return True
return False
@ -261,10 +277,7 @@ class _DataflowJobsController(LoggingMixin):
and DataflowJobType.JOB_TYPE_STREAMING == job['type']
):
return True
elif job['currentState'] in {
DataflowJobStatus.JOB_STATE_RUNNING,
DataflowJobStatus.JOB_STATE_PENDING,
}:
elif job['currentState'] in DataflowJobStatus.AWAITING_STATES:
return False
self.log.debug("Current job: %s", str(job))
raise Exception(
@ -282,14 +295,14 @@ class _DataflowJobsController(LoggingMixin):
time.sleep(self._poll_sleep)
self._refresh_jobs()
def get_jobs(self) -> List[dict]:
def get_jobs(self, refresh=False) -> List[dict]:
"""
Returns Dataflow jobs.
:return: list of jobs
:rtype: list
"""
if not self._jobs:
if not self._jobs or refresh:
self._refresh_jobs()
if not self._jobs:
raise ValueError("Could not read _jobs")
@ -300,23 +313,26 @@ class _DataflowJobsController(LoggingMixin):
"""
Cancels current job
"""
jobs = self._get_current_jobs()
batch = self._dataflow.new_batch_http_request()
job_ids = [job['id'] for job in jobs]
self.log.info("Canceling jobs: %s", ", ".join(job_ids))
for job_id in job_ids:
batch.add(
self._dataflow.projects()
.locations()
.jobs()
.update(
projectId=self._project_number,
location=self._job_location,
jobId=job_id,
body={"requestedState": DataflowJobStatus.JOB_STATE_CANCELLED},
jobs = self.get_jobs()
job_ids = [job['id'] for job in jobs if job['currentState'] not in DataflowJobStatus.TERMINAL_STATES]
if job_ids:
batch = self._dataflow.new_batch_http_request()
self.log.info("Canceling jobs: %s", ", ".join(job_ids))
for job_id in job_ids:
batch.add(
self._dataflow.projects()
.locations()
.jobs()
.update(
projectId=self._project_number,
location=self._job_location,
jobId=job_id,
body={"requestedState": DataflowJobStatus.JOB_STATE_CANCELLED},
)
)
)
batch.execute()
batch.execute()
else:
self.log.info("No jobs to cancel")
class _DataflowRunner(LoggingMixin):
@ -631,6 +647,52 @@ class DataflowHook(GoogleBaseHook):
jobs_controller.wait_for_done()
return response["job"]
@GoogleBaseHook.fallback_to_default_project_id
def start_flex_template(
self,
body: dict,
location: str,
project_id: str,
on_new_job_id_callback: Optional[Callable[[str], None]] = None,
):
"""
Starts flex templates with the Dataflow pipeline.
:param body: The request body. See:
https://cloud.google.com/dataflow/docs/reference/rest/v1b3/projects.locations.flexTemplates/launch#request-body
:param location: The location of the Dataflow job (for example europe-west1)
:type location: str
:param project_id: The ID of the GCP project that owns the job.
If set to ``None`` or missing, the default project_id from the GCP connection is used.
:type project_id: Optional[str]
:param on_new_job_id_callback: A callback that is called when a Job ID is detected.
:return: the Job
"""
service = self.get_conn()
request = (
service.projects() # pylint: disable=no-member
.locations()
.flexTemplates()
.launch(projectId=project_id, body=body, location=location)
)
response = request.execute(num_retries=self.num_retries)
job_id = response['job']['id']
if on_new_job_id_callback:
on_new_job_id_callback(job_id)
jobs_controller = _DataflowJobsController(
dataflow=self.get_conn(),
project_number=project_id,
job_id=job_id,
location=location,
poll_sleep=self.poll_sleep,
num_retries=self.num_retries,
)
jobs_controller.wait_for_done()
return jobs_controller.get_jobs(refresh=True)[0]
@_fallback_to_location_from_variables
@_fallback_to_project_id_from_variables
@GoogleBaseHook.fallback_to_default_project_id
@ -659,6 +721,9 @@ class DataflowHook(GoogleBaseHook):
:type dataflow: str
:param py_options: Additional options.
:type py_options: List[str]
:param project_id: The ID of the GCP project that owns the job.
If set to ``None`` or missing, the default project_id from the GCP connection is used.
:type project_id: Optional[str]
:param py_interpreter: Python version of the beam pipeline.
If None, this defaults to the python3.
To track python versions supported by beam and related

Просмотреть файл

@ -451,6 +451,72 @@ class DataflowTemplatedJobStartOperator(BaseOperator):
self.hook.cancel_job(job_id=self.job_id, project_id=self.project_id)
class DataflowStartFlexTemplateOperator(BaseOperator):
"""
Starts flex templates with the Dataflow pipeline.
:param body: The request body. See:
https://cloud.google.com/dataflow/docs/reference/rest/v1b3/projects.locations.flexTemplates/launch#request-body
:param location: The location of the Dataflow job (for example europe-west1)
:type location: str
:param project_id: The ID of the GCP project that owns the job.
If set to ``None`` or missing, the default project_id from the GCP connection is used.
:type project_id: Optional[str]
:param gcp_conn_id: The connection ID to use connecting to Google Cloud
Platform.
:type gcp_conn_id: str
:param delegate_to: The account to impersonate, if any.
For this to work, the service account making the request must have
domain-wide delegation enabled.
:type delegate_to: str
"""
template_fields = ["body", 'location', 'project_id', 'gcp_conn_id']
@apply_defaults
def __init__(
self,
body: Dict,
location: str,
project_id: Optional[str] = None,
gcp_conn_id: str = 'google_cloud_default',
delegate_to: Optional[str] = None,
*args,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self.body = body
self.location = location
self.project_id = project_id
self.gcp_conn_id = gcp_conn_id
self.delegate_to = delegate_to
self.job_id = None
self.hook: Optional[DataflowHook] = None
def execute(self, context):
self.hook = DataflowHook(
gcp_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to,
)
def set_current_job_id(job_id):
self.job_id = job_id
job = self.hook.start_flex_template(
body=self.body,
location=self.location,
project_id=self.project_id,
on_new_job_id_callback=set_current_job_id,
)
return job
def on_kill(self) -> None:
self.log.info("On kill.")
if self.job_id:
self.hook.cancel_job(job_id=self.job_id, project_id=self.project_id)
class DataflowCreatePythonJobOperator(BaseOperator):
"""
Launching Cloud Dataflow jobs written in python. Note that both

Просмотреть файл

@ -81,6 +81,15 @@ TEST_PROJECT = 'test-project'
TEST_JOB_ID = 'test-job-id'
TEST_LOCATION = 'custom-location'
DEFAULT_PY_INTERPRETER = 'python3'
TEST_FLEX_PARAMETERS = {
"containerSpecGcsPath": "gs://test-bucket/test-file",
"jobName": 'test-job-name',
"parameters": {
"inputSubscription": 'test-subsription',
"outputTable": "test-project:test-dataset.streaming_beam_sql",
},
}
TEST_PROJECT_ID = 'test-project-id'
class TestFallbackToVariables(unittest.TestCase):
@ -812,6 +821,40 @@ class TestDataflowTemplateHook(unittest.TestCase):
)
mock_uuid.assert_called_once_with()
@mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
@mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn'))
def test_start_flex_template(self, mock_conn, mock_controller):
mock_locations = mock_conn.return_value.projects.return_value.locations
launch_method = mock_locations.return_value.flexTemplates.return_value.launch
launch_method.return_value.execute.return_value = {"job": {"id": TEST_JOB_ID}}
mock_controller.return_value.get_jobs.return_value = [{"id": TEST_JOB_ID}]
on_new_job_id_callback = mock.MagicMock()
result = self.dataflow_hook.start_flex_template(
body={"launchParameter": TEST_FLEX_PARAMETERS},
location=TEST_LOCATION,
project_id=TEST_PROJECT_ID,
on_new_job_id_callback=on_new_job_id_callback,
)
on_new_job_id_callback.assert_called_once_with(TEST_JOB_ID)
launch_method.assert_called_once_with(
projectId='test-project-id',
body={'launchParameter': TEST_FLEX_PARAMETERS},
location=TEST_LOCATION,
)
mock_controller.assert_called_once_with(
dataflow=mock_conn.return_value,
project_number=TEST_PROJECT_ID,
job_id=TEST_JOB_ID,
location=TEST_LOCATION,
poll_sleep=self.dataflow_hook.poll_sleep,
num_retries=self.dataflow_hook.num_retries,
)
mock_controller.return_value.get_jobs.wait_for_done.assrt_called_once_with()
mock_controller.return_value.get_jobs.assrt_called_once_with()
self.assertEqual(result, {"id": TEST_JOB_ID})
@mock.patch(DATAFLOW_STRING.format('_DataflowJobsController'))
@mock.patch(DATAFLOW_STRING.format('DataflowHook.get_conn'))
def test_cancel_job(self, mock_get_conn, jobs_controller):
@ -1114,54 +1157,37 @@ class TestDataflowJob(unittest.TestCase):
self.assertEqual(False, result)
def test_dataflow_job_cancel_job(self):
job = {
"id": TEST_JOB_ID,
"name": UNIQUE_JOB_NAME,
"currentState": DataflowJobStatus.JOB_STATE_RUNNING,
}
# fmt: off
get_method = (
self.mock_dataflow.projects.return_value.
locations.return_value.
jobs.return_value.
get
)
get_method.return_value.execute.return_value = job
mock_jobs = self.mock_dataflow.projects.return_value.locations.return_value.jobs
get_method = mock_jobs.return_value.get
get_method.return_value.execute.side_effect = [
{"id": TEST_JOB_ID, "name": JOB_NAME, "currentState": DataflowJobStatus.JOB_STATE_RUNNING},
{"id": TEST_JOB_ID, "name": JOB_NAME, "currentState": DataflowJobStatus.JOB_STATE_PENDING},
{"id": TEST_JOB_ID, "name": JOB_NAME, "currentState": DataflowJobStatus.JOB_STATE_QUEUED},
{"id": TEST_JOB_ID, "name": JOB_NAME, "currentState": DataflowJobStatus.JOB_STATE_CANCELLING},
{"id": TEST_JOB_ID, "name": JOB_NAME, "currentState": DataflowJobStatus.JOB_STATE_DRAINING},
{"id": TEST_JOB_ID, "name": JOB_NAME, "currentState": DataflowJobStatus.JOB_STATE_STOPPED},
{"id": TEST_JOB_ID, "name": JOB_NAME, "currentState": DataflowJobStatus.JOB_STATE_CANCELLED},
]
(
self.mock_dataflow.projects.return_value.
locations.return_value.
jobs.return_value.
list_next.return_value
) = None
# fmt: on
mock_jobs.return_value.list_next.return_value = None
dataflow_job = _DataflowJobsController(
dataflow=self.mock_dataflow,
project_number=TEST_PROJECT,
name=UNIQUE_JOB_NAME,
location=TEST_LOCATION,
poll_sleep=10,
poll_sleep=0,
job_id=TEST_JOB_ID,
num_retries=20,
multiple_jobs=False,
)
dataflow_job.cancel()
get_method.assert_called_once_with(jobId=TEST_JOB_ID, location=TEST_LOCATION, projectId=TEST_PROJECT)
get_method.return_value.execute.assert_called_once_with(num_retries=20)
get_method.assert_called_with(jobId=TEST_JOB_ID, location=TEST_LOCATION, projectId=TEST_PROJECT)
get_method.return_value.execute.assert_called_with(num_retries=20)
self.mock_dataflow.new_batch_http_request.assert_called_once_with()
mock_batch = self.mock_dataflow.new_batch_http_request.return_value
# fmt: off
mock_update = (
self.mock_dataflow.projects.return_value.
locations.return_value.
jobs.return_value.
update
)
# fmt: on
mock_update = mock_jobs.return_value.update
mock_update.assert_called_once_with(
body={'requestedState': 'JOB_STATE_CANCELLED'},
jobId='test-job-id',
@ -1169,7 +1195,36 @@ class TestDataflowJob(unittest.TestCase):
projectId='test-project',
)
mock_batch.add.assert_called_once_with(mock_update.return_value)
mock_batch.execute.assert_called_once()
def test_dataflow_job_cancel_job_no_running_jobs(self):
mock_jobs = self.mock_dataflow.projects.return_value.locations.return_value.jobs
get_method = mock_jobs.return_value.get
get_method.return_value.execute.side_effect = [
{"id": TEST_JOB_ID, "name": JOB_NAME, "currentState": DataflowJobStatus.JOB_STATE_DONE},
{"id": TEST_JOB_ID, "name": JOB_NAME, "currentState": DataflowJobStatus.JOB_STATE_UPDATED},
{"id": TEST_JOB_ID, "name": JOB_NAME, "currentState": DataflowJobStatus.JOB_STATE_DRAINED},
{"id": TEST_JOB_ID, "name": JOB_NAME, "currentState": DataflowJobStatus.JOB_STATE_FAILED},
{"id": TEST_JOB_ID, "name": JOB_NAME, "currentState": DataflowJobStatus.JOB_STATE_CANCELLED},
]
mock_jobs.return_value.list_next.return_value = None
dataflow_job = _DataflowJobsController(
dataflow=self.mock_dataflow,
project_number=TEST_PROJECT,
name=UNIQUE_JOB_NAME,
location=TEST_LOCATION,
poll_sleep=0,
job_id=TEST_JOB_ID,
num_retries=20,
multiple_jobs=False,
)
dataflow_job.cancel()
get_method.assert_called_with(jobId=TEST_JOB_ID, location=TEST_LOCATION, projectId=TEST_PROJECT)
get_method.return_value.execute.assert_called_with(num_retries=20)
self.mock_dataflow.new_batch_http_request.assert_not_called()
mock_jobs.return_value.update.assert_not_called()
APACHE_BEAM_V_2_14_0_JAVA_SDK_LOG = f""""\

Просмотреть файл

@ -26,11 +26,13 @@ from airflow.providers.google.cloud.operators.dataflow import (
DataflowCreateJavaJobOperator,
DataflowCreatePythonJobOperator,
DataflowTemplatedJobStartOperator,
DataflowStartFlexTemplateOperator,
)
from airflow.version import version
TASK_ID = 'test-dataflow-operator'
JOB_NAME = 'test-dataflow-pipeline'
JOB_ID = 'test-dataflow-pipeline-id'
JOB_NAME = 'test-dataflow-pipeline-name'
TEMPLATE = 'gs://dataflow-templates/wordcount/template_file'
PARAMETERS = {
'inputFile': 'gs://dataflow-samples/shakespeare/kinglear.txt',
@ -59,7 +61,16 @@ EXPECTED_ADDITIONAL_OPTIONS = {
}
POLL_SLEEP = 30
GCS_HOOK_STRING = 'airflow.providers.google.cloud.operators.dataflow.{}'
TEST_LOCATION = "custom-location"
TEST_FLEX_PARAMETERS = {
"containerSpecGcsPath": "gs://test-bucket/test-file",
"jobName": 'test-job-name',
"parameters": {
"inputSubscription": 'test-subsription',
"outputTable": "test-project:test-dataset.streaming_beam_sql",
},
}
TEST_LOCATION = 'custom-location'
TEST_PROJECT_ID = 'test-project-id'
class TestDataflowPythonOperator(unittest.TestCase):
@ -290,3 +301,37 @@ class TestDataflowTemplateOperator(unittest.TestCase):
location=TEST_LOCATION,
environment={'maxWorkers': 2},
)
class TestDataflowStartFlexTemplateOperator(unittest.TestCase):
@mock.patch('airflow.providers.google.cloud.operators.dataflow.DataflowHook')
def test_execute(self, mock_dataflow):
start_flex_template = DataflowStartFlexTemplateOperator(
task_id="start_flex_template_streaming_beam_sql",
body={"launchParameter": TEST_FLEX_PARAMETERS},
do_xcom_push=True,
project_id=TEST_PROJECT_ID,
location=TEST_LOCATION,
)
start_flex_template.execute(mock.MagicMock())
mock_dataflow.return_value.start_flex_template.assert_called_once_with(
body={"launchParameter": TEST_FLEX_PARAMETERS},
location=TEST_LOCATION,
project_id=TEST_PROJECT_ID,
on_new_job_id_callback=mock.ANY,
)
def test_on_kill(self):
start_flex_template = DataflowStartFlexTemplateOperator(
task_id="start_flex_template_streaming_beam_sql",
body={"launchParameter": TEST_FLEX_PARAMETERS},
do_xcom_push=True,
location=TEST_LOCATION,
project_id=TEST_PROJECT_ID,
)
start_flex_template.hook = mock.MagicMock()
start_flex_template.job_id = JOB_ID
start_flex_template.on_kill()
start_flex_template.hook.cancel_job.assert_called_once_with(
job_id='test-dataflow-pipeline-id', project_id=TEST_PROJECT_ID
)

Просмотреть файл

@ -15,9 +15,25 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
import json
import os
import shlex
import textwrap
from tempfile import NamedTemporaryFile
from urllib.parse import urlparse
from tests.providers.google.cloud.utils.gcp_authenticator import GCP_DATAFLOW_KEY
import pytest
import requests
from airflow.providers.google.cloud.example_dags.example_dataflow_flex_template import (
BQ_FLEX_TEMPLATE_DATASET,
BQ_FLEX_TEMPLATE_LOCATION,
DATAFLOW_FLEX_TEMPLATE_JOB_NAME,
GCS_FLEX_TEMPLATE_TEMPLATE_PATH,
PUBSUB_FLEX_TEMPLATE_SUBSCRIPTION,
PUBSUB_FLEX_TEMPLATE_TOPIC,
)
from tests.providers.google.cloud.utils.gcp_authenticator import GCP_DATAFLOW_KEY, GCP_GCS_TRANSFER_KEY
from tests.test_utils.gcp_system_helpers import CLOUD_DAG_FOLDER, GoogleSystemTest, provide_gcp_context
@ -35,3 +51,193 @@ class CloudDataflowExampleDagsSystemTest(GoogleSystemTest):
@provide_gcp_context(GCP_DATAFLOW_KEY)
def test_run_example_gcp_dataflow_template(self):
self.run_dag('example_gcp_dataflow_template', CLOUD_DAG_FOLDER)
GCP_PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "example-project")
GCR_FLEX_TEMPLATE_IMAGE = f"gcr.io/{GCP_PROJECT_ID}/samples-dataflow-streaming-beam-sql:latest"
# https://github.com/GoogleCloudPlatform/java-docs-samples/tree/954553c/dataflow/flex-templates/streaming_beam_sql
GCS_TEMPLATE_PARTS = urlparse(GCS_FLEX_TEMPLATE_TEMPLATE_PATH)
GCS_FLEX_TEMPLATE_BUCKET_NAME = GCS_TEMPLATE_PARTS.netloc
EXAMPLE_FLEX_TEMPLATE_REPO = "GoogleCloudPlatform/java-docs-samples"
EXAMPLE_FLEX_TEMPLATE_COMMIT = "deb0745be1d1ac1d133e1f0a7faa9413dbfbe5fe"
EXAMPLE_FLEX_TEMPLATE_SUBDIR = "dataflow/flex-templates/streaming_beam_sql"
@pytest.mark.backend("mysql", "postgres")
@pytest.mark.credential_file(GCP_GCS_TRANSFER_KEY)
class CloudDataflowExampleDagFlexTemplateJavagSystemTest(GoogleSystemTest):
@provide_gcp_context(GCP_GCS_TRANSFER_KEY, project_id=GoogleSystemTest._project_id())
def setUp(self) -> None:
# Create a Cloud Storage bucket
self.execute_cmd(["gsutil", "mb", f"gs://{GCS_FLEX_TEMPLATE_BUCKET_NAME}"])
# Build image with pipeline
with NamedTemporaryFile("w") as f:
cloud_build_config = {
'steps': [
{'name': 'gcr.io/cloud-builders/git', 'args': ['clone', "$_EXAMPLE_REPO", "repo_dir"]},
{
'name': 'gcr.io/cloud-builders/git',
'args': ['checkout', '$_EXAMPLE_COMMIT'],
'dir': 'repo_dir',
},
{
'name': 'maven',
'args': ['mvn', 'clean', 'package'],
'dir': 'repo_dir/$_EXAMPLE_SUBDIR',
},
{
'name': 'gcr.io/cloud-builders/docker',
'args': ['build', '-t', '$_TEMPLATE_IMAGE', '.'],
'dir': 'repo_dir/$_EXAMPLE_SUBDIR',
},
],
'images': ['$_TEMPLATE_IMAGE'],
}
f.write(json.dumps(cloud_build_config))
f.flush()
self.execute_cmd(["cat", f.name])
substitutions = {
"_TEMPLATE_IMAGE": GCR_FLEX_TEMPLATE_IMAGE,
"_EXAMPLE_REPO": f"https://github.com/{EXAMPLE_FLEX_TEMPLATE_REPO}.git",
"_EXAMPLE_SUBDIR": EXAMPLE_FLEX_TEMPLATE_SUBDIR,
"_EXAMPLE_COMMIT": EXAMPLE_FLEX_TEMPLATE_COMMIT,
}
self.execute_cmd(
[
"gcloud",
"builds",
"submit",
"--substitutions="
+ ",".join([f"{k}={shlex.quote(v)}" for k, v in substitutions.items()]),
f"--config={f.name}",
"--no-source",
]
)
# Build template
with NamedTemporaryFile() as f: # type: ignore
manifest_url = (
f"https://raw.githubusercontent.com/"
f"{EXAMPLE_FLEX_TEMPLATE_REPO}/{EXAMPLE_FLEX_TEMPLATE_COMMIT}/"
f"{EXAMPLE_FLEX_TEMPLATE_SUBDIR}/metadata.json"
)
f.write(requests.get(manifest_url).content) # type: ignore
f.flush()
self.execute_cmd(
[
"gcloud",
"beta",
"dataflow",
"flex-template",
"build",
GCS_FLEX_TEMPLATE_TEMPLATE_PATH,
"--image",
GCR_FLEX_TEMPLATE_IMAGE,
"--sdk-language",
"JAVA",
"--metadata-file",
f.name,
]
)
# Create a Pub/Sub topic and a subscription to that topic
self.execute_cmd(["gcloud", "pubsub", "topics", "create", PUBSUB_FLEX_TEMPLATE_TOPIC])
self.execute_cmd(
[
"gcloud",
"pubsub",
"subscriptions",
"create",
"--topic",
PUBSUB_FLEX_TEMPLATE_TOPIC,
PUBSUB_FLEX_TEMPLATE_SUBSCRIPTION,
]
)
# Create a publisher for "positive ratings" that publishes 1 message per minute
self.execute_cmd(
[
"gcloud",
"scheduler",
"jobs",
"create",
"pubsub",
"positive-ratings-publisher",
'--schedule=* * * * *',
f"--topic={PUBSUB_FLEX_TEMPLATE_TOPIC}",
'--message-body=\'{"url": "https://beam.apache.org/", "review": "positive"}\'',
]
)
# Create and run another similar publisher for "negative ratings" that
self.execute_cmd(
[
"gcloud",
"scheduler",
"jobs",
"create",
"pubsub",
"negative-ratings-publisher",
'--schedule=*/2 * * * *',
f"--topic={PUBSUB_FLEX_TEMPLATE_TOPIC}",
'--message-body=\'{"url": "https://beam.apache.org/", "review": "negative"}\'',
]
)
# Create a BigQuery dataset
self.execute_cmd(["bq", "mk", "--dataset", f'{self._project_id()}:{BQ_FLEX_TEMPLATE_DATASET}'])
@provide_gcp_context(GCP_GCS_TRANSFER_KEY)
def test_run_example_dag_function(self):
self.run_dag("example_gcp_dataflow_flex_template_java", CLOUD_DAG_FOLDER)
@provide_gcp_context(GCP_GCS_TRANSFER_KEY, project_id=GoogleSystemTest._project_id())
def tearDown(self) -> None:
# Stop the Dataflow pipeline.
self.execute_cmd(
[
"bash",
"-c",
textwrap.dedent(
f"""\
gcloud dataflow jobs list \
--region={BQ_FLEX_TEMPLATE_LOCATION} \
--filter 'NAME:{DATAFLOW_FLEX_TEMPLATE_JOB_NAME} AND STATE=Running' \
--format 'value(JOB_ID)' \
| xargs -r gcloud dataflow jobs cancel --region={BQ_FLEX_TEMPLATE_LOCATION}
"""
),
]
)
# Delete the template spec file from Cloud Storage
self.execute_cmd(["gsutil", "rm", GCS_FLEX_TEMPLATE_TEMPLATE_PATH])
# Delete the Flex Template container image from Container Registry.
self.execute_cmd(
[
"gcloud",
"container",
"images",
"delete",
GCR_FLEX_TEMPLATE_IMAGE,
"--force-delete-tags",
"--quiet",
]
)
# Delete the Cloud Scheduler jobs.
self.execute_cmd(["gcloud", "scheduler", "jobs", "delete", "negative-ratings-publisher", "--quiet"])
self.execute_cmd(["gcloud", "scheduler", "jobs", "delete", "positive-ratings-publisher", "--quiet"])
# Delete the Pub/Sub subscription and topic.
self.execute_cmd(["gcloud", "pubsub", "subscriptions", "delete", PUBSUB_FLEX_TEMPLATE_SUBSCRIPTION])
self.execute_cmd(["gcloud", "pubsub", "topics", "delete", PUBSUB_FLEX_TEMPLATE_TOPIC])
# Delete the BigQuery dataset,
self.execute_cmd(["bq", "rm", "-r", "-f", "-d", f'{self._project_id()}:{BQ_FLEX_TEMPLATE_DATASET}'])
# Delete the Cloud Storage bucket
self.execute_cmd(["gsutil", "rm", "-r", f"gs://{GCS_FLEX_TEMPLATE_BUCKET_NAME}"])

Просмотреть файл

@ -16,6 +16,7 @@
# specific language governing permissions and limitations
# under the License.
import os
import shlex
import subprocess
from airflow.exceptions import AirflowException
@ -26,11 +27,11 @@ class LoggingCommandExecutor(LoggingMixin):
def execute_cmd(self, cmd, silent=False, cwd=None, env=None):
if silent:
self.log.info("Executing in silent mode: '%s'", " ".join(cmd))
self.log.info("Executing in silent mode: '%s'", " ".join([shlex.quote(c) for c in cmd]))
with open(os.devnull, 'w') as dev_null:
return subprocess.call(args=cmd, stdout=dev_null, stderr=subprocess.STDOUT, env=env, cwd=cwd)
else:
self.log.info("Executing: '%s'", " ".join(cmd))
self.log.info("Executing: '%s'", " ".join([shlex.quote(c) for c in cmd]))
process = subprocess.Popen(
args=cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
universal_newlines=True, cwd=cwd, env=env
@ -40,17 +41,17 @@ class LoggingCommandExecutor(LoggingMixin):
self.log.info("Stdout: %s", output)
self.log.info("Stderr: %s", err)
if retcode:
self.log.error("Error when executing %s", " ".join(cmd))
self.log.error("Error when executing %s", " ".join([shlex.quote(c) for c in cmd]))
return retcode
def check_output(self, cmd):
self.log.info("Executing for output: '%s'", " ".join(cmd))
self.log.info("Executing for output: '%s'", " ".join([shlex.quote(c) for c in cmd]))
process = subprocess.Popen(args=cmd, stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
output, err = process.communicate()
retcode = process.poll()
if retcode:
self.log.error("Error when executing '%s'", " ".join(cmd))
self.log.error("Error when executing '%s'", " ".join([shlex.quote(c) for c in cmd]))
self.log.info("Stdout: %s", output)
self.log.info("Stderr: %s", err)
raise AirflowException("Retcode {} on {} with stdout: {}, stderr: {}".