[AIRFLOW-3078] Basic operators for Google Compute Engine (#4022)

Add GceInstanceStartOperator, GceInstanceStopOperator and GceSetMachineTypeOperator.

Each operator includes:
- core logic
- input params validation
- unit tests
- presence in the example DAG
- docstrings
- How-to and Integration documentation

Additionally, in GceHook error checking if response is 200 OK was added:

Some types of errors are only visible in the response's "error" field
and the overall HTTP response is 200 OK.

That is why apart from checking if status is "done" we also check
if "error" is empty, and if not an exception is raised with error
message extracted from the "error" field of the response.

In this commit we also separated out Body Field Validator to
separate module in tools - this way it can be reused between
various GCP operators, it has proven to be usable in at least
two of them now.

Co-authored-by: sprzedwojski <szymon.przedwojski@polidea.com>
Co-authored-by: potiuk <jarek.potiuk@polidea.com>
This commit is contained in:
Szymon Przedwojski 2018-10-10 11:49:57 +02:00 коммит произвёл Kaxil Naik
Родитель 76ad6f0938
Коммит cdbdcae7c0
9 изменённых файлов: 1370 добавлений и 267 удалений

Просмотреть файл

@ -0,0 +1,108 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Example Airflow DAG that starts, stops and sets the machine type of a Google Compute
Engine instance.
This DAG relies on the following Airflow variables
https://airflow.apache.org/concepts.html#variables
* PROJECT_ID - Google Cloud Platform project where the Compute Engine instance exists.
* LOCATION - Google Cloud Platform zone where the instance exists.
* INSTANCE - Name of the Compute Engine instance.
* SHORT_MACHINE_TYPE_NAME - Machine type resource name to set, e.g. 'n1-standard-1'.
See https://cloud.google.com/compute/docs/machine-types
"""
import datetime
import airflow
from airflow import models
from airflow.contrib.operators.gcp_compute_operator import GceInstanceStartOperator, \
GceInstanceStopOperator, GceSetMachineTypeOperator
# [START howto_operator_gce_args]
PROJECT_ID = models.Variable.get('PROJECT_ID', '')
LOCATION = models.Variable.get('LOCATION', '')
INSTANCE = models.Variable.get('INSTANCE', '')
SHORT_MACHINE_TYPE_NAME = models.Variable.get('SHORT_MACHINE_TYPE_NAME', '')
SET_MACHINE_TYPE_BODY = {
'machineType': 'zones/{}/machineTypes/{}'.format(LOCATION, SHORT_MACHINE_TYPE_NAME)
}
default_args = {
'start_date': airflow.utils.dates.days_ago(1)
}
# [END howto_operator_gce_args]
with models.DAG(
'example_gcp_compute',
default_args=default_args,
schedule_interval=datetime.timedelta(days=1)
) as dag:
# [START howto_operator_gce_start]
gce_instance_start = GceInstanceStartOperator(
project_id=PROJECT_ID,
zone=LOCATION,
resource_id=INSTANCE,
task_id='gcp_compute_start_task'
)
# [END howto_operator_gce_start]
# Duplicate start for idempotence testing
gce_instance_start2 = GceInstanceStartOperator(
project_id=PROJECT_ID,
zone=LOCATION,
resource_id=INSTANCE,
task_id='gcp_compute_start_task2'
)
# [START howto_operator_gce_stop]
gce_instance_stop = GceInstanceStopOperator(
project_id=PROJECT_ID,
zone=LOCATION,
resource_id=INSTANCE,
task_id='gcp_compute_stop_task'
)
# [END howto_operator_gce_stop]
# Duplicate stop for idempotence testing
gce_instance_stop2 = GceInstanceStopOperator(
project_id=PROJECT_ID,
zone=LOCATION,
resource_id=INSTANCE,
task_id='gcp_compute_stop_task2'
)
# [START howto_operator_gce_set_machine_type]
gce_set_machine_type = GceSetMachineTypeOperator(
project_id=PROJECT_ID,
zone=LOCATION,
resource_id=INSTANCE,
body=SET_MACHINE_TYPE_BODY,
task_id='gcp_compute_set_machine_type'
)
# [END howto_operator_gce_set_machine_type]
# Duplicate set machine type for idempotence testing
gce_set_machine_type2 = GceSetMachineTypeOperator(
project_id=PROJECT_ID,
zone=LOCATION,
resource_id=INSTANCE,
body=SET_MACHINE_TYPE_BODY,
task_id='gcp_compute_set_machine_type2'
)
gce_instance_start >> gce_instance_start2 >> gce_instance_stop >> \
gce_instance_stop2 >> gce_set_machine_type >> gce_set_machine_type2

Просмотреть файл

@ -0,0 +1,167 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import time
from googleapiclient.discovery import build
from airflow import AirflowException
from airflow.contrib.hooks.gcp_api_base_hook import GoogleCloudBaseHook
# Number of retries - used by googleapiclient method calls to perform retries
# For requests that are "retriable"
NUM_RETRIES = 5
# Time to sleep between active checks of the operation results
TIME_TO_SLEEP_IN_SECONDS = 1
class GceOperationStatus:
PENDING = "PENDING"
RUNNING = "RUNNING"
DONE = "DONE"
# noinspection PyAbstractClass
class GceHook(GoogleCloudBaseHook):
"""
Hook for Google Compute Engine APIs.
"""
_conn = None
def __init__(self,
api_version,
gcp_conn_id='google_cloud_default',
delegate_to=None):
super(GceHook, self).__init__(gcp_conn_id, delegate_to)
self.api_version = api_version
def get_conn(self):
"""
Retrieves connection to Google Compute Engine.
:return: Google Compute Engine services object
:rtype: dict
"""
if not self._conn:
http_authorized = self._authorize()
self._conn = build('compute', self.api_version,
http=http_authorized, cache_discovery=False)
return self._conn
def start_instance(self, project_id, zone, resource_id):
"""
Starts an existing instance defined by project_id, zone and resource_id.
:param project_id: Google Cloud Platform project where the Compute Engine
instance exists.
:type project_id: str
:param zone: Google Cloud Platform zone where the instance exists.
:type zone: str
:param resource_id: Name of the Compute Engine instance resource.
:type resource_id: str
:return: True if the operation succeeded, raises an error otherwise
:rtype: bool
"""
response = self.get_conn().instances().start(
project=project_id,
zone=zone,
instance=resource_id
).execute(num_retries=NUM_RETRIES)
operation_name = response["name"]
return self._wait_for_operation_to_complete(project_id, zone, operation_name)
def stop_instance(self, project_id, zone, resource_id):
"""
Stops an instance defined by project_id, zone and resource_id.
:param project_id: Google Cloud Platform project where the Compute Engine
instance exists.
:type project_id: str
:param zone: Google Cloud Platform zone where the instance exists.
:type zone: str
:param resource_id: Name of the Compute Engine instance resource.
:type resource_id: str
:return: True if the operation succeeded, raises an error otherwise
:rtype: bool
"""
response = self.get_conn().instances().stop(
project=project_id,
zone=zone,
instance=resource_id
).execute(num_retries=NUM_RETRIES)
operation_name = response["name"]
return self._wait_for_operation_to_complete(project_id, zone, operation_name)
def set_machine_type(self, project_id, zone, resource_id, body):
"""
Sets machine type of an instance defined by project_id, zone and resource_id.
:param project_id: Google Cloud Platform project where the Compute Engine
instance exists.
:type project_id: str
:param zone: Google Cloud Platform zone where the instance exists.
:type zone: str
:param resource_id: Name of the Compute Engine instance resource.
:type resource_id: str
:param body: Body required by the Compute Engine setMachineType API,
as described in
https://cloud.google.com/compute/docs/reference/rest/v1/instances/setMachineType
:type body: dict
:return: True if the operation succeeded, raises an error otherwise
:rtype: bool
"""
response = self._execute_set_machine_type(project_id, zone, resource_id, body)
operation_name = response["name"]
return self._wait_for_operation_to_complete(project_id, zone, operation_name)
def _execute_set_machine_type(self, project_id, zone, resource_id, body):
return self.get_conn().instances().setMachineType(
project=project_id, zone=zone, instance=resource_id, body=body)\
.execute(num_retries=NUM_RETRIES)
def _wait_for_operation_to_complete(self, project_id, zone, operation_name):
"""
Waits for the named operation to complete - checks status of the
asynchronous call.
:param operation_name: name of the operation
:type operation_name: str
:return: True if the operation succeeded, raises an error otherwise
:rtype: bool
"""
service = self.get_conn()
while True:
operation_response = self._check_operation_status(
service, operation_name, project_id, zone)
if operation_response.get("status") == GceOperationStatus.DONE:
error = operation_response.get("error")
if error:
code = operation_response.get("httpErrorStatusCode")
msg = operation_response.get("httpErrorMessage")
# Extracting the errors list as string and trimming square braces
error_msg = str(error.get("errors"))[1:-1]
raise AirflowException("{} {}: ".format(code, msg) + error_msg)
# No meaningful info to return from the response in case of success
return True
time.sleep(TIME_TO_SLEEP_IN_SECONDS)
def _check_operation_status(self, service, operation_name, project_id, zone):
return service.zoneOperations().get(
project=project_id, zone=zone, operation=operation_name).execute(
num_retries=NUM_RETRIES)

Просмотреть файл

@ -0,0 +1,183 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from airflow import AirflowException
from airflow.contrib.hooks.gcp_compute_hook import GceHook
from airflow.contrib.utils.gcp_field_validator import GcpBodyFieldValidator
from airflow.models import BaseOperator
from airflow.utils.decorators import apply_defaults
class GceBaseOperator(BaseOperator):
"""
Abstract base operator for Google Compute Engine operators to inherit from.
"""
@apply_defaults
def __init__(self,
project_id,
zone,
resource_id,
gcp_conn_id='google_cloud_default',
api_version='v1',
*args, **kwargs):
self.project_id = project_id
self.zone = zone
self.full_location = 'projects/{}/zones/{}'.format(self.project_id,
self.zone)
self.resource_id = resource_id
self.gcp_conn_id = gcp_conn_id
self.api_version = api_version
self._validate_inputs()
self._hook = GceHook(gcp_conn_id=self.gcp_conn_id, api_version=self.api_version)
super(GceBaseOperator, self).__init__(*args, **kwargs)
def _validate_inputs(self):
if not self.project_id:
raise AirflowException("The required parameter 'project_id' is missing")
if not self.zone:
raise AirflowException("The required parameter 'zone' is missing")
if not self.resource_id:
raise AirflowException("The required parameter 'resource_id' is missing")
def execute(self, context):
pass
class GceInstanceStartOperator(GceBaseOperator):
"""
Start an instance in Google Compute Engine.
:param project_id: Google Cloud Platform project where the Compute Engine
instance exists.
:type project_id: str
:param zone: Google Cloud Platform zone where the instance exists.
:type zone: str
:param resource_id: Name of the Compute Engine instance resource.
:type resource_id: str
:param gcp_conn_id: The connection ID used to connect to Google Cloud Platform.
:type gcp_conn_id: str
:param api_version: API version used (e.g. v1).
:type api_version: str
"""
template_fields = ('project_id', 'zone', 'resource_id', 'gcp_conn_id', 'api_version')
@apply_defaults
def __init__(self,
project_id,
zone,
resource_id,
gcp_conn_id='google_cloud_default',
api_version='v1',
*args, **kwargs):
super(GceInstanceStartOperator, self).__init__(
project_id=project_id, zone=zone, resource_id=resource_id,
gcp_conn_id=gcp_conn_id, api_version=api_version, *args, **kwargs)
def execute(self, context):
return self._hook.start_instance(self.project_id, self.zone, self.resource_id)
class GceInstanceStopOperator(GceBaseOperator):
"""
Stop an instance in Google Compute Engine.
:param project_id: Google Cloud Platform project where the Compute Engine
instance exists.
:type project_id: str
:param zone: Google Cloud Platform zone where the instance exists.
:type zone: str
:param resource_id: Name of the Compute Engine instance resource.
:type resource_id: str
:param gcp_conn_id: The connection ID used to connect to Google Cloud Platform.
:type gcp_conn_id: str
:param api_version: API version used (e.g. v1).
:type api_version: str
"""
template_fields = ('project_id', 'zone', 'resource_id', 'gcp_conn_id', 'api_version')
@apply_defaults
def __init__(self,
project_id,
zone,
resource_id,
gcp_conn_id='google_cloud_default',
api_version='v1',
*args, **kwargs):
super(GceInstanceStopOperator, self).__init__(
project_id=project_id, zone=zone, resource_id=resource_id,
gcp_conn_id=gcp_conn_id, api_version=api_version, *args, **kwargs)
def execute(self, context):
return self._hook.stop_instance(self.project_id, self.zone, self.resource_id)
SET_MACHINE_TYPE_VALIDATION_SPECIFICATION = [
dict(name="machineType", regexp="^.+$"),
]
class GceSetMachineTypeOperator(GceBaseOperator):
"""
Changes the machine type for a stopped instance to the machine type specified in
the request.
:param project_id: Google Cloud Platform project where the Compute Engine
instance exists.
:type project_id: str
:param zone: Google Cloud Platform zone where the instance exists.
:type zone: str
:param resource_id: Name of the Compute Engine instance resource.
:type resource_id: str
:param body: Body required by the Compute Engine setMachineType API, as described in
https://cloud.google.com/compute/docs/reference/rest/v1/instances/setMachineType#request-body
:type body: dict
:param gcp_conn_id: The connection ID used to connect to Google Cloud Platform.
:type gcp_conn_id: str
:param api_version: API version used (e.g. v1).
:type api_version: str
"""
template_fields = ('project_id', 'zone', 'resource_id', 'gcp_conn_id', 'api_version')
@apply_defaults
def __init__(self,
project_id,
zone,
resource_id,
body,
gcp_conn_id='google_cloud_default',
api_version='v1',
validate_body=True,
*args, **kwargs):
self.body = body
self._field_validator = None
if validate_body:
self._field_validator = GcpBodyFieldValidator(
SET_MACHINE_TYPE_VALIDATION_SPECIFICATION, api_version=api_version)
super(GceSetMachineTypeOperator, self).__init__(
project_id=project_id, zone=zone, resource_id=resource_id,
gcp_conn_id=gcp_conn_id, api_version=api_version, *args, **kwargs)
def _validate_all_body_fields(self):
if self._field_validator:
self._field_validator.validate(self.body)
def execute(self, context):
self._validate_all_body_fields()
return self._hook.set_machine_type(self.project_id, self.zone,
self.resource_id, self.body)

Просмотреть файл

@ -20,277 +20,23 @@ import re
from googleapiclient.errors import HttpError
from airflow import AirflowException, LoggingMixin
from airflow import AirflowException
from airflow.contrib.utils.gcp_field_validator import GcpBodyFieldValidator, \
GcpFieldValidationException
from airflow.version import version
from airflow.models import BaseOperator
from airflow.contrib.hooks.gcp_function_hook import GcfHook
from airflow.utils.decorators import apply_defaults
# TODO: This whole section should be extracted later to contrib/tools/field_validator.py
COMPOSITE_FIELD_TYPES = ['union', 'dict']
class FieldValidationException(AirflowException):
"""
Thrown when validation finds dictionary field not valid according to specification.
"""
def __init__(self, message):
super(FieldValidationException, self).__init__(message)
class ValidationSpecificationException(AirflowException):
"""
Thrown when validation specification is wrong
(rather than dictionary being validated).
This should only happen during development as ideally
specification itself should not be invalid ;) .
"""
def __init__(self, message):
super(ValidationSpecificationException, self).__init__(message)
# TODO: make better description, add some examples
# TODO: move to contrib/utils folder when we reuse it.
class BodyFieldValidator(LoggingMixin):
"""
Validates correctness of request body according to specification.
The specification can describe various type of
fields including custom validation, and union of fields. This validator is meant
to be reusable by various operators
in the near future, but for now it is left as part of the Google Cloud Function,
so documentation about the
validator is not yet complete. To see what kind of specification can be used,
please take a look at
gcp_function_operator.CLOUD_FUNCTION_VALIDATION which specifies validation
for GCF deploy operator.
:param validation_specs: dictionary describing validation specification
:type validation_specs: [dict]
:param api_version: Version of the api used (for example v1)
:type api_version: str
"""
def __init__(self, validation_specs, api_version):
# type: ([dict], str) -> None
super(BodyFieldValidator, self).__init__()
self._validation_specs = validation_specs
self._api_version = api_version
@staticmethod
def _get_field_name_with_parent(field_name, parent):
if parent:
return parent + '.' + field_name
return field_name
@staticmethod
def _sanity_checks(children_validation_specs, field_type, full_field_path,
regexp, custom_validation, value):
# type: (dict, str, str, str, function, object) -> None
if value is None and field_type != 'union':
raise FieldValidationException(
"The required body field '{}' is missing. Please add it.".
format(full_field_path))
if regexp and field_type:
raise ValidationSpecificationException(
"The validation specification entry '{}' has both type and regexp. "
"The regexp is only allowed without type (i.e. assume type is 'str' "
"that can be validated with regexp)".format(full_field_path))
if children_validation_specs and field_type not in COMPOSITE_FIELD_TYPES:
raise ValidationSpecificationException(
"Nested fields are specified in field '{}' of type '{}'. "
"Nested fields are only allowed for fields of those types: ('{}').".
format(full_field_path, field_type, COMPOSITE_FIELD_TYPES))
if custom_validation and field_type:
raise ValidationSpecificationException(
"The validation specification field '{}' has both type and "
"custom_validation. Custom validation is only allowed without type.".
format(full_field_path))
@staticmethod
def _validate_regexp(full_field_path, regexp, value):
# type: (str, str, str) -> None
if not re.match(regexp, value):
# Note matching of only the beginning as we assume the regexps all-or-nothing
raise FieldValidationException(
"The body field '{}' of value '{}' does not match the field "
"specification regexp: '{}'.".
format(full_field_path, value, regexp))
def _validate_dict(self, children_validation_specs, full_field_path, value):
# type: (dict, str, dict) -> None
for child_validation_spec in children_validation_specs:
self._validate_field(validation_spec=child_validation_spec,
dictionary_to_validate=value,
parent=full_field_path)
for field_name in value.keys():
if field_name not in [spec['name'] for spec in children_validation_specs]:
self.log.warning(
"The field '{}' is in the body, but is not specified in the "
"validation specification '{}'. "
"This might be because you are using newer API version and "
"new field names defined for that version. Then the warning "
"can be safely ignored, or you might want to upgrade the operator"
"to the version that supports the new API version.".format(
self._get_field_name_with_parent(field_name, full_field_path),
children_validation_specs))
def _validate_union(self, children_validation_specs, full_field_path,
dictionary_to_validate):
# type: (dict, str, dict) -> None
field_found = False
found_field_name = None
for child_validation_spec in children_validation_specs:
# Forcing optional so that we do not have to type optional = True
# in specification for all union fields
new_field_found = self._validate_field(
validation_spec=child_validation_spec,
dictionary_to_validate=dictionary_to_validate,
parent=full_field_path,
force_optional=True)
field_name = child_validation_spec['name']
if new_field_found and field_found:
raise FieldValidationException(
"The mutually exclusive fields '{}' and '{}' belonging to the "
"union '{}' are both present. Please remove one".
format(field_name, found_field_name, full_field_path))
if new_field_found:
field_found = True
found_field_name = field_name
if not field_found:
self.log.warning(
"There is no '{}' union defined in the body {}. "
"Validation expected one of '{}' but could not find any. It's possible "
"that you are using newer API version and there is another union variant "
"defined for that version. Then the warning can be safely ignored, "
"or you might want to upgrade the operator to the version that "
"supports the new API version.".format(
full_field_path,
dictionary_to_validate,
[field['name'] for field in children_validation_specs]))
def _validate_field(self, validation_spec, dictionary_to_validate, parent=None,
force_optional=False):
"""
Validates if field is OK.
:param validation_spec: specification of the field
:type validation_spec: dict
:param dictionary_to_validate: dictionary where the field should be present
:type dictionary_to_validate: dict
:param parent: full path of parent field
:type parent: str
:param force_optional: forces the field to be optional
(all union fields have force_optional set to True)
:type force_optional: bool
:return: True if the field is present
"""
field_name = validation_spec['name']
field_type = validation_spec.get('type')
optional = validation_spec.get('optional')
regexp = validation_spec.get('regexp')
children_validation_specs = validation_spec.get('fields')
required_api_version = validation_spec.get('api_version')
custom_validation = validation_spec.get('custom_validation')
full_field_path = self._get_field_name_with_parent(field_name=field_name,
parent=parent)
if required_api_version and required_api_version != self._api_version:
self.log.debug(
"Skipping validation of the field '{}' for API version '{}' "
"as it is only valid for API version '{}'".
format(field_name, self._api_version, required_api_version))
return False
value = dictionary_to_validate.get(field_name)
if (optional or force_optional) and value is None:
self.log.debug("The optional field '{}' is missing. That's perfectly OK.".
format(full_field_path))
return False
# Certainly down from here the field is present (value is not None)
# so we should only return True from now on
self._sanity_checks(children_validation_specs=children_validation_specs,
field_type=field_type,
full_field_path=full_field_path,
regexp=regexp,
custom_validation=custom_validation,
value=value)
if regexp:
self._validate_regexp(full_field_path, regexp, value)
elif field_type == 'dict':
if not isinstance(value, dict):
raise FieldValidationException(
"The field '{}' should be dictionary type according to "
"specification '{}' but it is '{}'".
format(full_field_path, validation_spec, value))
if children_validation_specs is None:
self.log.debug(
"The dict field '{}' has no nested fields defined in the "
"specification '{}'. That's perfectly ok - it's content will "
"not be validated."
.format(full_field_path, validation_spec))
else:
self._validate_dict(children_validation_specs, full_field_path, value)
elif field_type == 'union':
if not children_validation_specs:
raise ValidationSpecificationException(
"The union field '{}' has no nested fields "
"defined in specification '{}'. Unions should have at least one "
"nested field defined.".format(full_field_path, validation_spec))
self._validate_union(children_validation_specs, full_field_path,
dictionary_to_validate)
elif custom_validation:
try:
custom_validation(value)
except Exception as e:
raise FieldValidationException(
"Error while validating custom field '{}' specified by '{}': '{}'".
format(full_field_path, validation_spec, e))
elif field_type is None:
self.log.debug("The type of field '{}' is not specified in '{}'. "
"Not validating its content.".
format(full_field_path, validation_spec))
else:
raise ValidationSpecificationException(
"The field '{}' is of type '{}' in specification '{}'."
"This type is unknown to validation!".format(
full_field_path, field_type, validation_spec))
return True
def validate(self, body_to_validate):
"""
Validates if the body (dictionary) follows specification that the validator was
instantiated with. Raises ValidationSpecificationException or
ValidationFieldException in case of problems with specification or the
body not conforming to the specification respectively.
:param body_to_validate: body that must follow the specification
:type body_to_validate: dict
:return: None
"""
try:
for validation_spec in self._validation_specs:
self._validate_field(validation_spec=validation_spec,
dictionary_to_validate=body_to_validate)
except FieldValidationException as e:
raise FieldValidationException(
"There was an error when validating: field '{}': '{}'".
format(body_to_validate, e))
# TODO End of field validator to be extracted
def _validate_available_memory_in_mb(value):
if int(value) <= 0:
raise FieldValidationException("The available memory has to be greater than 0")
raise GcpFieldValidationException("The available memory has to be greater than 0")
def _validate_max_instances(value):
if int(value) <= 0:
raise FieldValidationException(
raise GcpFieldValidationException(
"The max instances parameter has to be greater than 0")
@ -378,9 +124,10 @@ class GcfFunctionDeployOperator(BaseOperator):
self.api_version = api_version
self.zip_path = zip_path
self.zip_path_preprocessor = ZipPathPreprocessor(body, zip_path)
self.validate_body = validate_body
self._field_validator = BodyFieldValidator(CLOUD_FUNCTION_VALIDATION,
api_version=api_version)
self._field_validator = None
if validate_body:
self._field_validator = GcpBodyFieldValidator(CLOUD_FUNCTION_VALIDATION,
api_version=api_version)
self._hook = GcfHook(gcp_conn_id=self.gcp_conn_id, api_version=self.api_version)
self._validate_inputs()
super(GcfFunctionDeployOperator, self).__init__(*args, **kwargs)
@ -395,7 +142,8 @@ class GcfFunctionDeployOperator(BaseOperator):
self.zip_path_preprocessor.preprocess_body()
def _validate_all_body_fields(self):
self._field_validator.validate(self.body)
if self._field_validator:
self._field_validator.validate(self.body)
def _create_new_function(self):
self._hook.create_new_function(self.full_location, self.body)
@ -406,8 +154,8 @@ class GcfFunctionDeployOperator(BaseOperator):
def _check_if_function_exists(self):
name = self.body.get('name')
if not name:
raise FieldValidationException("The 'name' field should be present in "
"body: '{}'.".format(self.body))
raise GcpFieldValidationException("The 'name' field should be present in "
"body: '{}'.".format(self.body))
try:
self._hook.get_function(name)
except HttpError as e:
@ -430,8 +178,7 @@ class GcfFunctionDeployOperator(BaseOperator):
def execute(self, context):
if self.zip_path_preprocessor.should_upload_function():
self.body[SOURCE_UPLOAD_URL] = self._upload_source_code()
if self.validate_body:
self._validate_all_body_fields()
self._validate_all_body_fields()
self._set_airflow_version_label()
if not self._check_if_function_exists():
self._create_new_function()

Просмотреть файл

@ -0,0 +1,417 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Validator for body fields sent via GCP API.
The validator performs validation of the body (being dictionary of fields) that
is sent in the API request to Google Cloud (via googleclient API usually).
Context
-------
The specification mostly focuses on helping Airflow DAG developers in the development
phase. You can build your own GCP operator (such as GcfDeployOperator for example) which
can have built-in validation specification for the particular API. It's super helpful
when developer plays with different fields and their values at the initial phase of
DAG development. Most of the Google Cloud APIs perform their own validation on the
server side, but most of the requests are asynchronous and you need to wait for result
of the operation. This takes precious times and slows
down iteration over the API. BodyFieldValidator is meant to be used on the client side
and it should therefore provide an instant feedback to the developer on misspelled or
wrong type of parameters.
The validation should be performed in "execute()" method call in order to allow
template parameters to be expanded before validation is performed.
Types of fields
---------------
Specification is an array of dictionaries - each dictionary describes field, its type,
validation, optionality, api_version supported and nested fields (for unions and dicts).
Typically (for clarity and in order to aid syntax highlighting) the array of
dicts should be defined as series of dict() executions. Fragment of example
specification might look as follows:
```
SPECIFICATION =[
dict(name="an_union", type="union", optional=True, fields=[
dict(name="variant_1", type="dict"),
dict(name="variant_2", regexp=r'^.+$', api_version='v1beta2'),
),
dict(name="an_union", type="dict", fields=[
dict(name="field_1", type="dict"),
dict(name="field_2", regexp=r'^.+$'),
),
...
]
```
Each field should have key = "name" indicating field name. The field can be of one of the
following types:
* Dict fields: (key = "type", value="dict"):
Field of this type should contain nested fields in form of an array of dicts.
Each of the fields in the array is then expected (unless marked as optional)
and validated recursively. If an extra field is present in the dictionary, warning is
printed in log file (but the validation succeeds - see the Forward-compatibility notes)
* Union fields (key = "type", value="union"): field of this type should contain nested
fields in form of an array of dicts. One of the fields (and only one) should be
present (unless the union is marked as optional). If more than one union field is
present, FieldValidationException is raised. If none of the union fields is
present - warning is printed in the log (see below Forward-compatibility notes).
* Regexp-validated fields: (key = "regexp") - fields of this type are assumed to be
strings and they are validated with the regexp specified. Remember that the regexps
should ideally contain ^ at the beginning and $ at the end to make sure that
the whole field content is validated. Typically such regexp
validations should be used carefully and sparingly (see Forward-compatibility
notes below). Most of regexp validation should be at most r'^.+$'.
* Custom-validated fields: (key = "custom_validation") - fields of this type are validated
using method specified via custom_validation field. Any exception thrown in the custom
validation will be turned into FieldValidationException and will cause validation to
fail. Such custom validations might be used to check numeric fields (including
ranges of values), booleans or any other types of fields.
* API version: (key="api_version") if API version is specified, then the field will only
be validated when api_version used at field validator initialization matches exactly the
the version specified. If you want to declare fields that are available in several
versions of the APIs, you should specify the field as many times as many API versions
should be supported (each time with different API version).
* if none of the keys ("type", "regexp", "custom_validation" - the field is not validated
You can see some of the field examples in EXAMPLE_VALIDATION_SPECIFICATION.
Forward-compatibility notes
---------------------------
Certain decisions are crucial to allow the client APIs to work also with future API
versions. Since body attached is passed to the APIs call, this is entirely
possible to pass-through any new fields in the body (for future API versions) -
albeit without validation on the client side - they can and will still be validated
on the server side usually.
Here are the guidelines that you should follow to make validation forward-compatible:
* most of the fields are not validated for their content. It's possible to use regexp
in some specific cases that are guaranteed not to change in the future, but for most
fields regexp validation should be r'^.+$' indicating check for non-emptiness
* api_version is not validated - user can pass any future version of the api here. The API
version is only used to filter parameters that are marked as present in this api version
any new (not present in the specification) fields in the body are allowed (not verified)
For dictionaries, new fields can be added to dictionaries by future calls. However if an
unknown field in dictionary is added, a warning is logged by the client (but validation
remains successful). This is very nice feature to protect against typos in names.
* For unions, newly added union variants can be added by future calls and they will
pass validation, however the content or presence of those fields will not be validated.
This means that its possible to send a new non-validated union field together with an
old validated field and this problem will not be detected by the client. In such case
warning will be printed.
* When you add validator to an operator, you should also add ``validate_body`` parameter
(default = True) to __init__ of such operators - when it is set to False,
no validation should be performed. This is a safeguard for totally unpredicted and
backwards-incompatible changes that might sometimes occur in the APIs.
"""
import re
from airflow import LoggingMixin, AirflowException
COMPOSITE_FIELD_TYPES = ['union', 'dict']
class GcpFieldValidationException(AirflowException):
"""Thrown when validation finds dictionary field not valid according to specification.
"""
def __init__(self, message):
super(GcpFieldValidationException, self).__init__(message)
class GcpValidationSpecificationException(AirflowException):
"""Thrown when validation specification is wrong.
This should only happen during development as ideally
specification itself should not be invalid ;) .
"""
def __init__(self, message):
super(GcpValidationSpecificationException, self).__init__(message)
def _int_greater_than_zero(value):
if int(value) <= 0:
raise GcpFieldValidationException("The available memory has to be greater than 0")
EXAMPLE_VALIDATION_SPECIFICATION = [
dict(name="name", regexp="^.+$"),
dict(name="description", regexp="^.+$", optional=True),
dict(name="availableMemoryMb", custom_validation=_int_greater_than_zero,
optional=True),
dict(name="labels", optional=True, type="dict"),
dict(name="an_union", type="union", fields=[
dict(name="variant_1", regexp=r'^.+$'),
dict(name="variant_2", regexp=r'^.+$', api_version='v1beta2'),
dict(name="variant_3", type="dict", fields=[
dict(name="url", regexp=r'^.+$')
]),
dict(name="variant_4")
]),
]
class GcpBodyFieldValidator(LoggingMixin):
"""Validates correctness of request body according to specification.
The specification can describe various type of
fields including custom validation, and union of fields. This validator is
to be reusable by various operators. See the EXAMPLE_VALIDATION_SPECIFICATION
for some examples and explanations of how to create specification.
:param validation_specs: dictionary describing validation specification
:type validation_specs: [dict]
:param api_version: Version of the api used (for example v1)
:type api_version: str
"""
def __init__(self, validation_specs, api_version):
# type: ([dict], str) -> None
super(GcpBodyFieldValidator, self).__init__()
self._validation_specs = validation_specs
self._api_version = api_version
@staticmethod
def _get_field_name_with_parent(field_name, parent):
if parent:
return parent + '.' + field_name
return field_name
@staticmethod
def _sanity_checks(children_validation_specs, field_type, full_field_path,
regexp, custom_validation, value):
# type: (dict, str, str, str, function, object) -> None
if value is None and field_type != 'union':
raise GcpFieldValidationException(
"The required body field '{}' is missing. Please add it.".
format(full_field_path))
if regexp and field_type:
raise GcpValidationSpecificationException(
"The validation specification entry '{}' has both type and regexp. "
"The regexp is only allowed without type (i.e. assume type is 'str' "
"that can be validated with regexp)".format(full_field_path))
if children_validation_specs and field_type not in COMPOSITE_FIELD_TYPES:
raise GcpValidationSpecificationException(
"Nested fields are specified in field '{}' of type '{}'. "
"Nested fields are only allowed for fields of those types: ('{}').".
format(full_field_path, field_type, COMPOSITE_FIELD_TYPES))
if custom_validation and field_type:
raise GcpValidationSpecificationException(
"The validation specification field '{}' has both type and "
"custom_validation. Custom validation is only allowed without type.".
format(full_field_path))
@staticmethod
def _validate_regexp(full_field_path, regexp, value):
# type: (str, str, str) -> None
if not re.match(regexp, value):
# Note matching of only the beginning as we assume the regexps all-or-nothing
raise GcpFieldValidationException(
"The body field '{}' of value '{}' does not match the field "
"specification regexp: '{}'.".
format(full_field_path, value, regexp))
def _validate_dict(self, children_validation_specs, full_field_path, value):
# type: (dict, str, dict) -> None
for child_validation_spec in children_validation_specs:
self._validate_field(validation_spec=child_validation_spec,
dictionary_to_validate=value,
parent=full_field_path)
all_dict_keys = [spec['name'] for spec in children_validation_specs]
for field_name in value.keys():
if field_name not in all_dict_keys:
self.log.warning(
"The field '{}' is in the body, but is not specified in the "
"validation specification '{}'. "
"This might be because you are using newer API version and "
"new field names defined for that version. Then the warning "
"can be safely ignored, or you might want to upgrade the operator"
"to the version that supports the new API version.".format(
self._get_field_name_with_parent(field_name, full_field_path),
children_validation_specs))
def _validate_union(self, children_validation_specs, full_field_path,
dictionary_to_validate):
# type: (dict, str, dict) -> None
field_found = False
found_field_name = None
for child_validation_spec in children_validation_specs:
# Forcing optional so that we do not have to type optional = True
# in specification for all union fields
new_field_found = self._validate_field(
validation_spec=child_validation_spec,
dictionary_to_validate=dictionary_to_validate,
parent=full_field_path,
force_optional=True)
field_name = child_validation_spec['name']
if new_field_found and field_found:
raise GcpFieldValidationException(
"The mutually exclusive fields '{}' and '{}' belonging to the "
"union '{}' are both present. Please remove one".
format(field_name, found_field_name, full_field_path))
if new_field_found:
field_found = True
found_field_name = field_name
if not field_found:
self.log.warning(
"There is no '{}' union defined in the body {}. "
"Validation expected one of '{}' but could not find any. It's possible "
"that you are using newer API version and there is another union variant "
"defined for that version. Then the warning can be safely ignored, "
"or you might want to upgrade the operator to the version that "
"supports the new API version.".format(
full_field_path,
dictionary_to_validate,
[field['name'] for field in children_validation_specs]))
def _validate_field(self, validation_spec, dictionary_to_validate, parent=None,
force_optional=False):
"""
Validates if field is OK.
:param validation_spec: specification of the field
:type validation_spec: dict
:param dictionary_to_validate: dictionary where the field should be present
:type dictionary_to_validate: dict
:param parent: full path of parent field
:type parent: str
:param force_optional: forces the field to be optional
(all union fields have force_optional set to True)
:type force_optional: bool
:return: True if the field is present
"""
field_name = validation_spec['name']
field_type = validation_spec.get('type')
optional = validation_spec.get('optional')
regexp = validation_spec.get('regexp')
children_validation_specs = validation_spec.get('fields')
required_api_version = validation_spec.get('api_version')
custom_validation = validation_spec.get('custom_validation')
full_field_path = self._get_field_name_with_parent(field_name=field_name,
parent=parent)
if required_api_version and required_api_version != self._api_version:
self.log.debug(
"Skipping validation of the field '{}' for API version '{}' "
"as it is only valid for API version '{}'".
format(field_name, self._api_version, required_api_version))
return False
value = dictionary_to_validate.get(field_name)
if (optional or force_optional) and value is None:
self.log.debug("The optional field '{}' is missing. That's perfectly OK.".
format(full_field_path))
return False
# Certainly down from here the field is present (value is not None)
# so we should only return True from now on
self._sanity_checks(children_validation_specs=children_validation_specs,
field_type=field_type,
full_field_path=full_field_path,
regexp=regexp,
custom_validation=custom_validation,
value=value)
if regexp:
self._validate_regexp(full_field_path, regexp, value)
elif field_type == 'dict':
if not isinstance(value, dict):
raise GcpFieldValidationException(
"The field '{}' should be dictionary type according to "
"specification '{}' but it is '{}'".
format(full_field_path, validation_spec, value))
if children_validation_specs is None:
self.log.debug(
"The dict field '{}' has no nested fields defined in the "
"specification '{}'. That's perfectly ok - it's content will "
"not be validated."
.format(full_field_path, validation_spec))
else:
self._validate_dict(children_validation_specs, full_field_path, value)
elif field_type == 'union':
if not children_validation_specs:
raise GcpValidationSpecificationException(
"The union field '{}' has no nested fields "
"defined in specification '{}'. Unions should have at least one "
"nested field defined.".format(full_field_path, validation_spec))
self._validate_union(children_validation_specs, full_field_path,
dictionary_to_validate)
elif custom_validation:
try:
custom_validation(value)
except Exception as e:
raise GcpFieldValidationException(
"Error while validating custom field '{}' specified by '{}': '{}'".
format(full_field_path, validation_spec, e))
elif field_type is None:
self.log.debug("The type of field '{}' is not specified in '{}'. "
"Not validating its content.".
format(full_field_path, validation_spec))
else:
raise GcpValidationSpecificationException(
"The field '{}' is of type '{}' in specification '{}'."
"This type is unknown to validation!".format(
full_field_path, field_type, validation_spec))
return True
def validate(self, body_to_validate):
"""
Validates if the body (dictionary) follows specification that the validator was
instantiated with. Raises ValidationSpecificationException or
ValidationFieldException in case of problems with specification or the
body not conforming to the specification respectively.
:param body_to_validate: body that must follow the specification
:type body_to_validate: dict
:return: None
"""
try:
for validation_spec in self._validation_specs:
self._validate_field(validation_spec=validation_spec,
dictionary_to_validate=body_to_validate)
except GcpFieldValidationException as e:
raise GcpFieldValidationException(
"There was an error when validating: body '{}': '{}'".
format(body_to_validate, e))
all_field_names = [spec['name'] for spec in self._validation_specs
if spec.get('type') != 'union' and
spec.get('api_version') != self._api_version]
all_union_fields = [spec for spec in self._validation_specs
if spec.get('type') == 'union']
for union_field in all_union_fields:
all_field_names.extend(
[nested_union_spec['name'] for nested_union_spec in union_field['fields']
if nested_union_spec.get('type') != 'union' and
nested_union_spec.get('api_version') != self._api_version])
for field_name in body_to_validate.keys():
if field_name not in all_field_names:
self.log.warning(
"The field '{}' is in the body, but is not specified in the "
"validation specification '{}'. "
"This might be because you are using newer API version and "
"new field names defined for that version. Then the warning "
"can be safely ignored, or you might want to upgrade the operator"
"to the version that supports the new API version.".format(
field_name, self._validation_specs))

Просмотреть файл

@ -102,6 +102,62 @@ to execute a BigQuery load job.
:start-after: [START howto_operator_gcs_to_bq]
:end-before: [END howto_operator_gcs_to_bq]
GceInstanceStartOperator
^^^^^^^^^^^^^^^^^^^^^^^^
Allows to start an existing Google Compute Engine instance.
In this example parameter values are extracted from Airflow variables.
Moreover, the ``default_args`` dict is used to pass common arguments to all operators in a single DAG.
.. literalinclude:: ../../airflow/contrib/example_dags/example_gcp_compute.py
:language: python
:start-after: [START howto_operator_gce_args]
:end-before: [END howto_operator_gce_args]
Define the :class:`~airflow.contrib.operators.gcp_compute_operator
.GceInstanceStartOperator` by passing the required arguments to the constructor.
.. literalinclude:: ../../airflow/contrib/example_dags/example_gcp_compute.py
:language: python
:dedent: 4
:start-after: [START howto_operator_gce_start]
:end-before: [END howto_operator_gce_start]
GceInstanceStopOperator
^^^^^^^^^^^^^^^^^^^^^^^
Allows to stop an existing Google Compute Engine instance.
For parameter definition take a look at :class:`~airflow.contrib.operators.gcp_compute_operator.GceInstanceStartOperator` above.
Define the :class:`~airflow.contrib.operators.gcp_compute_operator
.GceInstanceStopOperator` by passing the required arguments to the constructor.
.. literalinclude:: ../../airflow/contrib/example_dags/example_gcp_compute.py
:language: python
:dedent: 4
:start-after: [START howto_operator_gce_stop]
:end-before: [END howto_operator_gce_stop]
GceSetMachineTypeOperator
^^^^^^^^^^^^^^^^^^^^^^^^^
Allows to change the machine type for a stopped instance to the specified machine type.
For parameter definition take a look at :class:`~airflow.contrib.operators.gcp_compute_operator.GceInstanceStartOperator` above.
Define the :class:`~airflow.contrib.operators.gcp_compute_operator
.GceSetMachineTypeOperator` by passing the required arguments to the constructor.
.. literalinclude:: ../../airflow/contrib/example_dags/example_gcp_compute.py
:language: python
:dedent: 4
:start-after: [START howto_operator_gce_set_machine_type]
:end-before: [END howto_operator_gce_set_machine_type]
GcfFunctionDeleteOperator
^^^^^^^^^^^^^^^^^^^^^^^^^

Просмотреть файл

@ -457,6 +457,37 @@ BigQueryHook
.. autoclass:: airflow.contrib.hooks.bigquery_hook.BigQueryHook
:members:
Compute Engine
''''''''''''''
Compute Engine Operators
""""""""""""""""""""""""
- :ref:`GceInstanceStartOperator` : start an existing Google Compute Engine instance.
- :ref:`GceInstanceStopOperator` : stop an existing Google Compute Engine instance.
- :ref:`GceSetMachineTypeOperator` : change the machine type for a stopped instance.
.. _GceInstanceStartOperator:
GceInstanceStartOperator
^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: airflow.contrib.operators.gcp_compute_operator.GceInstanceStartOperator
.. _GceInstanceStopOperator:
GceInstanceStopOperator
^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: airflow.contrib.operators.gcp_compute_operator.GceInstanceStopOperator
.. _GceSetMachineTypeOperator:
GceSetMachineTypeOperator
^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: airflow.contrib.operators.gcp_compute_operator.GceSetMachineTypeOperator
Cloud Functions
'''''''''''''''

Просмотреть файл

@ -0,0 +1,377 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import ast
import unittest
from airflow import AirflowException, configuration
from airflow.contrib.operators.gcp_compute_operator import GceInstanceStartOperator, \
GceInstanceStopOperator, GceSetMachineTypeOperator
from airflow.models import TaskInstance, DAG
from airflow.utils import timezone
try:
# noinspection PyProtectedMember
from unittest import mock
except ImportError:
try:
import mock
except ImportError:
mock = None
PROJECT_ID = 'project-id'
LOCATION = 'zone'
RESOURCE_ID = 'resource-id'
SHORT_MACHINE_TYPE_NAME = 'n1-machine-type'
SET_MACHINE_TYPE_BODY = {
'machineType': 'zones/{}/machineTypes/{}'.format(LOCATION, SHORT_MACHINE_TYPE_NAME)
}
DEFAULT_DATE = timezone.datetime(2017, 1, 1)
class GceInstanceStartTest(unittest.TestCase):
@mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook')
def test_instance_start(self, mock_hook):
mock_hook.return_value.start_instance.return_value = True
op = GceInstanceStartOperator(
project_id=PROJECT_ID,
zone=LOCATION,
resource_id=RESOURCE_ID,
task_id='id'
)
result = op.execute(None)
mock_hook.assert_called_once_with(api_version='v1',
gcp_conn_id='google_cloud_default')
mock_hook.return_value.start_instance.assert_called_once_with(
PROJECT_ID, LOCATION, RESOURCE_ID
)
self.assertTrue(result)
# Setting all of the operator's input parameters as templated dag_ids
# (could be anything else) just to test if the templating works for all fields
@mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook')
def test_instance_start_with_templates(self, mock_hook):
dag_id = 'test_dag_id'
configuration.load_test_config()
args = {
'start_date': DEFAULT_DATE
}
self.dag = DAG(dag_id, default_args=args)
op = GceInstanceStartOperator(
project_id='{{ dag.dag_id }}',
zone='{{ dag.dag_id }}',
resource_id='{{ dag.dag_id }}',
gcp_conn_id='{{ dag.dag_id }}',
api_version='{{ dag.dag_id }}',
task_id='id',
dag=self.dag
)
ti = TaskInstance(op, DEFAULT_DATE)
ti.render_templates()
self.assertEqual(dag_id, getattr(op, 'project_id'))
self.assertEqual(dag_id, getattr(op, 'zone'))
self.assertEqual(dag_id, getattr(op, 'resource_id'))
self.assertEqual(dag_id, getattr(op, 'gcp_conn_id'))
self.assertEqual(dag_id, getattr(op, 'api_version'))
@mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook')
def test_start_should_throw_ex_when_missing_project_id(self, mock_hook):
with self.assertRaises(AirflowException) as cm:
op = GceInstanceStartOperator(
project_id="",
zone=LOCATION,
resource_id=RESOURCE_ID,
task_id='id'
)
op.execute(None)
err = cm.exception
self.assertIn("The required parameter 'project_id' is missing", str(err))
mock_hook.assert_not_called()
@mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook')
def test_start_should_throw_ex_when_missing_zone(self, mock_hook):
with self.assertRaises(AirflowException) as cm:
op = GceInstanceStartOperator(
project_id=PROJECT_ID,
zone="",
resource_id=RESOURCE_ID,
task_id='id'
)
op.execute(None)
err = cm.exception
self.assertIn("The required parameter 'zone' is missing", str(err))
mock_hook.assert_not_called()
@mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook')
def test_start_should_throw_ex_when_missing_resource_id(self, mock_hook):
with self.assertRaises(AirflowException) as cm:
op = GceInstanceStartOperator(
project_id=PROJECT_ID,
zone=LOCATION,
resource_id="",
task_id='id'
)
op.execute(None)
err = cm.exception
self.assertIn("The required parameter 'resource_id' is missing", str(err))
mock_hook.assert_not_called()
@mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook')
def test_instance_stop(self, mock_hook):
mock_hook.return_value.stop_instance.return_value = True
op = GceInstanceStopOperator(
project_id=PROJECT_ID,
zone=LOCATION,
resource_id=RESOURCE_ID,
task_id='id'
)
result = op.execute(None)
mock_hook.assert_called_once_with(api_version='v1',
gcp_conn_id='google_cloud_default')
mock_hook.return_value.stop_instance.assert_called_once_with(
PROJECT_ID, LOCATION, RESOURCE_ID
)
self.assertTrue(result)
# Setting all of the operator's input parameters as templated dag_ids
# (could be anything else) just to test if the templating works for all fields
@mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook')
def test_instance_stop_with_templates(self, mock_hook):
dag_id = 'test_dag_id'
configuration.load_test_config()
args = {
'start_date': DEFAULT_DATE
}
self.dag = DAG(dag_id, default_args=args)
op = GceInstanceStopOperator(
project_id='{{ dag.dag_id }}',
zone='{{ dag.dag_id }}',
resource_id='{{ dag.dag_id }}',
gcp_conn_id='{{ dag.dag_id }}',
api_version='{{ dag.dag_id }}',
task_id='id',
dag=self.dag
)
ti = TaskInstance(op, DEFAULT_DATE)
ti.render_templates()
self.assertEqual(dag_id, getattr(op, 'project_id'))
self.assertEqual(dag_id, getattr(op, 'zone'))
self.assertEqual(dag_id, getattr(op, 'resource_id'))
self.assertEqual(dag_id, getattr(op, 'gcp_conn_id'))
self.assertEqual(dag_id, getattr(op, 'api_version'))
@mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook')
def test_stop_should_throw_ex_when_missing_project_id(self, mock_hook):
with self.assertRaises(AirflowException) as cm:
op = GceInstanceStopOperator(
project_id="",
zone=LOCATION,
resource_id=RESOURCE_ID,
task_id='id'
)
op.execute(None)
err = cm.exception
self.assertIn("The required parameter 'project_id' is missing", str(err))
mock_hook.assert_not_called()
@mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook')
def test_stop_should_throw_ex_when_missing_zone(self, mock_hook):
with self.assertRaises(AirflowException) as cm:
op = GceInstanceStopOperator(
project_id=PROJECT_ID,
zone="",
resource_id=RESOURCE_ID,
task_id='id'
)
op.execute(None)
err = cm.exception
self.assertIn("The required parameter 'zone' is missing", str(err))
mock_hook.assert_not_called()
@mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook')
def test_stop_should_throw_ex_when_missing_resource_id(self, mock_hook):
with self.assertRaises(AirflowException) as cm:
op = GceInstanceStopOperator(
project_id=PROJECT_ID,
zone=LOCATION,
resource_id="",
task_id='id'
)
op.execute(None)
err = cm.exception
self.assertIn("The required parameter 'resource_id' is missing", str(err))
mock_hook.assert_not_called()
@mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook')
def test_set_machine_type(self, mock_hook):
mock_hook.return_value.set_machine_type.return_value = True
op = GceSetMachineTypeOperator(
project_id=PROJECT_ID,
zone=LOCATION,
resource_id=RESOURCE_ID,
body=SET_MACHINE_TYPE_BODY,
task_id='id'
)
result = op.execute(None)
mock_hook.assert_called_once_with(api_version='v1',
gcp_conn_id='google_cloud_default')
mock_hook.return_value.set_machine_type.assert_called_once_with(
PROJECT_ID, LOCATION, RESOURCE_ID, SET_MACHINE_TYPE_BODY
)
self.assertTrue(result)
# Setting all of the operator's input parameters as templated dag_ids
# (could be anything else) just to test if the templating works for all fields
@mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook')
def test_set_machine_type_with_templates(self, mock_hook):
dag_id = 'test_dag_id'
configuration.load_test_config()
args = {
'start_date': DEFAULT_DATE
}
self.dag = DAG(dag_id, default_args=args)
op = GceSetMachineTypeOperator(
project_id='{{ dag.dag_id }}',
zone='{{ dag.dag_id }}',
resource_id='{{ dag.dag_id }}',
body={},
gcp_conn_id='{{ dag.dag_id }}',
api_version='{{ dag.dag_id }}',
task_id='id',
dag=self.dag
)
ti = TaskInstance(op, DEFAULT_DATE)
ti.render_templates()
self.assertEqual(dag_id, getattr(op, 'project_id'))
self.assertEqual(dag_id, getattr(op, 'zone'))
self.assertEqual(dag_id, getattr(op, 'resource_id'))
self.assertEqual(dag_id, getattr(op, 'gcp_conn_id'))
self.assertEqual(dag_id, getattr(op, 'api_version'))
@mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook')
def test_set_machine_type_should_throw_ex_when_missing_project_id(self, mock_hook):
with self.assertRaises(AirflowException) as cm:
op = GceSetMachineTypeOperator(
project_id="",
zone=LOCATION,
resource_id=RESOURCE_ID,
body=SET_MACHINE_TYPE_BODY,
task_id='id'
)
op.execute(None)
err = cm.exception
self.assertIn("The required parameter 'project_id' is missing", str(err))
mock_hook.assert_not_called()
@mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook')
def test_set_machine_type_should_throw_ex_when_missing_zone(self, mock_hook):
with self.assertRaises(AirflowException) as cm:
op = GceSetMachineTypeOperator(
project_id=PROJECT_ID,
zone="",
resource_id=RESOURCE_ID,
body=SET_MACHINE_TYPE_BODY,
task_id='id'
)
op.execute(None)
err = cm.exception
self.assertIn("The required parameter 'zone' is missing", str(err))
mock_hook.assert_not_called()
@mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook')
def test_set_machine_type_should_throw_ex_when_missing_resource_id(self, mock_hook):
with self.assertRaises(AirflowException) as cm:
op = GceSetMachineTypeOperator(
project_id=PROJECT_ID,
zone=LOCATION,
resource_id="",
body=SET_MACHINE_TYPE_BODY,
task_id='id'
)
op.execute(None)
err = cm.exception
self.assertIn("The required parameter 'resource_id' is missing", str(err))
mock_hook.assert_not_called()
@mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook')
def test_set_machine_type_should_throw_ex_when_missing_machine_type(self, mock_hook):
with self.assertRaises(AirflowException) as cm:
op = GceSetMachineTypeOperator(
project_id=PROJECT_ID,
zone=LOCATION,
resource_id=RESOURCE_ID,
body={},
task_id='id'
)
op.execute(None)
err = cm.exception
self.assertIn(
"The required body field 'machineType' is missing. Please add it.", str(err))
mock_hook.assert_called_once_with(api_version='v1',
gcp_conn_id='google_cloud_default')
MOCK_OP_RESPONSE = "{'kind': 'compute#operation', 'id': '8529919847974922736', " \
"'name': " \
"'operation-1538578207537-577542784f769-7999ab71-94f9ec1d', " \
"'zone': 'https://www.googleapis.com/compute/v1/projects/polidea" \
"-airflow/zones/europe-west3-b', 'operationType': " \
"'setMachineType', 'targetLink': " \
"'https://www.googleapis.com/compute/v1/projects/polidea-airflow" \
"/zones/europe-west3-b/instances/pa-1', 'targetId': " \
"'2480086944131075860', 'status': 'DONE', 'user': " \
"'uberdarek@polidea-airflow.iam.gserviceaccount.com', " \
"'progress': 100, 'insertTime': '2018-10-03T07:50:07.951-07:00', "\
"'startTime': '2018-10-03T07:50:08.324-07:00', 'endTime': " \
"'2018-10-03T07:50:08.484-07:00', 'error': {'errors': [{'code': " \
"'UNSUPPORTED_OPERATION', 'message': \"Machine type with name " \
"'machine-type-1' does not exist in zone 'europe-west3-b'.\"}]}, "\
"'httpErrorStatusCode': 400, 'httpErrorMessage': 'BAD REQUEST', " \
"'selfLink': " \
"'https://www.googleapis.com/compute/v1/projects/polidea-airflow" \
"/zones/europe-west3-b/operations/operation-1538578207537" \
"-577542784f769-7999ab71-94f9ec1d'} "
@mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook'
'._check_operation_status')
@mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook'
'._execute_set_machine_type')
@mock.patch('airflow.contrib.operators.gcp_compute_operator.GceHook.get_conn')
def test_set_machine_type_should_handle_and_trim_gce_error(
self, get_conn, _execute_set_machine_type, _check_operation_status):
get_conn.return_value = {}
_execute_set_machine_type.return_value = {"name": "test-operation"}
_check_operation_status.return_value = ast.literal_eval(self.MOCK_OP_RESPONSE)
with self.assertRaises(AirflowException) as cm:
op = GceSetMachineTypeOperator(
project_id=PROJECT_ID,
zone=LOCATION,
resource_id=RESOURCE_ID,
body=SET_MACHINE_TYPE_BODY,
task_id='id'
)
op.execute(None)
err = cm.exception
_check_operation_status.assert_called_once_with(
{}, "test-operation", PROJECT_ID, LOCATION)
_execute_set_machine_type.assert_called_once_with(
PROJECT_ID, LOCATION, RESOURCE_ID, SET_MACHINE_TYPE_BODY)
# Checking the full message was sometimes failing due to different order
# of keys in the serialized JSON
self.assertIn("400 BAD REQUEST: {", str(err)) # checking the square bracket trim
self.assertIn("UNSUPPORTED_OPERATION", str(err))

Просмотреть файл

@ -519,6 +519,23 @@ class GcfFunctionDeployTest(unittest.TestCase):
)
mock_hook.reset_mock()
@mock.patch('airflow.contrib.operators.gcp_function_operator.GcfHook')
def test_extra_parameter(self, mock_hook):
mock_hook.return_value.list_functions.return_value = []
mock_hook.return_value.create_new_function.return_value = True
body = deepcopy(VALID_BODY)
body['extra_parameter'] = 'extra'
op = GcfFunctionDeployOperator(
project_id="test_project_id",
location="test_region",
body=body,
task_id="id"
)
op.execute(None)
mock_hook.assert_called_once_with(api_version='v1',
gcp_conn_id='google_cloud_default')
mock_hook.reset_mock()
class GcfFunctionDeleteTest(unittest.TestCase):
_FUNCTION_NAME = 'projects/project_name/locations/project_location/functions' \